├── LICENSE
├── README.md
├── assets
├── CuteCat.jpeg
├── Dog.png
└── Lotus.jpeg
├── bash_scripts
├── canny_controlnet_inference.sh
├── controlnet_tile_inference.sh
├── depth_controlnet_inference.sh
└── lora_inference.sh
├── inference.py
├── model
├── __pycache__
│ ├── adapter.cpython-310.pyc
│ └── unet_adapter.cpython-310.pyc
├── adapter.py
├── unet_adapter.py
└── utils.py
├── pipeline
├── __pycache__
│ ├── pipeline_sd_xl_adapter.cpython-310.pyc
│ ├── pipeline_sd_xl_adapter_controlnet.cpython-310.pyc
│ └── pipeline_sd_xl_adapter_controlnet_img2img.cpython-310.pyc
├── pipeline_sd_xl_adapter.py
├── pipeline_sd_xl_adapter_controlnet.py
└── pipeline_sd_xl_adapter_controlnet_img2img.py
├── requirements.txt
└── scripts
├── __pycache__
├── inference_controlnet.cpython-310.pyc
├── inference_ctrlnet_tile.cpython-310.pyc
├── inference_lora.cpython-310.pyc
└── utils.cpython-310.pyc
├── inference_controlnet.py
├── inference_ctrlnet_tile.py
├── inference_lora.py
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # X-Adapter
2 |
3 | This repository is the official implementation of [X-Adapter](https://arxiv.org/abs/2312.02238).
4 |
5 | **[X-Adapter: Adding Universal Compatibility of Plugins for Upgraded Diffusion Model](https://arxiv.org/abs/2312.02238)**
6 |
7 | [Lingmin Ran](),
8 | [Xiaodong Cun](https://vinthony.github.io/academic/),
9 | [Jia-Wei Liu](https://jia-wei-liu.github.io/),
10 | [Rui Zhao](https://ruizhaocv.github.io/),
11 | [Song Zijie](),
12 | [Xintao Wang](https://xinntao.github.io/),
13 | [Jussi Keppo](https://www.jussikeppo.com/),
14 | [Mike Zheng Shou](https://sites.google.com/view/showlab)
15 |
16 |
17 | [](https://showlab.github.io/X-Adapter/)
18 | [](https://arxiv.org/abs/2312.02238)
19 |
20 | 
21 |
22 | X-Adapter enables plugins pretrained on the old version (e.g. SD1.5) directly work with the upgraded Model (e.g., SDXL) without further retraining.
23 |
24 | [//]: # (
)
25 |
26 | [//]: # (
)
27 |
28 | [//]: # (
)
29 |
30 | [//]: # (Given a video-text pair as input, our method, Tune-A-Video, fine-tunes a pre-trained text-to-image diffusion model for text-to-video generation.)
31 |
32 | [//]: # (
)
33 |
34 | ### Thank @[kijai](https://github.com/kijai) for CumfyUI implementation [here](https://github.com/kijai/ComfyUI-Diffusers-X-Adapter)! Please refer to this [tutorial](https://www.reddit.com/r/StableDiffusion/comments/1asuyiw/xadapter/) for hyperparameter setting.
35 |
36 | ## News
37 |
38 | - [17/02/2024] Inference code released
39 |
40 | ## Setup
41 |
42 | ### Requirements
43 |
44 | ```shell
45 | conda create -n xadapter python=3.10
46 | conda activate xadapter
47 |
48 | pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
49 | pip install -r requirements.txt
50 | ```
51 |
52 | Installing [xformers](https://github.com/facebookresearch/xformers) is highly recommended for high efficiency and low GPU cost.
53 |
54 | ### Weights
55 |
56 | **[Stable Diffusion]** [Stable Diffusion](https://arxiv.org/abs/2112.10752) is a latent text-to-image diffusion model capable of generating photo-realistic images given any text input. The pre-trained Stable Diffusion models can be downloaded from Hugging Face (e.g., [Stable Diffusion v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5)). You can also use fine-tuned Stable Diffusion models trained on different styles (e.g., [Anything V4.0](https://huggingface.co/andite/anything-v4.0), [Redshift](https://huggingface.co/nitrosocke/redshift-diffusion), etc.).
57 |
58 | **[ControlNet]** [Controlnet](https://github.com/lllyasviel/ControlNet) is a method to control diffusion models with spatial conditions. You can download the ControlNet family [here](https://huggingface.co/lllyasviel/ControlNet).
59 |
60 | **[LoRA]** [LoRA](https://arxiv.org/abs/2106.09685) is a lightweight adapter to fine-tune large-scale pretrained model. It is widely used for style or identity customization in diffusion models. You can download LoRA from the diffusion community (e.g., [civitai](https://civitai.com/)).
61 |
62 | ### Checkpoint
63 |
64 | Models can be downloaded from our [Hugging Face page](https://huggingface.co/Lingmin-Ran/X-Adapter). Put the checkpoint in folder `./checkpoint/X-Adapter`.
65 |
66 | ## Usage
67 |
68 | After preparing all checkpoints, we can run inference code using different plugins. You can refer to this [tutorial](https://www.reddit.com/r/StableDiffusion/comments/1asuyiw/xadapter/) to quickly get started with X-Adapter.
69 |
70 | ### Controlnet Inference
71 |
72 | Set `--controlnet_canny_path` or `--controlnet_depth_path` to ControlNet's path in the bash script. The default value is its Hugging Face model card.
73 |
74 | sh ./bash_scripts/canny_controlnet_inference.sh
75 | sh ./bash_scripts/depth_controlnet_inference.sh
76 |
77 | ### LoRA Inference
78 |
79 | Set `--lora_model_path` to LoRA's checkpoint in the bash script. In this example we use [MoXin](https://civitai.com/models/12597/moxin), and we put it in folder `./checkpoint/lora`.
80 |
81 | sh ./bash_scripts/lora_inference.sh
82 |
83 | ### Controlnet-Tile Inference
84 |
85 | Set `--controlnet_tile_path` to ControlNet-tile's path in the bash script. The default value is its Hugging Face model card.
86 |
87 | sh ./bash_scripts/controlnet_tile_inference.sh
88 |
89 | ## Cite
90 | If you find X-Adapter useful for your research and applications, please cite us using this BibTeX:
91 |
92 | ```bibtex
93 | @article{ran2023xadapter,
94 | title={X-Adapter: Adding Universal Compatibility of Plugins for Upgraded Diffusion Model},
95 | author={Lingmin Ran and Xiaodong Cun and Jia-Wei Liu and Rui Zhao and Song Zijie and Xintao Wang and Jussi Keppo and Mike Zheng Shou},
96 | journal={arXiv preprint arXiv:2312.02238},
97 | year={2023}
98 | }
99 | ```
100 |
--------------------------------------------------------------------------------
/assets/CuteCat.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/showlab/X-Adapter/e7348cb1b83e039dec8fe6f664266f1594e5103f/assets/CuteCat.jpeg
--------------------------------------------------------------------------------
/assets/Dog.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/showlab/X-Adapter/e7348cb1b83e039dec8fe6f664266f1594e5103f/assets/Dog.png
--------------------------------------------------------------------------------
/assets/Lotus.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/showlab/X-Adapter/e7348cb1b83e039dec8fe6f664266f1594e5103f/assets/Lotus.jpeg
--------------------------------------------------------------------------------
/bash_scripts/canny_controlnet_inference.sh:
--------------------------------------------------------------------------------
1 | python inference.py --plugin_type "controlnet" \
2 | --prompt "A cute cat, high quality, extremely detailed" \
3 | --condition_type "canny" \
4 | --input_image_path "./assets/CuteCat.jpeg" \
5 | --controlnet_condition_scale_list 1.5 1.75 2.0 \
6 | --adapter_guidance_start_list 1.00 \
7 | --adapter_condition_scale_list 1.0 1.20 \
8 | --height 1024 \
9 | --width 1024 \
10 | --height_sd1_5 512 \
11 | --width_sd1_5 512 \
12 |
--------------------------------------------------------------------------------
/bash_scripts/controlnet_tile_inference.sh:
--------------------------------------------------------------------------------
1 | python inference.py --plugin_type "controlnet_tile" \
2 | --prompt "best quality, extremely datailed" \
3 | --controlnet_condition_scale_list 1.0 \
4 | --adapter_guidance_start_list 0.7 \
5 | --adapter_condition_scale_list 1.2 \
6 | --input_image_path "./assets/Dog.png" \
7 | --height 1024 \
8 | --width 768 \
9 | --height_sd1_5 512 \
10 | --width_sd1_5 384 \
11 |
--------------------------------------------------------------------------------
/bash_scripts/depth_controlnet_inference.sh:
--------------------------------------------------------------------------------
1 | python inference.py --plugin_type "controlnet" \
2 | --prompt "A colorful lotus, ink, high quality, extremely detailed" \
3 | --condition_type "depth" \
4 | --input_image_path "./assets/Lotus.jpeg" \
5 | --controlnet_condition_scale_list 1.0 \
6 | --adapter_guidance_start_list 0.80 \
7 | --adapter_condition_scale_list 1.0 \
8 | --height 1024 \
9 | --width 1024 \
10 | --height_sd1_5 512 \
11 | --width_sd1_5 512 \
12 |
--------------------------------------------------------------------------------
/bash_scripts/lora_inference.sh:
--------------------------------------------------------------------------------
1 | python inference.py --plugin_type "lora" \
2 | --prompt "masterpiece, best quality, ultra detailed, 1 girl , solo, smile, looking at viewer, holding flowers" \
3 | --prompt_sd1_5 "masterpiece, best quality, ultra detailed, 1 girl, solo, smile, looking at viewer, holding flowers, shuimobysim, wuchangshuo, bonian, zhenbanqiao, badashanren" \
4 | --adapter_guidance_start_list 0.95 \
5 | --adapter_condition_scale_list 1.50 \
6 | --seed 3943946911 \
7 | --height 1024 \
8 | --width 1024 \
9 | --height_sd1_5 512 \
10 | --width_sd1_5 512 \
11 |
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import os
2 | import datetime
3 | import argparse
4 |
5 | from scripts.inference_controlnet import inference_controlnet
6 | from scripts.inference_lora import inference_lora
7 | from scripts.inference_ctrlnet_tile import inference_ctrlnet_tile
8 |
9 |
10 | def parse_args(input_args=None):
11 | parser = argparse.ArgumentParser(description="Inference setting for X-Adapter.")
12 |
13 | parser.add_argument(
14 | "--plugin_type",
15 | type=str, help='lora or controlnet', default="controlnet"
16 | )
17 | parser.add_argument(
18 | "--controlnet_condition_scale_list",
19 | nargs='+', help='controlnet_scale', default=[1.0, 2.0]
20 | )
21 | parser.add_argument(
22 | "--adapter_guidance_start_list",
23 | nargs='+', help='start of 2nd stage', default=[0.6, 0.65, 0.7, 0.75, 0.8]
24 | )
25 | parser.add_argument(
26 | "--adapter_condition_scale_list",
27 | nargs='+', help='X-Adapter scale', default=[0.8, 1.0, 1.2]
28 | )
29 | parser.add_argument(
30 | "--base_path",
31 | type=str, help='path to base model', default="runwayml/stable-diffusion-v1-5"
32 | )
33 | parser.add_argument(
34 | "--sdxl_path",
35 | type=str, help='path to SDXL', default="stabilityai/stable-diffusion-xl-base-1.0"
36 | )
37 | parser.add_argument(
38 | "--path_vae_sdxl",
39 | type=str, help='path to SDXL vae', default="madebyollin/sdxl-vae-fp16-fix"
40 | )
41 | parser.add_argument(
42 | "--adapter_checkpoint",
43 | type=str, help='path to X-Adapter', default="./checkpoint/X-Adapter/X_Adapter_v1.bin"
44 | )
45 | parser.add_argument(
46 | "--condition_type",
47 | type=str, help='condition type', default="canny"
48 | )
49 | parser.add_argument(
50 | "--controlnet_canny_path",
51 | type=str, help='path to canny controlnet', default="lllyasviel/sd-controlnet-canny"
52 | )
53 | parser.add_argument(
54 | "--controlnet_depth_path",
55 | type=str, help='path to depth controlnet', default="lllyasviel/sd-controlnet-depth"
56 | )
57 | parser.add_argument(
58 | "--controlnet_tile_path",
59 | type=str, help='path to controlnet tile', default="lllyasviel/control_v11f1e_sd15_tile"
60 | )
61 | parser.add_argument(
62 | "--lora_model_path",
63 | type=str, help='path to lora', default="./checkpoint/lora/MoXinV1.safetensors"
64 | )
65 | parser.add_argument(
66 | "--prompt",
67 | type=str, help='SDXL prompt', default=None, required=True
68 | )
69 | parser.add_argument(
70 | "--prompt_sd1_5",
71 | type=str, help='SD1.5 prompt', default=None
72 | )
73 | parser.add_argument(
74 | "--negative_prompt",
75 | type=str, default=None
76 | )
77 | parser.add_argument(
78 | "--iter_num",
79 | type=int, default=1
80 | )
81 | parser.add_argument(
82 | "--input_image_path",
83 | type=str, default="./controlnet_test_image/CuteCat.jpeg"
84 | )
85 | parser.add_argument(
86 | "--num_inference_steps",
87 | type=int, default=50
88 | )
89 | parser.add_argument(
90 | "--guidance_scale",
91 | type=float, default=7.5
92 | )
93 | parser.add_argument(
94 | "--seed",
95 | type=int, default=1674753452
96 | )
97 | parser.add_argument(
98 | "--width",
99 | type=int, default=1024
100 | )
101 | parser.add_argument(
102 | "--height",
103 | type=int, default=1024
104 | )
105 | parser.add_argument(
106 | "--height_sd1_5",
107 | type=int, default=512
108 | )
109 | parser.add_argument(
110 | "--width_sd1_5",
111 | type=int, default=512
112 | )
113 |
114 | if input_args is not None:
115 | args = parser.parse_args(input_args)
116 | else:
117 | args = parser.parse_args()
118 |
119 | return args
120 |
121 |
122 | def run_inference(args):
123 | current_datetime = datetime.datetime.now()
124 | current_datetime = str(current_datetime).replace(":", "_")
125 | save_path = f"./result/{current_datetime}_lora" if args.plugin_type == "lora" else f"./result/{current_datetime}_controlnet"
126 | os.makedirs(save_path)
127 | args.save_path = save_path
128 |
129 | if args.plugin_type == "controlnet":
130 | inference_controlnet(args)
131 | elif args.plugin_type == "controlnet_tile":
132 | inference_ctrlnet_tile(args)
133 | elif args.plugin_type == "lora":
134 | inference_lora(args)
135 | else:
136 | raise NotImplementedError("not implemented yet")
137 |
138 |
139 | if __name__ == "__main__":
140 | args = parse_args()
141 | run_inference(args)
142 |
--------------------------------------------------------------------------------
/model/__pycache__/adapter.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/showlab/X-Adapter/e7348cb1b83e039dec8fe6f664266f1594e5103f/model/__pycache__/adapter.cpython-310.pyc
--------------------------------------------------------------------------------
/model/__pycache__/unet_adapter.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/showlab/X-Adapter/e7348cb1b83e039dec8fe6f664266f1594e5103f/model/__pycache__/unet_adapter.cpython-310.pyc
--------------------------------------------------------------------------------
/model/adapter.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from collections import OrderedDict
4 | from diffusers.models.embeddings import (
5 | TimestepEmbedding,
6 | Timesteps,
7 | )
8 |
9 |
10 | def conv_nd(dims, *args, **kwargs):
11 | """
12 | Create a 1D, 2D, or 3D convolution module.
13 | """
14 | if dims == 1:
15 | return nn.Conv1d(*args, **kwargs)
16 | elif dims == 2:
17 | return nn.Conv2d(*args, **kwargs)
18 | elif dims == 3:
19 | return nn.Conv3d(*args, **kwargs)
20 | raise ValueError(f"unsupported dimensions: {dims}")
21 |
22 |
23 | def avg_pool_nd(dims, *args, **kwargs):
24 | """
25 | Create a 1D, 2D, or 3D average pooling module.
26 | """
27 | if dims == 1:
28 | return nn.AvgPool1d(*args, **kwargs)
29 | elif dims == 2:
30 | return nn.AvgPool2d(*args, **kwargs)
31 | elif dims == 3:
32 | return nn.AvgPool3d(*args, **kwargs)
33 | raise ValueError(f"unsupported dimensions: {dims}")
34 |
35 |
36 | def get_parameter_dtype(parameter: torch.nn.Module):
37 | try:
38 | params = tuple(parameter.parameters())
39 | if len(params) > 0:
40 | return params[0].dtype
41 |
42 | buffers = tuple(parameter.buffers())
43 | if len(buffers) > 0:
44 | return buffers[0].dtype
45 |
46 | except StopIteration:
47 | # For torch.nn.DataParallel compatibility in PyTorch 1.5
48 |
49 | def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
50 | tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
51 | return tuples
52 |
53 | gen = parameter._named_members(get_members_fn=find_tensor_attributes)
54 | first_tuple = next(gen)
55 | return first_tuple[1].dtype
56 |
57 |
58 | class Downsample(nn.Module):
59 | """
60 | A downsampling layer with an optional convolution.
61 | :param channels: channels in the inputs and outputs.
62 | :param use_conv: a bool determining if a convolution is applied.
63 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
64 | downsampling occurs in the inner-two dimensions.
65 | """
66 |
67 | def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
68 | super().__init__()
69 | self.channels = channels
70 | self.out_channels = out_channels or channels
71 | self.use_conv = use_conv
72 | self.dims = dims
73 | stride = 2 if dims != 3 else (1, 2, 2)
74 | if use_conv:
75 | self.op = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding)
76 | else:
77 | assert self.channels == self.out_channels
78 | from torch.nn import MaxUnpool2d
79 | self.op = MaxUnpool2d(dims, kernel_size=stride, stride=stride)
80 |
81 | def forward(self, x):
82 | assert x.shape[1] == self.channels
83 | return self.op(x)
84 |
85 |
86 | class Upsample(nn.Module):
87 | def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
88 | super().__init__()
89 | self.channels = channels
90 | self.out_channels = out_channels or channels
91 | self.use_conv = use_conv
92 | self.dims = dims
93 | stride = 2 if dims != 3 else (1, 2, 2)
94 | if use_conv:
95 | self.op = nn.ConvTranspose2d(self.channels, self.out_channels, 3, stride=stride, padding=1)
96 | else:
97 | assert self.channels == self.out_channels
98 | self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
99 |
100 | def forward(self, x, output_size):
101 | assert x.shape[1] == self.channels
102 | return self.op(x, output_size)
103 |
104 |
105 | class Linear(nn.Module):
106 | def __init__(self, temb_channels, out_channels):
107 | super(Linear, self).__init__()
108 | self.linear = nn.Linear(temb_channels, out_channels)
109 |
110 | def forward(self, x):
111 | return self.linear(x)
112 |
113 |
114 |
115 | class ResnetBlock(nn.Module):
116 |
117 | def __init__(self, in_c, out_c, down, up, ksize=3, sk=False, use_conv=True, enable_timestep=False, temb_channels=None, use_norm=False):
118 | super().__init__()
119 | self.use_norm = use_norm
120 | self.enable_timestep = enable_timestep
121 | ps = ksize // 2
122 | if in_c != out_c or sk == False:
123 | self.in_conv = nn.Conv2d(in_c, out_c, ksize, 1, ps)
124 | else:
125 | self.in_conv = None
126 | self.block1 = nn.Conv2d(out_c, out_c, 3, 1, 1)
127 | self.act = nn.ReLU()
128 | if use_norm:
129 | self.norm1 = nn.GroupNorm(num_groups=32, num_channels=out_c, eps=1e-6, affine=True)
130 | self.block2 = nn.Conv2d(out_c, out_c, ksize, 1, ps)
131 | if sk == False:
132 | self.skep = nn.Conv2d(in_c, out_c, ksize, 1, ps)
133 | else:
134 | self.skep = None
135 |
136 | self.down = down
137 | self.up = up
138 | if self.down:
139 | self.down_opt = Downsample(in_c, use_conv=use_conv)
140 | if self.up:
141 | self.up_opt = Upsample(in_c, use_conv=use_conv)
142 | if enable_timestep:
143 | self.timestep_proj = Linear(temb_channels, out_c)
144 |
145 |
146 | def forward(self, x, output_size=None, temb=None):
147 | if self.down == True:
148 | x = self.down_opt(x)
149 | if self.up == True:
150 | x = self.up_opt(x, output_size)
151 | if self.in_conv is not None: # edit
152 | x = self.in_conv(x)
153 |
154 | h = self.block1(x)
155 | if temb is not None:
156 | temb = self.timestep_proj(temb)[:, :, None, None]
157 | h = h + temb
158 | if self.use_norm:
159 | h = self.norm1(h)
160 | h = self.act(h)
161 | h = self.block2(h)
162 | if self.skep is not None:
163 | return h + self.skep(x)
164 | else:
165 | return h + x
166 |
167 |
168 | class Adapter_XL(nn.Module):
169 |
170 | def __init__(self, in_channels=[1280, 640, 320], out_channels=[1280, 1280, 640], nums_rb=3, ksize=3, sk=True, use_conv=False, use_zero_conv=True,
171 | enable_timestep=False, use_norm=False, temb_channels=None, fusion_type='ADD'):
172 | super(Adapter_XL, self).__init__()
173 | self.channels = in_channels
174 | self.nums_rb = nums_rb
175 | self.body = []
176 | self.out = []
177 | self.use_zero_conv = use_zero_conv
178 | self.fusion_type = fusion_type
179 | self.gamma = []
180 | self.beta = []
181 | self.norm = []
182 | if fusion_type == "SPADE":
183 | self.use_zero_conv = False
184 | for i in range(len(self.channels)):
185 | if self.fusion_type == 'SPADE':
186 | # Corresponding to SPADE
187 | self.gamma.append(nn.Conv2d(out_channels[i], out_channels[i], 1, padding=0))
188 | self.beta.append(nn.Conv2d(out_channels[i], out_channels[i], 1, padding=0))
189 | self.norm.append(nn.BatchNorm2d(out_channels[i]))
190 | elif use_zero_conv:
191 | self.out.append(self.make_zero_conv(out_channels[i]))
192 | else:
193 | self.out.append(nn.Conv2d(out_channels[i], out_channels[i], 1, padding=0))
194 | for j in range(nums_rb):
195 | if i==0:
196 | # 1280, 32, 32 -> 1280, 32, 32
197 | self.body.append(
198 | ResnetBlock(in_channels[i], out_channels[i], down=False, up=False, ksize=ksize, sk=sk, use_conv=use_conv,
199 | enable_timestep=enable_timestep, temb_channels=temb_channels, use_norm=use_norm))
200 | # 1280, 32, 32 -> 1280, 32, 32
201 | elif i==1:
202 | # 640, 64, 64 -> 1280, 64, 64
203 | if j==0:
204 | self.body.append(
205 | ResnetBlock(in_channels[i], out_channels[i], down=False, up=False, ksize=ksize, sk=sk,
206 | use_conv=use_conv, enable_timestep=enable_timestep, temb_channels=temb_channels, use_norm=use_norm))
207 | else:
208 | self.body.append(
209 | ResnetBlock(out_channels[i], out_channels[i], down=False, up=False, ksize=ksize,sk=sk,
210 | use_conv=use_conv, enable_timestep=enable_timestep, temb_channels=temb_channels, use_norm=use_norm))
211 | else:
212 | # 320, 64, 64 -> 640, 128, 128
213 | if j==0:
214 | self.body.append(
215 | ResnetBlock(in_channels[i], out_channels[i], down=False, up=True, ksize=ksize, sk=sk,
216 | use_conv=True, enable_timestep=enable_timestep, temb_channels=temb_channels, use_norm=use_norm))
217 | # use convtranspose2d
218 | else:
219 | self.body.append(
220 | ResnetBlock(out_channels[i], out_channels[i], down=False, up=False, ksize=ksize, sk=sk,
221 | use_conv=use_conv, enable_timestep=enable_timestep, temb_channels=temb_channels, use_norm=use_norm))
222 |
223 |
224 | self.body = nn.ModuleList(self.body)
225 | if self.use_zero_conv:
226 | self.zero_out = nn.ModuleList(self.out)
227 |
228 | # if self.fusion_type == 'SPADE':
229 | # self.norm = nn.ModuleList(self.norm)
230 | # self.gamma = nn.ModuleList(self.gamma)
231 | # self.beta = nn.ModuleList(self.beta)
232 | # else:
233 | # self.zero_out = nn.ModuleList(self.out)
234 |
235 |
236 | # if enable_timestep:
237 | # a = 320
238 | #
239 | # time_embed_dim = a * 4
240 | # self.time_proj = Timesteps(a, True, 0)
241 | # timestep_input_dim = a
242 | #
243 | # self.time_embedding = TimestepEmbedding(
244 | # timestep_input_dim,
245 | # time_embed_dim,
246 | # act_fn='silu',
247 | # post_act_fn=None,
248 | # cond_proj_dim=None,
249 | # )
250 |
251 |
252 | def make_zero_conv(self, channels):
253 |
254 | return zero_module(nn.Conv2d(channels, channels, 1, padding=0))
255 |
256 | @property
257 | def dtype(self) -> torch.dtype:
258 | """
259 | `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
260 | """
261 | return get_parameter_dtype(self)
262 |
263 | def forward(self, x, t=None):
264 | # extract features
265 | features = []
266 | b, c, _, _ = x[-1].shape
267 | if t is not None:
268 | if not torch.is_tensor(t):
269 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
270 | # This would be a good case for the `match` statement (Python 3.10+)
271 | is_mps = x[0].device.type == "mps"
272 | if isinstance(timestep, float):
273 | dtype = torch.float32 if is_mps else torch.float64
274 | else:
275 | dtype = torch.int32 if is_mps else torch.int64
276 | t = torch.tensor([t], dtype=dtype, device=x[0].device)
277 | elif len(t.shape) == 0:
278 | t = t[None].to(x[0].device)
279 |
280 | t = t.expand(b)
281 | t = self.time_proj(t) # b, 320
282 | t = t.to(dtype=x[0].dtype)
283 | t = self.time_embedding(t) # b, 1280
284 | # output_size = (b, 640, 128, 128) # last CA layer output
285 | output_size = (b, 640, (x[0].shape)[2] * 4 , (x[0].shape)[3] * 4) # last CA layer output should suit to the input size CSR
286 |
287 | for i in range(len(self.channels)):
288 | for j in range(self.nums_rb):
289 | idx = i * self.nums_rb + j
290 | if j == 0:
291 | if i < 2:
292 | out = self.body[idx](x[i], temb=t)
293 | else:
294 | out = self.body[idx](x[i], output_size=output_size, temb=t)
295 | else:
296 | out = self.body[idx](out, temb=t)
297 | if self.fusion_type == 'SPADE':
298 | out_gamma = self.gamma[i](out)
299 | out_beta = self.beta[i](out)
300 | out = [out_gamma, out_beta]
301 | else:
302 | out = self.zero_out[i](out)
303 | features.append(out)
304 |
305 | return features
306 |
307 |
308 | def zero_module(module):
309 | """
310 | Zero out the parameters of a module and return it.
311 | """
312 | for p in module.parameters():
313 | p.detach().zero_()
314 | return module
315 |
316 |
317 | if __name__=='__main__':
318 | adapter = Adapter_XL(use_zero_conv=True,
319 | enable_timestep=True, use_norm=True, temb_channels=1280, fusion_type='SPADE').cuda()
320 | x = [torch.randn(4, 1280, 32, 32).cuda(), torch.randn(4, 640, 64, 64).cuda(), torch.randn(4, 320, 64, 64).cuda()]
321 | t = torch.tensor([1,2,3,4]).cuda()
322 | result = adapter(x, t=t)
323 | for xx in result:
324 | print(xx[0].shape)
325 | print(xx[1].shape)
326 |
327 |
328 |
329 |
330 |
--------------------------------------------------------------------------------
/model/unet_adapter.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from dataclasses import dataclass
15 | from typing import Any, Dict, List, Optional, Tuple, Union
16 |
17 | import torch
18 | import torch.nn as nn
19 | import torch.utils.checkpoint
20 |
21 | from diffusers.configuration_utils import ConfigMixin, register_to_config
22 | from diffusers.loaders import UNet2DConditionLoadersMixin
23 | from diffusers.utils import BaseOutput, logging
24 | from diffusers.models.activations import get_activation
25 | from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
26 | from diffusers.models.embeddings import (
27 | GaussianFourierProjection,
28 | ImageHintTimeEmbedding,
29 | ImageProjection,
30 | ImageTimeEmbedding,
31 | PositionNet,
32 | TextImageProjection,
33 | TextImageTimeEmbedding,
34 | TextTimeEmbedding,
35 | TimestepEmbedding,
36 | Timesteps,
37 | )
38 | from diffusers.models.modeling_utils import ModelMixin
39 | from diffusers.models.unet_2d_blocks import (
40 | UNetMidBlock2DCrossAttn,
41 | UNetMidBlock2DSimpleCrossAttn,
42 | get_down_block,
43 | get_up_block,
44 | )
45 |
46 |
47 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
48 |
49 |
50 | @dataclass
51 | class UNet2DConditionOutput(BaseOutput):
52 | """
53 | The output of [`UNet2DConditionModel`].
54 |
55 | Args:
56 | sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
57 | The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
58 | """
59 |
60 | sample: torch.FloatTensor = None
61 | hidden_states: Optional[list] = None
62 | encoder_feature: Optional[list] = None
63 |
64 |
65 | class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
66 | r"""
67 | A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
68 | shaped output.
69 |
70 | This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
71 | for all models (such as downloading or saving).
72 |
73 | Parameters:
74 | sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
75 | Height and width of input/output sample.
76 | in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
77 | out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
78 | center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
79 | flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
80 | Whether to flip the sin to cos in the time embedding.
81 | freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
82 | down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
83 | The tuple of downsample blocks to use.
84 | mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
85 | Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or
86 | `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
87 | up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
88 | The tuple of upsample blocks to use.
89 | only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
90 | Whether to include self-attention in the basic transformer blocks, see
91 | [`~models.attention.BasicTransformerBlock`].
92 | block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
93 | The tuple of output channels for each block.
94 | layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
95 | downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
96 | mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
97 | act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
98 | norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
99 | If `None`, normalization and activation layers is skipped in post-processing.
100 | norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
101 | cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
102 | The dimension of the cross attention features.
103 | transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
104 | The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
105 | [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
106 | [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
107 | encoder_hid_dim (`int`, *optional*, defaults to None):
108 | If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
109 | dimension to `cross_attention_dim`.
110 | encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
111 | If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
112 | embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
113 | attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
114 | num_attention_heads (`int`, *optional*):
115 | The number of attention heads. If not defined, defaults to `attention_head_dim`
116 | resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
117 | for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
118 | class_embed_type (`str`, *optional*, defaults to `None`):
119 | The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
120 | `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
121 | addition_embed_type (`str`, *optional*, defaults to `None`):
122 | Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
123 | "text". "text" will use the `TextTimeEmbedding` layer.
124 | addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
125 | Dimension for the timestep embeddings.
126 | num_class_embeds (`int`, *optional*, defaults to `None`):
127 | Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
128 | class conditioning with `class_embed_type` equal to `None`.
129 | time_embedding_type (`str`, *optional*, defaults to `positional`):
130 | The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
131 | time_embedding_dim (`int`, *optional*, defaults to `None`):
132 | An optional override for the dimension of the projected time embedding.
133 | time_embedding_act_fn (`str`, *optional*, defaults to `None`):
134 | Optional activation function to use only once on the time embeddings before they are passed to the rest of
135 | the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
136 | timestep_post_act (`str`, *optional*, defaults to `None`):
137 | The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
138 | time_cond_proj_dim (`int`, *optional*, defaults to `None`):
139 | The dimension of `cond_proj` layer in the timestep embedding.
140 | conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
141 | conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
142 | projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
143 | `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
144 | class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
145 | embeddings with the class embeddings.
146 | mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
147 | Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
148 | `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
149 | `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
150 | otherwise.
151 | """
152 |
153 | _supports_gradient_checkpointing = True
154 |
155 | @register_to_config
156 | def __init__(
157 | self,
158 | sample_size: Optional[int] = None,
159 | in_channels: int = 4,
160 | out_channels: int = 4,
161 | center_input_sample: bool = False,
162 | flip_sin_to_cos: bool = True,
163 | freq_shift: int = 0,
164 | down_block_types: Tuple[str] = (
165 | "CrossAttnDownBlock2D",
166 | "CrossAttnDownBlock2D",
167 | "CrossAttnDownBlock2D",
168 | "DownBlock2D",
169 | ),
170 | mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
171 | up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
172 | only_cross_attention: Union[bool, Tuple[bool]] = False,
173 | block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
174 | layers_per_block: Union[int, Tuple[int]] = 2,
175 | downsample_padding: int = 1,
176 | mid_block_scale_factor: float = 1,
177 | act_fn: str = "silu",
178 | norm_num_groups: Optional[int] = 32,
179 | norm_eps: float = 1e-5,
180 | cross_attention_dim: Union[int, Tuple[int]] = 1280,
181 | transformer_layers_per_block: Union[int, Tuple[int]] = 1,
182 | encoder_hid_dim: Optional[int] = None,
183 | encoder_hid_dim_type: Optional[str] = None,
184 | attention_head_dim: Union[int, Tuple[int]] = 8,
185 | num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
186 | dual_cross_attention: bool = False,
187 | use_linear_projection: bool = False,
188 | class_embed_type: Optional[str] = None,
189 | addition_embed_type: Optional[str] = None,
190 | addition_time_embed_dim: Optional[int] = None,
191 | num_class_embeds: Optional[int] = None,
192 | upcast_attention: bool = False,
193 | resnet_time_scale_shift: str = "default",
194 | resnet_skip_time_act: bool = False,
195 | resnet_out_scale_factor: int = 1.0,
196 | time_embedding_type: str = "positional",
197 | time_embedding_dim: Optional[int] = None,
198 | time_embedding_act_fn: Optional[str] = None,
199 | timestep_post_act: Optional[str] = None,
200 | time_cond_proj_dim: Optional[int] = None,
201 | conv_in_kernel: int = 3,
202 | conv_out_kernel: int = 3,
203 | projection_class_embeddings_input_dim: Optional[int] = None,
204 | attention_type: str = "default",
205 | class_embeddings_concat: bool = False,
206 | mid_block_only_cross_attention: Optional[bool] = None,
207 | cross_attention_norm: Optional[str] = None,
208 | addition_embed_type_num_heads=64,
209 | ):
210 | super().__init__()
211 |
212 | self.sample_size = sample_size
213 |
214 | if num_attention_heads is not None:
215 | raise ValueError(
216 | "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
217 | )
218 |
219 | # If `num_attention_heads` is not defined (which is the case for most models)
220 | # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
221 | # The reason for this behavior is to correct for incorrectly named variables that were introduced
222 | # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
223 | # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
224 | # which is why we correct for the naming here.
225 | num_attention_heads = num_attention_heads or attention_head_dim
226 |
227 | # Check inputs
228 | if len(down_block_types) != len(up_block_types):
229 | raise ValueError(
230 | f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
231 | )
232 |
233 | if len(block_out_channels) != len(down_block_types):
234 | raise ValueError(
235 | f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
236 | )
237 |
238 | if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
239 | raise ValueError(
240 | f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
241 | )
242 |
243 | if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
244 | raise ValueError(
245 | f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
246 | )
247 |
248 | if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
249 | raise ValueError(
250 | f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
251 | )
252 |
253 | if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
254 | raise ValueError(
255 | f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
256 | )
257 |
258 | if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
259 | raise ValueError(
260 | f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
261 | )
262 |
263 | # input
264 | conv_in_padding = (conv_in_kernel - 1) // 2
265 | self.conv_in = nn.Conv2d(
266 | in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
267 | )
268 |
269 | # time
270 | if time_embedding_type == "fourier":
271 | time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
272 | if time_embed_dim % 2 != 0:
273 | raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
274 | self.time_proj = GaussianFourierProjection(
275 | time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
276 | )
277 | timestep_input_dim = time_embed_dim
278 | elif time_embedding_type == "positional":
279 | time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
280 |
281 | self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
282 | timestep_input_dim = block_out_channels[0]
283 | else:
284 | raise ValueError(
285 | f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
286 | )
287 |
288 | self.time_embedding = TimestepEmbedding(
289 | timestep_input_dim,
290 | time_embed_dim,
291 | act_fn=act_fn,
292 | post_act_fn=timestep_post_act,
293 | cond_proj_dim=time_cond_proj_dim,
294 | )
295 |
296 | if encoder_hid_dim_type is None and encoder_hid_dim is not None:
297 | encoder_hid_dim_type = "text_proj"
298 | self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
299 | logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
300 |
301 | if encoder_hid_dim is None and encoder_hid_dim_type is not None:
302 | raise ValueError(
303 | f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
304 | )
305 |
306 | if encoder_hid_dim_type == "text_proj":
307 | self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
308 | elif encoder_hid_dim_type == "text_image_proj":
309 | # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
310 | # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
311 | # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
312 | self.encoder_hid_proj = TextImageProjection(
313 | text_embed_dim=encoder_hid_dim,
314 | image_embed_dim=cross_attention_dim,
315 | cross_attention_dim=cross_attention_dim,
316 | )
317 | elif encoder_hid_dim_type == "image_proj":
318 | # Kandinsky 2.2
319 | self.encoder_hid_proj = ImageProjection(
320 | image_embed_dim=encoder_hid_dim,
321 | cross_attention_dim=cross_attention_dim,
322 | )
323 | elif encoder_hid_dim_type is not None:
324 | raise ValueError(
325 | f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
326 | )
327 | else:
328 | self.encoder_hid_proj = None
329 |
330 | # class embedding
331 | if class_embed_type is None and num_class_embeds is not None:
332 | self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
333 | elif class_embed_type == "timestep":
334 | self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
335 | elif class_embed_type == "identity":
336 | self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
337 | elif class_embed_type == "projection":
338 | if projection_class_embeddings_input_dim is None:
339 | raise ValueError(
340 | "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
341 | )
342 | # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
343 | # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
344 | # 2. it projects from an arbitrary input dimension.
345 | #
346 | # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
347 | # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
348 | # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
349 | self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
350 | elif class_embed_type == "simple_projection":
351 | if projection_class_embeddings_input_dim is None:
352 | raise ValueError(
353 | "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
354 | )
355 | self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
356 | else:
357 | self.class_embedding = None
358 |
359 | if addition_embed_type == "text":
360 | if encoder_hid_dim is not None:
361 | text_time_embedding_from_dim = encoder_hid_dim
362 | else:
363 | text_time_embedding_from_dim = cross_attention_dim
364 |
365 | self.add_embedding = TextTimeEmbedding(
366 | text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
367 | )
368 | elif addition_embed_type == "text_image":
369 | # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
370 | # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
371 | # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
372 | self.add_embedding = TextImageTimeEmbedding(
373 | text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
374 | )
375 | elif addition_embed_type == "text_time":
376 | self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
377 | self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
378 | elif addition_embed_type == "image":
379 | # Kandinsky 2.2
380 | self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
381 | elif addition_embed_type == "image_hint":
382 | # Kandinsky 2.2 ControlNet
383 | self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
384 | elif addition_embed_type is not None:
385 | raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
386 |
387 | if time_embedding_act_fn is None:
388 | self.time_embed_act = None
389 | else:
390 | self.time_embed_act = get_activation(time_embedding_act_fn)
391 |
392 | self.down_blocks = nn.ModuleList([])
393 | self.up_blocks = nn.ModuleList([])
394 |
395 | if isinstance(only_cross_attention, bool):
396 | if mid_block_only_cross_attention is None:
397 | mid_block_only_cross_attention = only_cross_attention
398 |
399 | only_cross_attention = [only_cross_attention] * len(down_block_types)
400 |
401 | if mid_block_only_cross_attention is None:
402 | mid_block_only_cross_attention = False
403 |
404 | if isinstance(num_attention_heads, int):
405 | num_attention_heads = (num_attention_heads,) * len(down_block_types)
406 |
407 | if isinstance(attention_head_dim, int):
408 | attention_head_dim = (attention_head_dim,) * len(down_block_types)
409 |
410 | if isinstance(cross_attention_dim, int):
411 | cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
412 |
413 | if isinstance(layers_per_block, int):
414 | layers_per_block = [layers_per_block] * len(down_block_types)
415 |
416 | if isinstance(transformer_layers_per_block, int):
417 | transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
418 |
419 | if class_embeddings_concat:
420 | # The time embeddings are concatenated with the class embeddings. The dimension of the
421 | # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
422 | # regular time embeddings
423 | blocks_time_embed_dim = time_embed_dim * 2
424 | else:
425 | blocks_time_embed_dim = time_embed_dim
426 |
427 | # down
428 | output_channel = block_out_channels[0]
429 | for i, down_block_type in enumerate(down_block_types):
430 | input_channel = output_channel
431 | output_channel = block_out_channels[i]
432 | is_final_block = i == len(block_out_channels) - 1
433 |
434 | down_block = get_down_block(
435 | down_block_type,
436 | num_layers=layers_per_block[i],
437 | transformer_layers_per_block=transformer_layers_per_block[i],
438 | in_channels=input_channel,
439 | out_channels=output_channel,
440 | temb_channels=blocks_time_embed_dim,
441 | add_downsample=not is_final_block,
442 | resnet_eps=norm_eps,
443 | resnet_act_fn=act_fn,
444 | resnet_groups=norm_num_groups,
445 | cross_attention_dim=cross_attention_dim[i],
446 | num_attention_heads=num_attention_heads[i],
447 | downsample_padding=downsample_padding,
448 | dual_cross_attention=dual_cross_attention,
449 | use_linear_projection=use_linear_projection,
450 | only_cross_attention=only_cross_attention[i],
451 | upcast_attention=upcast_attention,
452 | resnet_time_scale_shift=resnet_time_scale_shift,
453 | attention_type=attention_type,
454 | resnet_skip_time_act=resnet_skip_time_act,
455 | resnet_out_scale_factor=resnet_out_scale_factor,
456 | cross_attention_norm=cross_attention_norm,
457 | attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
458 | )
459 | self.down_blocks.append(down_block)
460 |
461 | # mid
462 | if mid_block_type == "UNetMidBlock2DCrossAttn":
463 | self.mid_block = UNetMidBlock2DCrossAttn(
464 | transformer_layers_per_block=transformer_layers_per_block[-1],
465 | in_channels=block_out_channels[-1],
466 | temb_channels=blocks_time_embed_dim,
467 | resnet_eps=norm_eps,
468 | resnet_act_fn=act_fn,
469 | output_scale_factor=mid_block_scale_factor,
470 | resnet_time_scale_shift=resnet_time_scale_shift,
471 | cross_attention_dim=cross_attention_dim[-1],
472 | num_attention_heads=num_attention_heads[-1],
473 | resnet_groups=norm_num_groups,
474 | dual_cross_attention=dual_cross_attention,
475 | use_linear_projection=use_linear_projection,
476 | upcast_attention=upcast_attention,
477 | attention_type=attention_type,
478 | )
479 | elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
480 | self.mid_block = UNetMidBlock2DSimpleCrossAttn(
481 | in_channels=block_out_channels[-1],
482 | temb_channels=blocks_time_embed_dim,
483 | resnet_eps=norm_eps,
484 | resnet_act_fn=act_fn,
485 | output_scale_factor=mid_block_scale_factor,
486 | cross_attention_dim=cross_attention_dim[-1],
487 | attention_head_dim=attention_head_dim[-1],
488 | resnet_groups=norm_num_groups,
489 | resnet_time_scale_shift=resnet_time_scale_shift,
490 | skip_time_act=resnet_skip_time_act,
491 | only_cross_attention=mid_block_only_cross_attention,
492 | cross_attention_norm=cross_attention_norm,
493 | )
494 | elif mid_block_type is None:
495 | self.mid_block = None
496 | else:
497 | raise ValueError(f"unknown mid_block_type : {mid_block_type}")
498 |
499 | # count how many layers upsample the images
500 | self.num_upsamplers = 0
501 |
502 | # up
503 | reversed_block_out_channels = list(reversed(block_out_channels))
504 | reversed_num_attention_heads = list(reversed(num_attention_heads))
505 | reversed_layers_per_block = list(reversed(layers_per_block))
506 | reversed_cross_attention_dim = list(reversed(cross_attention_dim))
507 | reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
508 | only_cross_attention = list(reversed(only_cross_attention))
509 |
510 | output_channel = reversed_block_out_channels[0]
511 | for i, up_block_type in enumerate(up_block_types):
512 | is_final_block = i == len(block_out_channels) - 1
513 |
514 | prev_output_channel = output_channel
515 | output_channel = reversed_block_out_channels[i]
516 | input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
517 |
518 | # add upsample block for all BUT final layer
519 | if not is_final_block:
520 | add_upsample = True
521 | self.num_upsamplers += 1
522 | else:
523 | add_upsample = False
524 |
525 | up_block = get_up_block(
526 | up_block_type,
527 | num_layers=reversed_layers_per_block[i] + 1,
528 | transformer_layers_per_block=reversed_transformer_layers_per_block[i],
529 | in_channels=input_channel,
530 | out_channels=output_channel,
531 | prev_output_channel=prev_output_channel,
532 | temb_channels=blocks_time_embed_dim,
533 | add_upsample=add_upsample,
534 | resnet_eps=norm_eps,
535 | resnet_act_fn=act_fn,
536 | resnet_groups=norm_num_groups,
537 | cross_attention_dim=reversed_cross_attention_dim[i],
538 | num_attention_heads=reversed_num_attention_heads[i],
539 | dual_cross_attention=dual_cross_attention,
540 | use_linear_projection=use_linear_projection,
541 | only_cross_attention=only_cross_attention[i],
542 | upcast_attention=upcast_attention,
543 | resnet_time_scale_shift=resnet_time_scale_shift,
544 | attention_type=attention_type,
545 | resnet_skip_time_act=resnet_skip_time_act,
546 | resnet_out_scale_factor=resnet_out_scale_factor,
547 | cross_attention_norm=cross_attention_norm,
548 | attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
549 | )
550 | self.up_blocks.append(up_block)
551 | prev_output_channel = output_channel
552 |
553 | # out
554 | if norm_num_groups is not None:
555 | self.conv_norm_out = nn.GroupNorm(
556 | num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
557 | )
558 |
559 | self.conv_act = get_activation(act_fn)
560 |
561 | else:
562 | self.conv_norm_out = None
563 | self.conv_act = None
564 |
565 | conv_out_padding = (conv_out_kernel - 1) // 2
566 | self.conv_out = nn.Conv2d(
567 | block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
568 | )
569 |
570 | if attention_type == "gated":
571 | positive_len = 768
572 | if isinstance(cross_attention_dim, int):
573 | positive_len = cross_attention_dim
574 | elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
575 | positive_len = cross_attention_dim[0]
576 | self.position_net = PositionNet(positive_len=positive_len, out_dim=cross_attention_dim)
577 |
578 |
579 | @property
580 | def attn_processors(self) -> Dict[str, AttentionProcessor]:
581 | r"""
582 | Returns:
583 | `dict` of attention processors: A dictionary containing all attention processors used in the model with
584 | indexed by its weight name.
585 | """
586 | # set recursively
587 | processors = {}
588 |
589 | def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
590 | if hasattr(module, "set_processor"):
591 | processors[f"{name}.processor"] = module.processor
592 |
593 | for sub_name, child in module.named_children():
594 | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
595 |
596 | return processors
597 |
598 | for name, module in self.named_children():
599 | fn_recursive_add_processors(name, module, processors)
600 |
601 | return processors
602 |
603 | def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
604 | r"""
605 | Sets the attention processor to use to compute attention.
606 |
607 | Parameters:
608 | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
609 | The instantiated processor class or a dictionary of processor classes that will be set as the processor
610 | for **all** `Attention` layers.
611 |
612 | If `processor` is a dict, the key needs to define the path to the corresponding cross attention
613 | processor. This is strongly recommended when setting trainable attention processors.
614 |
615 | """
616 | count = len(self.attn_processors.keys())
617 |
618 | if isinstance(processor, dict) and len(processor) != count:
619 | raise ValueError(
620 | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
621 | f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
622 | )
623 |
624 | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
625 | if hasattr(module, "set_processor"):
626 | if not isinstance(processor, dict):
627 | module.set_processor(processor)
628 | else:
629 | module.set_processor(processor.pop(f"{name}.processor"))
630 |
631 | for sub_name, child in module.named_children():
632 | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
633 |
634 | for name, module in self.named_children():
635 | fn_recursive_attn_processor(name, module, processor)
636 |
637 | def set_default_attn_processor(self):
638 | """
639 | Disables custom attention processors and sets the default attention implementation.
640 | """
641 | self.set_attn_processor(AttnProcessor())
642 |
643 | def set_attention_slice(self, slice_size):
644 | r"""
645 | Enable sliced attention computation.
646 |
647 | When this option is enabled, the attention module splits the input tensor in slices to compute attention in
648 | several steps. This is useful for saving some memory in exchange for a small decrease in speed.
649 |
650 | Args:
651 | slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
652 | When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
653 | `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
654 | provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
655 | must be a multiple of `slice_size`.
656 | """
657 | sliceable_head_dims = []
658 |
659 | def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
660 | if hasattr(module, "set_attention_slice"):
661 | sliceable_head_dims.append(module.sliceable_head_dim)
662 |
663 | for child in module.children():
664 | fn_recursive_retrieve_sliceable_dims(child)
665 |
666 | # retrieve number of attention layers
667 | for module in self.children():
668 | fn_recursive_retrieve_sliceable_dims(module)
669 |
670 | num_sliceable_layers = len(sliceable_head_dims)
671 |
672 | if slice_size == "auto":
673 | # half the attention head size is usually a good trade-off between
674 | # speed and memory
675 | slice_size = [dim // 2 for dim in sliceable_head_dims]
676 | elif slice_size == "max":
677 | # make smallest slice possible
678 | slice_size = num_sliceable_layers * [1]
679 |
680 | slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
681 |
682 | if len(slice_size) != len(sliceable_head_dims):
683 | raise ValueError(
684 | f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
685 | f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
686 | )
687 |
688 | for i in range(len(slice_size)):
689 | size = slice_size[i]
690 | dim = sliceable_head_dims[i]
691 | if size is not None and size > dim:
692 | raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
693 |
694 | # Recursively walk through all the children.
695 | # Any children which exposes the set_attention_slice method
696 | # gets the message
697 | def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
698 | if hasattr(module, "set_attention_slice"):
699 | module.set_attention_slice(slice_size.pop())
700 |
701 | for child in module.children():
702 | fn_recursive_set_attention_slice(child, slice_size)
703 |
704 | reversed_slice_size = list(reversed(slice_size))
705 | for module in self.children():
706 | fn_recursive_set_attention_slice(module, reversed_slice_size)
707 |
708 | def _set_gradient_checkpointing(self, module, value=False):
709 | if hasattr(module, "gradient_checkpointing"):
710 | module.gradient_checkpointing = value
711 |
712 | def forward(
713 | self,
714 | sample: torch.FloatTensor,
715 | timestep: Union[torch.Tensor, float, int],
716 | encoder_hidden_states: torch.Tensor,
717 | class_labels: Optional[torch.Tensor] = None,
718 | timestep_cond: Optional[torch.Tensor] = None,
719 | attention_mask: Optional[torch.Tensor] = None,
720 | cross_attention_kwargs: Optional[Dict[str, Any]] = None,
721 | added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
722 | down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
723 | mid_block_additional_residual: Optional[torch.Tensor] = None,
724 | up_block_additional_residual: Optional[torch.Tensor] = None,
725 | encoder_attention_mask: Optional[torch.Tensor] = None,
726 | return_dict: bool = True,
727 | return_hidden_states: bool = False,
728 | return_encoder_feature: bool = False,
729 | return_early: bool = False,
730 | down_bridge_residuals: Optional[Tuple[torch.Tensor]] = None,
731 | fusion_guidance_scale: Optional[torch.FloatTensor] = None,
732 | fusion_type: Optional[str] = 'ADD',
733 | adapter: Optional = None
734 | ) -> Union[UNet2DConditionOutput, Tuple]:
735 | r"""
736 | The [`UNet2DConditionModel`] forward method.
737 |
738 | Args:
739 | sample (`torch.FloatTensor`):
740 | The noisy input tensor with the following shape `(batch, channel, height, width)`.
741 | timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
742 | encoder_hidden_states (`torch.FloatTensor`):
743 | The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
744 | encoder_attention_mask (`torch.Tensor`):
745 | A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
746 | `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
747 | which adds large negative values to the attention scores corresponding to "discard" tokens.
748 | return_dict (`bool`, *optional*, defaults to `True`):
749 | Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
750 | tuple.
751 | cross_attention_kwargs (`dict`, *optional*):
752 | A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
753 | added_cond_kwargs: (`dict`, *optional*):
754 | A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
755 | are passed along to the UNet blocks.
756 |
757 | Returns:
758 | [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
759 | If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
760 | a `tuple` is returned where the first element is the sample tensor.
761 | """
762 | # By default samples have to be AT least a multiple of the overall upsampling factor.
763 | # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
764 | # However, the upsampling interpolation output size can be forced to fit any upsampling size
765 | # on the fly if necessary.
766 | ############## bridge usage ##################
767 | if return_hidden_states:
768 | hidden_states = []
769 | return_dict = True
770 | ############## end of bridge usage ##################
771 |
772 |
773 |
774 | default_overall_up_factor = 2**self.num_upsamplers
775 |
776 | # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
777 | forward_upsample_size = False
778 | upsample_size = None
779 |
780 | if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
781 | logger.info("Forward upsample size to force interpolation output size.")
782 | forward_upsample_size = True
783 |
784 | # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
785 | # expects mask of shape:
786 | # [batch, key_tokens]
787 | # adds singleton query_tokens dimension:
788 | # [batch, 1, key_tokens]
789 | # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
790 | # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
791 | # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
792 | if attention_mask is not None:
793 | # assume that mask is expressed as:
794 | # (1 = keep, 0 = discard)
795 | # convert mask into a bias that can be added to attention scores:
796 | # (keep = +0, discard = -10000.0)
797 | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
798 | attention_mask = attention_mask.unsqueeze(1)
799 |
800 | # convert encoder_attention_mask to a bias the same way we do for attention_mask
801 | if encoder_attention_mask is not None:
802 | encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
803 | encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
804 |
805 | # 0. center input if necessary
806 | if self.config.center_input_sample:
807 | sample = 2 * sample - 1.0
808 |
809 | # 1. time
810 | timesteps = timestep
811 | if not torch.is_tensor(timesteps):
812 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
813 | # This would be a good case for the `match` statement (Python 3.10+)
814 | is_mps = sample.device.type == "mps"
815 | if isinstance(timestep, float):
816 | dtype = torch.float32 if is_mps else torch.float64
817 | else:
818 | dtype = torch.int32 if is_mps else torch.int64
819 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
820 | elif len(timesteps.shape) == 0:
821 | timesteps = timesteps[None].to(sample.device)
822 |
823 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
824 | timesteps = timesteps.expand(sample.shape[0])
825 |
826 | t_emb = self.time_proj(timesteps) # 2, 320
827 |
828 | # `Timesteps` does not contain any weights and will always return f32 tensors
829 | # but time_embedding might actually be running in fp16. so we need to cast here.
830 | # there might be better ways to encapsulate this.
831 | t_emb = t_emb.to(dtype=sample.dtype)
832 |
833 | emb = self.time_embedding(t_emb, timestep_cond)
834 |
835 | aug_emb = None
836 |
837 | if self.class_embedding is not None:
838 | if class_labels is None:
839 | raise ValueError("class_labels should be provided when num_class_embeds > 0")
840 |
841 | if self.config.class_embed_type == "timestep":
842 | class_labels = self.time_proj(class_labels)
843 |
844 | # `Timesteps` does not contain any weights and will always return f32 tensors
845 | # there might be better ways to encapsulate this.
846 | class_labels = class_labels.to(dtype=sample.dtype)
847 |
848 | class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
849 |
850 | if self.config.class_embeddings_concat:
851 | emb = torch.cat([emb, class_emb], dim=-1)
852 | else:
853 | emb = emb + class_emb
854 |
855 | if self.config.addition_embed_type == "text":
856 | aug_emb = self.add_embedding(encoder_hidden_states)
857 | elif self.config.addition_embed_type == "text_image":
858 | # Kandinsky 2.1 - style
859 | if "image_embeds" not in added_cond_kwargs:
860 | raise ValueError(
861 | f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
862 | )
863 |
864 | image_embs = added_cond_kwargs.get("image_embeds")
865 | text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
866 | aug_emb = self.add_embedding(text_embs, image_embs)
867 | elif self.config.addition_embed_type == "text_time":
868 | # SDXL - style
869 | if "text_embeds" not in added_cond_kwargs:
870 | raise ValueError(
871 | f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
872 | )
873 | text_embeds = added_cond_kwargs.get("text_embeds")
874 | if "time_ids" not in added_cond_kwargs:
875 | raise ValueError(
876 | f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
877 | )
878 | time_ids = added_cond_kwargs.get("time_ids")
879 | time_embeds = self.add_time_proj(time_ids.flatten())
880 | time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
881 |
882 | add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
883 | add_embeds = add_embeds.to(emb.dtype)
884 | aug_emb = self.add_embedding(add_embeds)
885 | elif self.config.addition_embed_type == "image":
886 | # Kandinsky 2.2 - style
887 | if "image_embeds" not in added_cond_kwargs:
888 | raise ValueError(
889 | f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
890 | )
891 | image_embs = added_cond_kwargs.get("image_embeds")
892 | aug_emb = self.add_embedding(image_embs)
893 | elif self.config.addition_embed_type == "image_hint":
894 | # Kandinsky 2.2 - style
895 | if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
896 | raise ValueError(
897 | f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
898 | )
899 | image_embs = added_cond_kwargs.get("image_embeds")
900 | hint = added_cond_kwargs.get("hint")
901 | aug_emb, hint = self.add_embedding(image_embs, hint)
902 | sample = torch.cat([sample, hint], dim=1)
903 |
904 | emb = emb + aug_emb if aug_emb is not None else emb
905 |
906 | if self.time_embed_act is not None:
907 | emb = self.time_embed_act(emb)
908 |
909 | if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
910 | encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
911 | elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
912 | # Kadinsky 2.1 - style
913 | if "image_embeds" not in added_cond_kwargs:
914 | raise ValueError(
915 | f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
916 | )
917 |
918 | image_embeds = added_cond_kwargs.get("image_embeds")
919 | encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
920 | elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
921 | # Kandinsky 2.2 - style
922 | if "image_embeds" not in added_cond_kwargs:
923 | raise ValueError(
924 | f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
925 | )
926 | image_embeds = added_cond_kwargs.get("image_embeds")
927 | encoder_hidden_states = self.encoder_hid_proj(image_embeds)
928 | # 2. pre-process
929 | sample = self.conv_in(sample)
930 |
931 | # 2.5 GLIGEN position net
932 | if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
933 | cross_attention_kwargs = cross_attention_kwargs.copy()
934 | gligen_args = cross_attention_kwargs.pop("gligen")
935 | cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
936 |
937 | # 3. down
938 |
939 | if return_encoder_feature:
940 | encoder_feature = []
941 |
942 | is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
943 | is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
944 | is_bridge_encoder = down_bridge_residuals is not None
945 | is_bridge = up_block_additional_residual is not None
946 |
947 | down_block_res_samples = (sample,)
948 |
949 |
950 |
951 | for downsample_block in self.down_blocks:
952 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
953 | # For t2i-adapter CrossAttnDownBlock2D
954 | additional_residuals = {}
955 | if is_adapter and len(down_block_additional_residuals) > 0:
956 | additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
957 |
958 | sample, res_samples = downsample_block(
959 | hidden_states=sample,
960 | temb=emb,
961 | encoder_hidden_states=encoder_hidden_states,
962 | attention_mask=attention_mask,
963 | cross_attention_kwargs=cross_attention_kwargs,
964 | encoder_attention_mask=encoder_attention_mask,
965 | **additional_residuals,
966 | )
967 |
968 | if is_bridge_encoder and len(down_bridge_residuals) > 0:
969 | sample += down_bridge_residuals.pop(0)
970 |
971 | if return_encoder_feature:
972 | encoder_feature.append(sample)
973 | else:
974 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
975 |
976 | if is_adapter and len(down_block_additional_residuals) > 0:
977 | sample += down_block_additional_residuals.pop(0)
978 |
979 | if is_bridge_encoder and len(down_bridge_residuals) > 0:
980 | sample += down_bridge_residuals.pop(0)
981 |
982 | down_block_res_samples += res_samples
983 |
984 |
985 | if is_controlnet:
986 | new_down_block_res_samples = ()
987 |
988 | for down_block_res_sample, down_block_additional_residual in zip(
989 | down_block_res_samples, down_block_additional_residuals
990 | ):
991 | down_block_res_sample = down_block_res_sample + down_block_additional_residual
992 | new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
993 |
994 | down_block_res_samples = new_down_block_res_samples
995 |
996 | if return_encoder_feature and return_early:
997 | return encoder_feature
998 |
999 | # 4. mid
1000 | if self.mid_block is not None:
1001 | sample = self.mid_block(
1002 | sample,
1003 | emb,
1004 | encoder_hidden_states=encoder_hidden_states,
1005 | attention_mask=attention_mask,
1006 | cross_attention_kwargs=cross_attention_kwargs,
1007 | encoder_attention_mask=encoder_attention_mask,
1008 | )
1009 |
1010 | if is_controlnet:
1011 | sample = sample + mid_block_additional_residual
1012 |
1013 | ################# bridge usage #################
1014 |
1015 | if is_bridge:
1016 | if fusion_guidance_scale is not None:
1017 | sample = sample + fusion_guidance_scale * (up_block_additional_residual.pop(0) - sample)
1018 | else:
1019 | sample += up_block_additional_residual.pop(0)
1020 | ################# end of bridge usage #################
1021 | # 5. up
1022 |
1023 | for i, upsample_block in enumerate(self.up_blocks):
1024 | is_final_block = i == len(self.up_blocks) - 1
1025 |
1026 | res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1027 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
1028 |
1029 | # if we have not reached the final block and need to forward the
1030 | # upsample size, we do it here
1031 | if not is_final_block and forward_upsample_size:
1032 | upsample_size = down_block_res_samples[-1].shape[2:]
1033 |
1034 | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
1035 | sample = upsample_block(
1036 | hidden_states=sample,
1037 | temb=emb,
1038 | res_hidden_states_tuple=res_samples,
1039 | encoder_hidden_states=encoder_hidden_states,
1040 | cross_attention_kwargs=cross_attention_kwargs,
1041 | upsample_size=upsample_size,
1042 | attention_mask=attention_mask,
1043 | encoder_attention_mask=encoder_attention_mask,
1044 | )
1045 | else:
1046 | sample = upsample_block(
1047 | hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
1048 | )
1049 |
1050 |
1051 | ################# bridge usage #################
1052 | if is_bridge and len(up_block_additional_residual) > 0:
1053 | if fusion_guidance_scale is not None:
1054 | sample = sample + fusion_guidance_scale * (up_block_additional_residual.pop(0) - sample)
1055 | else:
1056 | sample += up_block_additional_residual.pop(0)
1057 |
1058 | if return_hidden_states and i > 0:
1059 | # Collect last three up blk in SD1.5
1060 | hidden_states.append(sample)
1061 | ################# end of bridge usage #################
1062 |
1063 | # 6. post-process
1064 | if self.conv_norm_out:
1065 | sample = self.conv_norm_out(sample)
1066 | sample = self.conv_act(sample)
1067 | sample = self.conv_out(sample)
1068 |
1069 | if not return_dict:
1070 | return (sample,)
1071 |
1072 | return UNet2DConditionOutput(sample=sample, hidden_states=hidden_states if return_hidden_states else None,
1073 | encoder_feature=encoder_feature if return_encoder_feature else None)
1074 |
--------------------------------------------------------------------------------
/model/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import imageio
3 | import numpy as np
4 | from typing import Union
5 |
6 | import torch
7 | import torchvision
8 | import torch.distributed as dist
9 |
10 | from safetensors import safe_open
11 | from tqdm import tqdm
12 | from einops import rearrange
13 | from model.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
14 | # from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora, convert_motion_lora_ckpt_to_diffusers
--------------------------------------------------------------------------------
/pipeline/__pycache__/pipeline_sd_xl_adapter.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/showlab/X-Adapter/e7348cb1b83e039dec8fe6f664266f1594e5103f/pipeline/__pycache__/pipeline_sd_xl_adapter.cpython-310.pyc
--------------------------------------------------------------------------------
/pipeline/__pycache__/pipeline_sd_xl_adapter_controlnet.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/showlab/X-Adapter/e7348cb1b83e039dec8fe6f664266f1594e5103f/pipeline/__pycache__/pipeline_sd_xl_adapter_controlnet.cpython-310.pyc
--------------------------------------------------------------------------------
/pipeline/__pycache__/pipeline_sd_xl_adapter_controlnet_img2img.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/showlab/X-Adapter/e7348cb1b83e039dec8fe6f664266f1594e5103f/pipeline/__pycache__/pipeline_sd_xl_adapter_controlnet_img2img.cpython-310.pyc
--------------------------------------------------------------------------------
/pipeline/pipeline_sd_xl_adapter.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import inspect
16 | import os
17 | import PIL
18 | import numpy as np
19 | import torch.nn.functional as F
20 | from PIL import Image
21 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union
22 |
23 | import torch
24 | from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
25 |
26 | from diffusers.image_processor import VaeImageProcessor
27 | from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
28 | # from diffusers.models import AutoencoderKL, UNet2DConditionModel
29 | from diffusers.models import AutoencoderKL
30 | from model.unet_adapter import UNet2DConditionModel
31 |
32 | from diffusers.models.attention_processor import (
33 | AttnProcessor2_0,
34 | LoRAAttnProcessor2_0,
35 | LoRAXFormersAttnProcessor,
36 | XFormersAttnProcessor,
37 | )
38 | from diffusers.schedulers import KarrasDiffusionSchedulers
39 | from diffusers.utils import (
40 | is_accelerate_available,
41 | is_accelerate_version,
42 | is_invisible_watermark_available,
43 | logging,
44 | randn_tensor,
45 | replace_example_docstring,
46 | )
47 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline
48 | from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
49 | from model.adapter import Adapter_XL
50 |
51 |
52 | if is_invisible_watermark_available():
53 | from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
54 |
55 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
56 |
57 | EXAMPLE_DOC_STRING = """
58 | Examples:
59 | ```py
60 | >>> import torch
61 | >>> from diffusers import StableDiffusionXLPipeline
62 |
63 | >>> pipe = StableDiffusionXLPipeline.from_pretrained(
64 | ... "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
65 | ... )
66 | >>> pipe = pipe.to("cuda")
67 |
68 | >>> prompt = "a photo of an astronaut riding a horse on mars"
69 | >>> image = pipe(prompt).images[0]
70 | ```
71 | """
72 |
73 |
74 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
75 | def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
76 | """
77 | Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
78 | Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
79 | """
80 | std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
81 | std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
82 | # rescale the results from guidance (fixes overexposure)
83 | noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
84 | # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
85 | noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
86 | return noise_cfg
87 |
88 |
89 | class StableDiffusionXLAdapterPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin):
90 | r"""
91 | Pipeline for text-to-image generation using Stable Diffusion XL.
92 |
93 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
94 | library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
95 |
96 | In addition the pipeline inherits the following loading methods:
97 | - *LoRA*: [`StableDiffusionXLPipeline.load_lora_weights`]
98 | - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
99 |
100 | as well as the following saving methods:
101 | - *LoRA*: [`loaders.StableDiffusionXLPipeline.save_lora_weights`]
102 |
103 | Args:
104 | vae ([`AutoencoderKL`]):
105 | Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
106 | text_encoder ([`CLIPTextModel`]):
107 | Frozen text-encoder. Stable Diffusion XL uses the text portion of
108 | [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
109 | the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
110 | text_encoder_2 ([` CLIPTextModelWithProjection`]):
111 | Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
112 | [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
113 | specifically the
114 | [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
115 | variant.
116 | tokenizer (`CLIPTokenizer`):
117 | Tokenizer of class
118 | [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
119 | tokenizer_2 (`CLIPTokenizer`):
120 | Second Tokenizer of class
121 | [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
122 | unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
123 | scheduler ([`SchedulerMixin`]):
124 | A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
125 | [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
126 | """
127 |
128 | def __init__(
129 | self,
130 | vae: AutoencoderKL,
131 | text_encoder: CLIPTextModel,
132 | text_encoder_2: CLIPTextModelWithProjection,
133 | tokenizer: CLIPTokenizer,
134 | tokenizer_2: CLIPTokenizer,
135 | unet: UNet2DConditionModel,
136 | scheduler: KarrasDiffusionSchedulers,
137 | vae_sd1_5: AutoencoderKL,
138 | text_encoder_sd1_5: CLIPTextModel,
139 | tokenizer_sd1_5: CLIPTokenizer,
140 | unet_sd1_5: UNet2DConditionModel,
141 | scheduler_sd1_5: KarrasDiffusionSchedulers,
142 | adapter: Adapter_XL,
143 | force_zeros_for_empty_prompt: bool = True,
144 | add_watermarker: Optional[bool] = None,
145 | ):
146 | super().__init__()
147 |
148 | self.register_modules(
149 | vae=vae,
150 | text_encoder=text_encoder,
151 | text_encoder_2=text_encoder_2,
152 | tokenizer=tokenizer,
153 | tokenizer_2=tokenizer_2,
154 | unet=unet,
155 | scheduler=scheduler,
156 | vae_sd1_5=vae_sd1_5,
157 | text_encoder_sd1_5=text_encoder_sd1_5,
158 | tokenizer_sd1_5=tokenizer_sd1_5,
159 | unet_sd1_5=unet_sd1_5,
160 | scheduler_sd1_5=scheduler_sd1_5,
161 | adapter=adapter,
162 | )
163 | self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
164 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
165 | self.vae_scale_factor_sd1_5 = 2 ** (len(self.vae_sd1_5.config.block_out_channels) - 1)
166 | self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
167 | self.default_sample_size = self.unet.config.sample_size
168 | self.image_processor_sd1_5 = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor_sd1_5)
169 |
170 | add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
171 |
172 | if add_watermarker:
173 | self.watermark = StableDiffusionXLWatermarker()
174 | else:
175 | self.watermark = None
176 |
177 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
178 | def enable_vae_slicing(self):
179 | r"""
180 | Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
181 | compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
182 | """
183 | self.vae.enable_slicing()
184 |
185 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
186 | def disable_vae_slicing(self):
187 | r"""
188 | Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
189 | computing decoding in one step.
190 | """
191 | self.vae.disable_slicing()
192 |
193 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
194 | def enable_vae_tiling(self):
195 | r"""
196 | Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
197 | compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
198 | processing larger images.
199 | """
200 | self.vae.enable_tiling()
201 |
202 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
203 | def disable_vae_tiling(self):
204 | r"""
205 | Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
206 | computing decoding in one step.
207 | """
208 | self.vae.disable_tiling()
209 |
210 | def enable_model_cpu_offload(self, gpu_id=0):
211 | r"""
212 | Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
213 | to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
214 | method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
215 | `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
216 | """
217 | if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
218 | from accelerate import cpu_offload_with_hook
219 | else:
220 | raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
221 |
222 | device = torch.device(f"cuda:{gpu_id}")
223 |
224 |
225 | self.to("cpu", silence_dtype_warnings=True)
226 | torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
227 |
228 | model_sequence = (
229 | [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
230 | )
231 | model_sequence.extend([self.unet, self.vae])
232 |
233 | model_sequence.extend([self.unet_sd1_5, self.vae_sd1_5, self.text_encoder_sd1_5, self.adapter])
234 |
235 | hook = None
236 | for cpu_offloaded_model in model_sequence:
237 | _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
238 |
239 | # We'll offload the last model manually.
240 | self.final_offload_hook = hook
241 |
242 | def encode_prompt(
243 | self,
244 | prompt: str,
245 | prompt_2: Optional[str] = None,
246 | device: Optional[torch.device] = None,
247 | num_images_per_prompt: int = 1,
248 | do_classifier_free_guidance: bool = True,
249 | negative_prompt: Optional[str] = None,
250 | negative_prompt_2: Optional[str] = None,
251 | prompt_embeds: Optional[torch.FloatTensor] = None,
252 | negative_prompt_embeds: Optional[torch.FloatTensor] = None,
253 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
254 | negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
255 | lora_scale: Optional[float] = None,
256 | ):
257 | r"""
258 | Encodes the prompt into text encoder hidden states.
259 |
260 | Args:
261 | prompt (`str` or `List[str]`, *optional*):
262 | prompt to be encoded
263 | prompt_2 (`str` or `List[str]`, *optional*):
264 | The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
265 | used in both text-encoders
266 | device: (`torch.device`):
267 | torch device
268 | num_images_per_prompt (`int`):
269 | number of images that should be generated per prompt
270 | do_classifier_free_guidance (`bool`):
271 | whether to use classifier free guidance or not
272 | negative_prompt (`str` or `List[str]`, *optional*):
273 | The prompt or prompts not to guide the image generation. If not defined, one has to pass
274 | `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
275 | less than `1`).
276 | negative_prompt_2 (`str` or `List[str]`, *optional*):
277 | The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
278 | `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
279 | prompt_embeds (`torch.FloatTensor`, *optional*):
280 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
281 | provided, text embeddings will be generated from `prompt` input argument.
282 | negative_prompt_embeds (`torch.FloatTensor`, *optional*):
283 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
284 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
285 | argument.
286 | pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
287 | Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
288 | If not provided, pooled text embeddings will be generated from `prompt` input argument.
289 | negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
290 | Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
291 | weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
292 | input argument.
293 | lora_scale (`float`, *optional*):
294 | A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
295 | """
296 | device = device or self._execution_device
297 |
298 | # set lora scale so that monkey patched LoRA
299 | # function of text encoder can correctly access it
300 | if lora_scale is not None and isinstance(self, LoraLoaderMixin):
301 | self._lora_scale = lora_scale
302 |
303 | if prompt is not None and isinstance(prompt, str):
304 | batch_size = 1
305 | elif prompt is not None and isinstance(prompt, list):
306 | batch_size = len(prompt)
307 | else:
308 | batch_size = prompt_embeds.shape[0]
309 |
310 | # Define tokenizers and text encoders
311 | tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
312 | text_encoders = (
313 | [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
314 | )
315 |
316 | if prompt_embeds is None:
317 | prompt_2 = prompt_2 or prompt
318 | # textual inversion: procecss multi-vector tokens if necessary
319 | prompt_embeds_list = []
320 | prompts = [prompt, prompt_2]
321 | for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
322 | if isinstance(self, TextualInversionLoaderMixin):
323 | prompt = self.maybe_convert_prompt(prompt, tokenizer)
324 |
325 | text_inputs = tokenizer(
326 | prompt,
327 | padding="max_length",
328 | max_length=tokenizer.model_max_length,
329 | truncation=True,
330 | return_tensors="pt",
331 | )
332 |
333 | text_input_ids = text_inputs.input_ids
334 | untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
335 |
336 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
337 | text_input_ids, untruncated_ids
338 | ):
339 | removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1: -1])
340 | logger.warning(
341 | "The following part of your input was truncated because CLIP can only handle sequences up to"
342 | f" {tokenizer.model_max_length} tokens: {removed_text}"
343 | )
344 |
345 | prompt_embeds = text_encoder(
346 | text_input_ids.to(device),
347 | output_hidden_states=True,
348 | )
349 |
350 | # We are only ALWAYS interested in the pooled output of the final text encoder
351 | pooled_prompt_embeds = prompt_embeds[0]
352 | prompt_embeds = prompt_embeds.hidden_states[-2]
353 |
354 | prompt_embeds_list.append(prompt_embeds)
355 |
356 | prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
357 |
358 | # get unconditional embeddings for classifier free guidance
359 | zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
360 | if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
361 | negative_prompt_embeds = torch.zeros_like(prompt_embeds)
362 | negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
363 | elif do_classifier_free_guidance and negative_prompt_embeds is None:
364 | negative_prompt = negative_prompt or ""
365 | negative_prompt_2 = negative_prompt_2 or negative_prompt
366 |
367 | uncond_tokens: List[str]
368 | if prompt is not None and type(prompt) is not type(negative_prompt):
369 | raise TypeError(
370 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
371 | f" {type(prompt)}."
372 | )
373 | elif isinstance(negative_prompt, str):
374 | uncond_tokens = [negative_prompt, negative_prompt_2]
375 | elif batch_size != len(negative_prompt):
376 | raise ValueError(
377 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
378 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
379 | " the batch size of `prompt`."
380 | )
381 | else:
382 | uncond_tokens = [negative_prompt, negative_prompt_2]
383 |
384 | negative_prompt_embeds_list = []
385 | for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
386 | if isinstance(self, TextualInversionLoaderMixin):
387 | negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
388 |
389 | max_length = prompt_embeds.shape[1]
390 | uncond_input = tokenizer(
391 | negative_prompt,
392 | padding="max_length",
393 | max_length=max_length,
394 | truncation=True,
395 | return_tensors="pt",
396 | )
397 |
398 | negative_prompt_embeds = text_encoder(
399 | uncond_input.input_ids.to(device),
400 | output_hidden_states=True,
401 | )
402 | # We are only ALWAYS interested in the pooled output of the final text encoder
403 | negative_pooled_prompt_embeds = negative_prompt_embeds[0]
404 | negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
405 |
406 | negative_prompt_embeds_list.append(negative_prompt_embeds)
407 |
408 | negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
409 |
410 | prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
411 | bs_embed, seq_len, _ = prompt_embeds.shape
412 | # duplicate text embeddings for each generation per prompt, using mps friendly method
413 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
414 | prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
415 |
416 | if do_classifier_free_guidance:
417 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
418 | seq_len = negative_prompt_embeds.shape[1]
419 | negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
420 | negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
421 | negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
422 |
423 | pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
424 | bs_embed * num_images_per_prompt, -1
425 | )
426 | if do_classifier_free_guidance:
427 | negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
428 | bs_embed * num_images_per_prompt, -1
429 | )
430 |
431 | return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
432 |
433 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
434 | def prepare_extra_step_kwargs(self, generator, eta):
435 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
436 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
437 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
438 | # and should be between [0, 1]
439 |
440 | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
441 | extra_step_kwargs = {}
442 | if accepts_eta:
443 | extra_step_kwargs["eta"] = eta
444 |
445 | # check if the scheduler accepts generator
446 | accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
447 | if accepts_generator:
448 | extra_step_kwargs["generator"] = generator
449 | return extra_step_kwargs
450 |
451 | def check_inputs(
452 | self,
453 | prompt,
454 | prompt_2,
455 | height,
456 | width,
457 | callback_steps,
458 | negative_prompt=None,
459 | negative_prompt_2=None,
460 | prompt_embeds=None,
461 | negative_prompt_embeds=None,
462 | pooled_prompt_embeds=None,
463 | negative_pooled_prompt_embeds=None,
464 | ):
465 | if height % 8 != 0 or width % 8 != 0:
466 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
467 |
468 | if (callback_steps is None) or (
469 | callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
470 | ):
471 | raise ValueError(
472 | f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
473 | f" {type(callback_steps)}."
474 | )
475 |
476 | if prompt is not None and prompt_embeds is not None:
477 | raise ValueError(
478 | f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
479 | " only forward one of the two."
480 | )
481 | elif prompt_2 is not None and prompt_embeds is not None:
482 | raise ValueError(
483 | f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
484 | " only forward one of the two."
485 | )
486 | elif prompt is None and prompt_embeds is None:
487 | raise ValueError(
488 | "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
489 | )
490 | elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
491 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
492 | elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
493 | raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
494 |
495 | if negative_prompt is not None and negative_prompt_embeds is not None:
496 | raise ValueError(
497 | f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
498 | f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
499 | )
500 | elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
501 | raise ValueError(
502 | f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
503 | f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
504 | )
505 |
506 | if prompt_embeds is not None and negative_prompt_embeds is not None:
507 | if prompt_embeds.shape != negative_prompt_embeds.shape:
508 | raise ValueError(
509 | "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
510 | f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
511 | f" {negative_prompt_embeds.shape}."
512 | )
513 |
514 | if prompt_embeds is not None and pooled_prompt_embeds is None:
515 | raise ValueError(
516 | "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
517 | )
518 |
519 | if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
520 | raise ValueError(
521 | "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
522 | )
523 |
524 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
525 | def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
526 | shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
527 | if isinstance(generator, list) and len(generator) != batch_size:
528 | raise ValueError(
529 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
530 | f" size of {batch_size}. Make sure the batch size matches the length of the generators."
531 | )
532 |
533 | if latents is None:
534 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
535 | else:
536 | latents = latents.to(device)
537 |
538 | # scale the initial noise by the standard deviation required by the scheduler
539 | latents = latents * self.scheduler.init_noise_sigma
540 | return latents
541 |
542 | def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):
543 | add_time_ids = list(original_size + crops_coords_top_left + target_size)
544 |
545 | passed_add_embed_dim = (
546 | self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
547 | )
548 | expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
549 |
550 | if expected_add_embed_dim != passed_add_embed_dim:
551 | raise ValueError(
552 | f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
553 | )
554 |
555 | add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
556 | return add_time_ids
557 |
558 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
559 | def upcast_vae(self):
560 | dtype = self.vae.dtype
561 | self.vae.to(dtype=torch.float32)
562 | use_torch_2_0_or_xformers = isinstance(
563 | self.vae.decoder.mid_block.attentions[0].processor,
564 | (
565 | AttnProcessor2_0,
566 | XFormersAttnProcessor,
567 | LoRAXFormersAttnProcessor,
568 | LoRAAttnProcessor2_0,
569 | ),
570 | )
571 | # if xformers or torch_2_0 is used attention block does not need
572 | # to be in float32 which can save lots of memory
573 | if use_torch_2_0_or_xformers:
574 | self.vae.post_quant_conv.to(dtype)
575 | self.vae.decoder.conv_in.to(dtype)
576 | self.vae.decoder.mid_block.to(dtype)
577 |
578 | @torch.no_grad()
579 | @replace_example_docstring(EXAMPLE_DOC_STRING)
580 | def __call__(
581 | self,
582 | prompt: Union[str, List[str]] = None,
583 | prompt_2: Optional[Union[str, List[str]]] = None,
584 | prompt_sd1_5: Optional[Union[str, List[str]]] = None,
585 | height: Optional[int] = None,
586 | width: Optional[int] = None,
587 | height_sd1_5: Optional[int] = None,
588 | width_sd1_5: Optional[int] = None,
589 | num_inference_steps: int = 50,
590 | denoising_end: Optional[float] = None,
591 | guidance_scale: float = 5.0,
592 | negative_prompt: Optional[Union[str, List[str]]] = None,
593 | negative_prompt_2: Optional[Union[str, List[str]]] = None,
594 | num_images_per_prompt: Optional[int] = 1,
595 | eta: float = 0.0,
596 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
597 | latents: Optional[torch.FloatTensor] = None,
598 | latents_sd1_5: Optional[torch.FloatTensor] = None,
599 | prompt_embeds: Optional[torch.FloatTensor] = None,
600 | negative_prompt_embeds: Optional[torch.FloatTensor] = None,
601 | prompt_embeds_sd1_5: Optional[torch.FloatTensor] = None,
602 | negative_prompt_embeds_sd1_5: Optional[torch.FloatTensor] = None,
603 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
604 | negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
605 | output_type: Optional[str] = "pil",
606 | return_dict: bool = True,
607 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
608 | callback_steps: int = 1,
609 | cross_attention_kwargs: Optional[Dict[str, Any]] = None,
610 | guidance_rescale: float = 0.0,
611 | original_size: Optional[Tuple[int, int]] = None,
612 | crops_coords_top_left: Tuple[int, int] = (0, 0),
613 | target_size: Optional[Tuple[int, int]] = None,
614 | adapter_condition_scale: Optional[float] = 1.0,
615 | adapter_guidance_start: Union[float, List[float]] = 0.5,
616 | denoising_start: Optional[float] = None,
617 | adapter_type: str = "de", # "de", "en", "en_de"
618 | fusion_guidance_scale: Optional[float] = None,
619 | enable_time_step: bool = False
620 | ):
621 | r"""
622 | Function invoked when calling the pipeline for generation.
623 |
624 | Args:
625 | prompt (`str` or `List[str]`, *optional*):
626 | The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
627 | instead.
628 | prompt_2 (`str` or `List[str]`, *optional*):
629 | The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
630 | used in both text-encoders
631 | height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
632 | The height in pixels of the generated image.
633 | width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
634 | The width in pixels of the generated image.
635 | num_inference_steps (`int`, *optional*, defaults to 50):
636 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the
637 | expense of slower inference.
638 | denoising_end (`float`, *optional*):
639 | When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
640 | completed before it is intentionally prematurely terminated. As a result, the returned sample will
641 | still retain a substantial amount of noise as determined by the discrete timesteps selected by the
642 | scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
643 | "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
644 | Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
645 | guidance_scale (`float`, *optional*, defaults to 5.0):
646 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
647 | `guidance_scale` is defined as `w` of equation 2. of [Imagen
648 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
649 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
650 | usually at the expense of lower image quality.
651 | negative_prompt (`str` or `List[str]`, *optional*):
652 | The prompt or prompts not to guide the image generation. If not defined, one has to pass
653 | `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
654 | less than `1`).
655 | negative_prompt_2 (`str` or `List[str]`, *optional*):
656 | The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
657 | `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
658 | num_images_per_prompt (`int`, *optional*, defaults to 1):
659 | The number of images to generate per prompt.
660 | eta (`float`, *optional*, defaults to 0.0):
661 | Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
662 | [`schedulers.DDIMScheduler`], will be ignored for others.
663 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
664 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
665 | to make generation deterministic.
666 | latents (`torch.FloatTensor`, *optional*):
667 | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
668 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
669 | tensor will ge generated by sampling using the supplied random `generator`.
670 | prompt_embeds (`torch.FloatTensor`, *optional*):
671 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
672 | provided, text embeddings will be generated from `prompt` input argument.
673 | negative_prompt_embeds (`torch.FloatTensor`, *optional*):
674 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
675 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
676 | argument.
677 | pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
678 | Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
679 | If not provided, pooled text embeddings will be generated from `prompt` input argument.
680 | negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
681 | Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
682 | weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
683 | input argument.
684 | output_type (`str`, *optional*, defaults to `"pil"`):
685 | The output format of the generate image. Choose between
686 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
687 | return_dict (`bool`, *optional*, defaults to `True`):
688 | Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
689 | of a plain tuple.
690 | callback (`Callable`, *optional*):
691 | A function that will be called every `callback_steps` steps during inference. The function will be
692 | called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
693 | callback_steps (`int`, *optional*, defaults to 1):
694 | The frequency at which the `callback` function will be called. If not specified, the callback will be
695 | called at every step.
696 | cross_attention_kwargs (`dict`, *optional*):
697 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
698 | `self.processor` in
699 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
700 | guidance_rescale (`float`, *optional*, defaults to 0.7):
701 | Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
702 | Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
703 | [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
704 | Guidance rescale factor should fix overexposure when using zero terminal SNR.
705 | original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
706 | If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
707 | `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as
708 | explained in section 2.2 of
709 | [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
710 | crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
711 | `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
712 | `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
713 | `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
714 | [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
715 | target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
716 | For most cases, `target_size` should be set to the desired height and width of the generated image. If
717 | not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in
718 | section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
719 |
720 | Examples:
721 |
722 | Returns:
723 | [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
724 | [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
725 | `tuple`. When returning a tuple, the first element is a list with the generated images.
726 | """
727 | # 0. Default height and width to unet
728 | height = height or self.default_sample_size * self.vae_scale_factor
729 | width = width or self.default_sample_size * self.vae_scale_factor
730 |
731 | height_sd1_5 = height_sd1_5 or self.default_sample_size_sd1_5 * self.vae_scale_factor_sd1_5
732 | width_sd1_5 = width_sd1_5 or self.default_sample_size_sd1_5 * self.vae_scale_factor_sd1_5
733 |
734 | original_size = original_size or (height, width)
735 | target_size = target_size or (height, width)
736 |
737 | # 1. Check inputs. Raise error if not correct
738 | self.check_inputs(
739 | prompt,
740 | prompt_2,
741 | height,
742 | width,
743 | callback_steps,
744 | negative_prompt,
745 | negative_prompt_2,
746 | prompt_embeds,
747 | negative_prompt_embeds,
748 | pooled_prompt_embeds,
749 | negative_pooled_prompt_embeds,
750 | )
751 |
752 | self.check_inputs_sd1_5(
753 | prompt if prompt_sd1_5 is None else prompt_sd1_5, height_sd1_5, width_sd1_5, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
754 | )
755 |
756 | # 2. Define call parameters
757 | if prompt is not None and isinstance(prompt, str):
758 | batch_size = 1
759 | elif prompt is not None and isinstance(prompt, list):
760 | batch_size = len(prompt)
761 | else:
762 | batch_size = prompt_embeds.shape[0]
763 |
764 | device = torch.device('cuda')
765 |
766 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
767 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
768 | # corresponds to doing no classifier free guidance.
769 | do_classifier_free_guidance = guidance_scale > 1.0
770 |
771 | # 3. Encode input prompt
772 | text_encoder_lora_scale = (
773 | cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
774 | )
775 | (
776 | prompt_embeds,
777 | negative_prompt_embeds,
778 | pooled_prompt_embeds,
779 | negative_pooled_prompt_embeds,
780 | ) = self.encode_prompt(
781 | prompt=prompt,
782 | prompt_2=prompt_2,
783 | device=device,
784 | num_images_per_prompt=num_images_per_prompt,
785 | do_classifier_free_guidance=do_classifier_free_guidance,
786 | negative_prompt=negative_prompt,
787 | negative_prompt_2=negative_prompt_2,
788 | prompt_embeds=prompt_embeds,
789 | negative_prompt_embeds=negative_prompt_embeds,
790 | pooled_prompt_embeds=pooled_prompt_embeds,
791 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
792 | lora_scale=text_encoder_lora_scale,
793 | )
794 |
795 | prompt_embeds_sd1_5 = self._encode_prompt_sd1_5(
796 | prompt if prompt_sd1_5 is None else prompt_sd1_5,
797 | device,
798 | num_images_per_prompt,
799 | do_classifier_free_guidance,
800 | negative_prompt,
801 | prompt_embeds=prompt_embeds_sd1_5,
802 | negative_prompt_embeds=negative_prompt_embeds_sd1_5,
803 | lora_scale=text_encoder_lora_scale,
804 | )
805 | # todo: implement prompt_embeds for SD1.5
806 |
807 | # 4. Prepare timesteps
808 | self.scheduler_sd1_5.set_timesteps(num_inference_steps, device=device)
809 | timesteps_sd1_5 = self.scheduler_sd1_5.timesteps
810 | num_inference_steps_sd1_5 = num_inference_steps
811 |
812 | self.scheduler.set_timesteps(num_inference_steps, device=device)
813 |
814 | timesteps, num_inference_steps = self.get_timesteps(
815 | num_inference_steps, adapter_guidance_start, device, denoising_start=denoising_start
816 | )
817 | latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
818 |
819 | # 5. Prepare latent variables
820 | num_channels_latents = self.unet.config.in_channels
821 | latents = self.prepare_latents(
822 | batch_size * num_images_per_prompt,
823 | num_channels_latents,
824 | height,
825 | width,
826 | prompt_embeds.dtype,
827 | device,
828 | generator,
829 | latents,
830 | )
831 |
832 | num_channels_latents_sd1_5 = self.unet_sd1_5.config.in_channels
833 | latents_sd1_5 = self.prepare_latents_sd1_5(
834 | batch_size * num_images_per_prompt,
835 | num_channels_latents_sd1_5,
836 | height_sd1_5,
837 | width_sd1_5,
838 | prompt_embeds_sd1_5.dtype,
839 | device,
840 | generator,
841 | latents_sd1_5,
842 | )
843 |
844 | # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
845 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
846 |
847 | # 7. Prepare added time ids & embeddings
848 | add_text_embeds = pooled_prompt_embeds
849 | add_time_ids = self._get_add_time_ids(
850 | original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
851 | )
852 |
853 | if do_classifier_free_guidance:
854 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
855 | add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
856 | add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
857 |
858 | prompt_embeds = prompt_embeds.to(device)
859 | add_text_embeds = add_text_embeds.to(device)
860 | add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
861 |
862 | # 8. Denoising loop
863 | num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
864 |
865 | # 7.1 Apply denoising_end
866 | if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1:
867 | discrete_timestep_cutoff = int(
868 | round(
869 | self.scheduler.config.num_train_timesteps
870 | - (denoising_end * self.scheduler.config.num_train_timesteps)
871 | )
872 | )
873 | num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
874 | timesteps = timesteps[:num_inference_steps]
875 |
876 | latents_sd1_5_prior = latents_sd1_5.clone()
877 |
878 | with self.progress_bar(total=num_inference_steps_sd1_5) as progress_bar:
879 | for i, t in enumerate(timesteps_sd1_5):
880 |
881 | #################### SD1.5 forward ####################
882 | t_sd1_5 = timesteps_sd1_5[i]
883 |
884 | latent_model_input = torch.cat([latents_sd1_5_prior] * 2) if do_classifier_free_guidance else latents_sd1_5_prior
885 | latent_model_input = self.scheduler_sd1_5.scale_model_input(latent_model_input, t_sd1_5)
886 |
887 | # predict the noise residual
888 | unet_output = self.unet_sd1_5(
889 | latent_model_input,
890 | t_sd1_5,
891 | encoder_hidden_states=prompt_embeds_sd1_5,
892 | cross_attention_kwargs=cross_attention_kwargs,
893 | return_hidden_states=False
894 | )
895 | noise_pred = unet_output.sample
896 |
897 | # perform guidance
898 | if do_classifier_free_guidance:
899 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
900 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
901 |
902 | if do_classifier_free_guidance and guidance_rescale > 0.0:
903 | # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
904 | noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
905 |
906 | # compute the previous noisy sample x_t -> x_t-1
907 | latents_sd1_5_prior = self.scheduler_sd1_5.step(noise_pred, t_sd1_5, latents_sd1_5_prior, **extra_step_kwargs, return_dict=False)[0]
908 |
909 | #################### End of SD1.5 forward ####################
910 |
911 | # call the callback, if provided
912 | if i == len(timesteps_sd1_5) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler_sd1_5.order == 0):
913 | progress_bar.update()
914 |
915 | add_noise = True if denoising_start is None else False
916 | latents = self.prepare_xl_latents_from_sd_1_5(latents_sd1_5_prior, latent_timestep, batch_size,
917 | num_images_per_prompt, height, width, prompt_embeds.dtype, device,
918 | generator=generator, add_noise=add_noise)
919 | latents_sd1_5 = self.sd1_5_add_noise(latents_sd1_5_prior, latent_timestep, generator, device,
920 | prompt_embeds.dtype)
921 |
922 | with self.progress_bar(total=num_inference_steps) as progress_bar:
923 | for i, t in enumerate(timesteps):
924 | # expand the latents if we are doing classifier free guidance
925 |
926 | #################### SD1.5 forward ####################
927 | t_sd1_5 = timesteps_sd1_5[i]
928 | latent_model_input = torch.cat([latents_sd1_5] * 2) if do_classifier_free_guidance else latents
929 | latent_model_input = self.scheduler_sd1_5.scale_model_input(latent_model_input, t_sd1_5)
930 |
931 | unet_output = self.unet_sd1_5(
932 | latent_model_input,
933 | t_sd1_5,
934 | encoder_hidden_states=prompt_embeds_sd1_5,
935 | cross_attention_kwargs=cross_attention_kwargs,
936 | return_hidden_states=True,
937 | return_encoder_feature=True
938 | )
939 | noise_pred = unet_output.sample
940 | hidden_states = unet_output.hidden_states
941 | encoder_feature = unet_output.encoder_feature
942 |
943 |
944 | # adapter forward
945 | if adapter_type == "de":
946 | down_bridge_residuals = None
947 | up_block_additional_residual = self.adapter(hidden_states, t=t_sd1_5 if enable_time_step else None)
948 | for xx in range(len(up_block_additional_residual)):
949 | up_block_additional_residual[xx] = up_block_additional_residual[xx] * adapter_condition_scale
950 | elif adapter_type == "en":
951 | up_block_additional_residual = None
952 | down_bridge_residuals = self.adapter(encoder_feature)
953 | for xx in range(len(down_bridge_residuals)):
954 | down_bridge_residuals[xx] = down_bridge_residuals[xx] * adapter_condition_scale
955 | else:
956 | dict = self.adapter(x=hidden_states, enc_x=encoder_feature)
957 | down_bridge_residuals = dict['encoder_features']
958 | up_block_additional_residual = dict['decoder_features']
959 | for xx in range(len(up_block_additional_residual)):
960 | up_block_additional_residual[xx] = up_block_additional_residual[xx] * adapter_condition_scale
961 | for xx in range(len(down_bridge_residuals)):
962 | down_bridge_residuals[xx] = down_bridge_residuals[xx] * adapter_condition_scale
963 |
964 | # perform guidance
965 | if do_classifier_free_guidance:
966 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
967 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
968 |
969 | if do_classifier_free_guidance and guidance_rescale > 0.0:
970 | # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
971 | noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
972 |
973 | # compute the previous noisy sample x_t -> x_t-1
974 |
975 | latents_sd1_5 = self.scheduler_sd1_5.step(noise_pred, t_sd1_5, latents_sd1_5, **extra_step_kwargs,
976 | return_dict=False)[0]
977 |
978 | #################### End of SD1.5 forward ####################
979 |
980 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
981 |
982 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
983 |
984 | # predict the noise residual
985 | added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
986 |
987 | noise_pred = self.unet(
988 | latent_model_input,
989 | t,
990 | encoder_hidden_states=prompt_embeds,
991 | cross_attention_kwargs=cross_attention_kwargs,
992 | added_cond_kwargs=added_cond_kwargs,
993 | up_block_additional_residual=up_block_additional_residual,
994 | down_bridge_residuals=down_bridge_residuals,
995 | return_dict=False,
996 | fusion_guidance_scale=fusion_guidance_scale
997 | )[0]
998 |
999 | # perform guidance
1000 | if do_classifier_free_guidance:
1001 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1002 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1003 |
1004 | if do_classifier_free_guidance and guidance_rescale > 0.0:
1005 | # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1006 | noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
1007 |
1008 | # compute the previous noisy sample x_t -> x_t-1
1009 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1010 |
1011 | # call the callback, if provided
1012 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1013 | progress_bar.update()
1014 | if callback is not None and i % callback_steps == 0:
1015 | callback(i, t, latents)
1016 |
1017 | # make sure the VAE is in float32 mode, as it overflows in float16
1018 | if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
1019 | self.upcast_vae()
1020 | latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1021 |
1022 | if not output_type == "latent":
1023 | image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
1024 | else:
1025 | image = latents
1026 | return StableDiffusionXLPipelineOutput(images=image)
1027 |
1028 | # apply watermark if available
1029 | if self.watermark is not None:
1030 | image = self.watermark.apply_watermark(image)
1031 |
1032 | image = self.image_processor.postprocess(image, output_type=output_type)
1033 |
1034 | # Offload last model to CPU
1035 | if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1036 | self.final_offload_hook.offload()
1037 |
1038 | if not return_dict:
1039 | return (image,)
1040 |
1041 | return StableDiffusionXLPipelineOutput(images=image)
1042 |
1043 | # Overrride to properly handle the loading and unloading of the additional text encoder.
1044 | def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
1045 | # We could have accessed the unet config from `lora_state_dict()` too. We pass
1046 | # it here explicitly to be able to tell that it's coming from an SDXL
1047 | # pipeline.
1048 | state_dict, network_alphas = self.lora_state_dict(
1049 | pretrained_model_name_or_path_or_dict,
1050 | unet_config=self.unet.config,
1051 | **kwargs,
1052 | )
1053 | self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)
1054 |
1055 | text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
1056 | if len(text_encoder_state_dict) > 0:
1057 | self.load_lora_into_text_encoder(
1058 | text_encoder_state_dict,
1059 | network_alphas=network_alphas,
1060 | text_encoder=self.text_encoder,
1061 | prefix="text_encoder",
1062 | lora_scale=self.lora_scale,
1063 | )
1064 |
1065 | text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
1066 | if len(text_encoder_2_state_dict) > 0:
1067 | self.load_lora_into_text_encoder(
1068 | text_encoder_2_state_dict,
1069 | network_alphas=network_alphas,
1070 | text_encoder=self.text_encoder_2,
1071 | prefix="text_encoder_2",
1072 | lora_scale=self.lora_scale,
1073 | )
1074 |
1075 | @classmethod
1076 | def save_lora_weights(
1077 | self,
1078 | save_directory: Union[str, os.PathLike],
1079 | unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
1080 | text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
1081 | text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
1082 | is_main_process: bool = True,
1083 | weight_name: str = None,
1084 | save_function: Callable = None,
1085 | safe_serialization: bool = True,
1086 | ):
1087 | state_dict = {}
1088 |
1089 | def pack_weights(layers, prefix):
1090 | layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
1091 | layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
1092 | return layers_state_dict
1093 |
1094 | state_dict.update(pack_weights(unet_lora_layers, "unet"))
1095 |
1096 | if text_encoder_lora_layers and text_encoder_2_lora_layers:
1097 | state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
1098 | state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
1099 |
1100 | self.write_lora_layers(
1101 | state_dict=state_dict,
1102 | save_directory=save_directory,
1103 | is_main_process=is_main_process,
1104 | weight_name=weight_name,
1105 | save_function=save_function,
1106 | safe_serialization=safe_serialization,
1107 | )
1108 |
1109 | def _remove_text_encoder_monkey_patch(self):
1110 | self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
1111 | self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)
1112 |
1113 | def _encode_prompt_sd1_5(
1114 | self,
1115 | prompt,
1116 | device,
1117 | num_images_per_prompt,
1118 | do_classifier_free_guidance,
1119 | negative_prompt=None,
1120 | prompt_embeds: Optional[torch.FloatTensor] = None,
1121 | negative_prompt_embeds: Optional[torch.FloatTensor] = None,
1122 | lora_scale: Optional[float] = None,
1123 | ):
1124 | r"""
1125 | Encodes the prompt into text encoder hidden states.
1126 |
1127 | Args:
1128 | prompt (`str` or `List[str]`, *optional*):
1129 | prompt to be encoded
1130 | device: (`torch.device`):
1131 | torch device
1132 | num_images_per_prompt (`int`):
1133 | number of images that should be generated per prompt
1134 | do_classifier_free_guidance (`bool`):
1135 | whether to use classifier free guidance or not
1136 | negative_prompt (`str` or `List[str]`, *optional*):
1137 | The prompt or prompts not to guide the image generation. If not defined, one has to pass
1138 | `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
1139 | less than `1`).
1140 | prompt_embeds (`torch.FloatTensor`, *optional*):
1141 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1142 | provided, text embeddings will be generated from `prompt` input argument.
1143 | negative_prompt_embeds (`torch.FloatTensor`, *optional*):
1144 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1145 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
1146 | argument.
1147 | lora_scale (`float`, *optional*):
1148 | A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
1149 | """
1150 | # set lora scale so that monkey patched LoRA
1151 | # function of text encoder can correctly access it
1152 | if lora_scale is not None and isinstance(self, LoraLoaderMixin):
1153 | self._lora_scale = lora_scale
1154 |
1155 | if prompt is not None and isinstance(prompt, str):
1156 | batch_size = 1
1157 | elif prompt is not None and isinstance(prompt, list):
1158 | batch_size = len(prompt)
1159 | else:
1160 | batch_size = prompt_embeds.shape[0]
1161 |
1162 | if prompt_embeds is None:
1163 | # textual inversion: procecss multi-vector tokens if necessary
1164 | if isinstance(self, TextualInversionLoaderMixin):
1165 | prompt = self.maybe_convert_prompt(prompt, self.tokenizer_sd1_5)
1166 |
1167 | text_inputs = self.tokenizer_sd1_5(
1168 | prompt,
1169 | padding="max_length",
1170 | max_length=self.tokenizer_sd1_5.model_max_length,
1171 | truncation=True,
1172 | return_tensors="pt",
1173 | )
1174 | text_input_ids = text_inputs.input_ids
1175 | untruncated_ids = self.tokenizer_sd1_5(prompt, padding="longest", return_tensors="pt").input_ids
1176 |
1177 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
1178 | text_input_ids, untruncated_ids
1179 | ):
1180 | removed_text = self.tokenizer_sd1_5.batch_decode(
1181 | untruncated_ids[:, self.tokenizer_sd1_5.model_max_length - 1: -1]
1182 | )
1183 | logger.warning(
1184 | "The following part of your input was truncated because CLIP can only handle sequences up to"
1185 | f" {self.tokenizer_sd1_5.model_max_length} tokens: {removed_text}"
1186 | )
1187 |
1188 | if hasattr(self.text_encoder_sd1_5.config,
1189 | "use_attention_mask") and self.text_encoder_sd1_5.config.use_attention_mask:
1190 | attention_mask = text_inputs.attention_mask.to(device)
1191 | else:
1192 | attention_mask = None
1193 |
1194 | prompt_embeds = self.text_encoder_sd1_5(
1195 | text_input_ids.to(device),
1196 | attention_mask=attention_mask,
1197 | )
1198 | prompt_embeds = prompt_embeds[0]
1199 |
1200 | if self.text_encoder_sd1_5 is not None:
1201 | prompt_embeds_dtype = self.text_encoder_sd1_5.dtype
1202 | elif self.unet_sd1_5 is not None:
1203 | prompt_embeds_dtype = self.unet_sd1_5.dtype
1204 | else:
1205 | prompt_embeds_dtype = prompt_embeds.dtype
1206 |
1207 | prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
1208 |
1209 | bs_embed, seq_len, _ = prompt_embeds.shape
1210 | # duplicate text embeddings for each generation per prompt, using mps friendly method
1211 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
1212 | prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
1213 |
1214 | # get unconditional embeddings for classifier free guidance
1215 | if do_classifier_free_guidance and negative_prompt_embeds is None:
1216 | uncond_tokens: List[str]
1217 | if negative_prompt is None:
1218 | uncond_tokens = [""] * batch_size
1219 | elif prompt is not None and type(prompt) is not type(negative_prompt):
1220 | raise TypeError(
1221 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
1222 | f" {type(prompt)}."
1223 | )
1224 | elif isinstance(negative_prompt, str):
1225 | uncond_tokens = [negative_prompt]
1226 | elif batch_size != len(negative_prompt):
1227 | raise ValueError(
1228 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
1229 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
1230 | " the batch size of `prompt`."
1231 | )
1232 | else:
1233 | uncond_tokens = negative_prompt
1234 |
1235 | # textual inversion: procecss multi-vector tokens if necessary
1236 | if isinstance(self, TextualInversionLoaderMixin):
1237 | uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer_sd1_5)
1238 |
1239 | max_length = prompt_embeds.shape[1]
1240 | uncond_input = self.tokenizer_sd1_5(
1241 | uncond_tokens,
1242 | padding="max_length",
1243 | max_length=max_length,
1244 | truncation=True,
1245 | return_tensors="pt",
1246 | )
1247 |
1248 | if hasattr(self.text_encoder_sd1_5.config,
1249 | "use_attention_mask") and self.text_encoder_sd1_5.config.use_attention_mask:
1250 | attention_mask = uncond_input.attention_mask.to(device)
1251 | else:
1252 | attention_mask = None
1253 |
1254 | negative_prompt_embeds = self.text_encoder_sd1_5(
1255 | uncond_input.input_ids.to(device),
1256 | attention_mask=attention_mask,
1257 | )
1258 | negative_prompt_embeds = negative_prompt_embeds[0]
1259 |
1260 | if do_classifier_free_guidance:
1261 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
1262 | seq_len = negative_prompt_embeds.shape[1]
1263 |
1264 | negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
1265 |
1266 | negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
1267 | negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
1268 |
1269 | # For classifier free guidance, we need to do two forward passes.
1270 | # Here we concatenate the unconditional and text embeddings into a single batch
1271 | # to avoid doing two forward passes
1272 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
1273 |
1274 | return prompt_embeds
1275 |
1276 | def decode_latents_sd1_5(self, latents):
1277 | warnings.warn(
1278 | "The decode_latents method is deprecated and will be removed in a future version. Please"
1279 | " use VaeImageProcessor instead",
1280 | FutureWarning,
1281 | )
1282 | latents = 1 / self.vae_sd1_5.config.scaling_factor * latents
1283 | image = self.vae_sd1_5.decode(latents, return_dict=False)[0]
1284 | image = (image / 2 + 0.5).clamp(0, 1)
1285 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
1286 | image = image.cpu().permute(0, 2, 3, 1).float().numpy()
1287 | return image
1288 |
1289 | def check_inputs_sd1_5(
1290 | self,
1291 | prompt,
1292 | height,
1293 | width,
1294 | callback_steps,
1295 | negative_prompt=None,
1296 | prompt_embeds=None,
1297 | negative_prompt_embeds=None,
1298 | ):
1299 | if height % 8 != 0 or width % 8 != 0:
1300 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
1301 |
1302 | if (callback_steps is None) or (
1303 | callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
1304 | ):
1305 | raise ValueError(
1306 | f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
1307 | f" {type(callback_steps)}."
1308 | )
1309 |
1310 | if prompt is not None and prompt_embeds is not None:
1311 | raise ValueError(
1312 | f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
1313 | " only forward one of the two."
1314 | )
1315 | elif prompt is None and prompt_embeds is None:
1316 | raise ValueError(
1317 | "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
1318 | )
1319 | elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
1320 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
1321 |
1322 | if negative_prompt is not None and negative_prompt_embeds is not None:
1323 | raise ValueError(
1324 | f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
1325 | f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
1326 | )
1327 |
1328 | if prompt_embeds is not None and negative_prompt_embeds is not None:
1329 | if prompt_embeds.shape != negative_prompt_embeds.shape:
1330 | raise ValueError(
1331 | "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
1332 | f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
1333 | f" {negative_prompt_embeds.shape}."
1334 | )
1335 |
1336 | def prepare_xl_latents_from_sd_1_5(
1337 | self, latent, timestep, batch_size, num_images_per_prompt, height, width, dtype, device, generator=None,
1338 | add_noise=True
1339 | ):
1340 | # sd1.5 latent -> img
1341 | image = self.vae_sd1_5.decode(latent / self.vae_sd1_5.config.scaling_factor, return_dict=False)[0]
1342 | do_denormalize = [True] * image.shape[0]
1343 | image = self.image_processor_sd1_5.postprocess(image, output_type='pil', do_denormalize=do_denormalize)[0]
1344 | image = image.resize((width, height))
1345 | # image.save('./test_img/image_sd1_5.jpg')
1346 | # input()
1347 |
1348 | if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
1349 | raise ValueError(
1350 | f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
1351 | )
1352 |
1353 | # Offload text encoder if `enable_model_cpu_offload` was enabled
1354 | if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1355 | self.text_encoder_2.to("cpu")
1356 | torch.cuda.empty_cache()
1357 |
1358 | image = self.image_processor.preprocess(image)
1359 |
1360 | image = image.to(device=device, dtype=dtype)
1361 |
1362 | batch_size = batch_size * num_images_per_prompt
1363 |
1364 | if image.shape[1] == 4:
1365 | init_latents = image
1366 |
1367 | else:
1368 | # make sure the VAE is in float32 mode, as it overflows in float16
1369 | if self.vae.config.force_upcast:
1370 | image = image.float()
1371 | self.vae.to(dtype=torch.float32)
1372 |
1373 | if isinstance(generator, list) and len(generator) != batch_size:
1374 | raise ValueError(
1375 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
1376 | f" size of {batch_size}. Make sure the batch size matches the length of the generators."
1377 | )
1378 |
1379 | elif isinstance(generator, list):
1380 | init_latents = [
1381 | self.vae.encode(image[i: i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
1382 | ]
1383 | init_latents = torch.cat(init_latents, dim=0)
1384 | else:
1385 | init_latents = self.vae.encode(image).latent_dist.sample(generator)
1386 |
1387 | if self.vae.config.force_upcast:
1388 | self.vae.to(dtype)
1389 |
1390 | init_latents = init_latents.to(dtype)
1391 | init_latents = self.vae.config.scaling_factor * init_latents
1392 |
1393 | if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
1394 | # expand init_latents for batch_size
1395 | additional_image_per_prompt = batch_size // init_latents.shape[0]
1396 | init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
1397 | elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
1398 | raise ValueError(
1399 | f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
1400 | )
1401 | else:
1402 | init_latents = torch.cat([init_latents], dim=0)
1403 |
1404 | if add_noise:
1405 | shape = init_latents.shape
1406 | noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
1407 | # get latents
1408 | init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
1409 |
1410 | latents = init_latents
1411 |
1412 | return latents
1413 |
1414 | def sd1_5_add_noise(self, init_latents, timestep, generator, device, dtype):
1415 | shape = init_latents.shape
1416 | noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
1417 | # get latents
1418 | init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
1419 |
1420 | image = self.vae_sd1_5.decode(init_latents / self.vae_sd1_5.config.scaling_factor, return_dict=False)[0]
1421 | do_denormalize = [True] * image.shape[0]
1422 | image = self.image_processor_sd1_5.postprocess(image, output_type='pil', do_denormalize=do_denormalize)[0]
1423 | # image.save(f'./test_img/noisy_image_sd1_5_{int(timestep)}.jpg')
1424 |
1425 | return init_latents
1426 |
1427 | def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None):
1428 | # get the original timestep using init_timestep
1429 | if denoising_start is None:
1430 | init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
1431 | t_start = max(num_inference_steps - init_timestep, 0)
1432 | else:
1433 | t_start = 0
1434 |
1435 | timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
1436 |
1437 | # Strength is irrelevant if we directly request a timestep to start at;
1438 | # that is, strength is determined by the denoising_start instead.
1439 | if denoising_start is not None:
1440 | discrete_timestep_cutoff = int(
1441 | round(
1442 | self.scheduler.config.num_train_timesteps
1443 | - (denoising_start * self.scheduler.config.num_train_timesteps)
1444 | )
1445 | )
1446 | timesteps = list(filter(lambda ts: ts < discrete_timestep_cutoff, timesteps))
1447 | return torch.tensor(timesteps), len(timesteps)
1448 |
1449 | return timesteps, num_inference_steps - t_start
1450 |
1451 | def prepare_latents_sd1_5(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
1452 | shape = (batch_size, num_channels_latents, height // self.vae_scale_factor_sd1_5, width // self.vae_scale_factor_sd1_5)
1453 | if isinstance(generator, list) and len(generator) != batch_size:
1454 | raise ValueError(
1455 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
1456 | f" size of {batch_size}. Make sure the batch size matches the length of the generators."
1457 | )
1458 |
1459 | if latents is None:
1460 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
1461 | else:
1462 | latents = latents.to(device)
1463 |
1464 | # scale the initial noise by the standard deviation required by the scheduler
1465 | latents = latents * self.scheduler_sd1_5.init_noise_sigma
1466 | return latents
1467 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate == 0.18.0
2 | controlnet_aux == 0.0.7
3 | opencv_python_headless == 4.8.0.76
4 | dataclasses == 0.6
5 | diffusers == 0.20.0
6 | einops == 0.4.1
7 | huggingface_hub == 0.17.2
8 | imageio == 2.26.0
9 | matplotlib == 3.7.1
10 | numpy == 1.23.3
11 | safetensors == 0.3.3
12 | tqdm == 4.64.1
13 | transformers == 4.25.1
14 | Pillow == 10.2.0
--------------------------------------------------------------------------------
/scripts/__pycache__/inference_controlnet.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/showlab/X-Adapter/e7348cb1b83e039dec8fe6f664266f1594e5103f/scripts/__pycache__/inference_controlnet.cpython-310.pyc
--------------------------------------------------------------------------------
/scripts/__pycache__/inference_ctrlnet_tile.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/showlab/X-Adapter/e7348cb1b83e039dec8fe6f664266f1594e5103f/scripts/__pycache__/inference_ctrlnet_tile.cpython-310.pyc
--------------------------------------------------------------------------------
/scripts/__pycache__/inference_lora.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/showlab/X-Adapter/e7348cb1b83e039dec8fe6f664266f1594e5103f/scripts/__pycache__/inference_lora.cpython-310.pyc
--------------------------------------------------------------------------------
/scripts/__pycache__/utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/showlab/X-Adapter/e7348cb1b83e039dec8fe6f664266f1594e5103f/scripts/__pycache__/utils.cpython-310.pyc
--------------------------------------------------------------------------------
/scripts/inference_controlnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import json
3 | import os
4 | import numpy as np
5 | import cv2
6 | import matplotlib
7 | from tqdm import tqdm
8 | from diffusers import DiffusionPipeline
9 | from diffusers import DPMSolverMultistepScheduler
10 | from diffusers.utils import load_image
11 | from torch import Generator
12 | from PIL import Image
13 | from packaging import version
14 |
15 | from transformers import CLIPTextModel, CLIPTokenizer, AutoTokenizer, PretrainedConfig
16 |
17 | import diffusers
18 | from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel, ControlNetModel, T2IAdapter
19 | from diffusers.optimization import get_scheduler
20 | from diffusers.training_utils import EMAModel
21 | from diffusers.utils import check_min_version, deprecate, is_wandb_available
22 | from diffusers.utils.import_utils import is_xformers_available
23 |
24 | from model.unet_adapter import UNet2DConditionModel
25 | from model.adapter import Adapter_XL
26 | from pipeline.pipeline_sd_xl_adapter_controlnet import StableDiffusionXLAdapterControlnetPipeline
27 | from controlnet_aux import MidasDetector, CannyDetector
28 |
29 | from scripts.utils import str2float
30 |
31 |
32 | def import_model_class_from_model_name_or_path(
33 | pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
34 | ):
35 | text_encoder_config = PretrainedConfig.from_pretrained(
36 | pretrained_model_name_or_path, subfolder=subfolder, revision=revision
37 | )
38 | model_class = text_encoder_config.architectures[0]
39 |
40 | if model_class == "CLIPTextModel":
41 | from transformers import CLIPTextModel
42 |
43 | return CLIPTextModel
44 | elif model_class == "CLIPTextModelWithProjection":
45 | from transformers import CLIPTextModelWithProjection
46 |
47 | return CLIPTextModelWithProjection
48 | else:
49 | raise ValueError(f"{model_class} is not supported.")
50 |
51 |
52 | def inference_controlnet(args):
53 | device = 'cuda'
54 | weight_dtype = torch.float16
55 |
56 | controlnet_condition_scale_list = str2float(args.controlnet_condition_scale_list)
57 | adapter_guidance_start_list = str2float(args.adapter_guidance_start_list)
58 | adapter_condition_scale_list = str2float(args.adapter_condition_scale_list)
59 |
60 | path = args.base_path
61 | path_sdxl = args.sdxl_path
62 | path_vae_sdxl = args.path_vae_sdxl
63 | adapter_path = args.adapter_checkpoint
64 |
65 | if args.condition_type == "canny":
66 | controlnet_path = args.controlnet_canny_path
67 | canny = CannyDetector()
68 | elif args.condition_type == "depth":
69 | controlnet_path = args.controlnet_depth_path # todo: haven't defined in args
70 | depth = MidasDetector.from_pretrained("lllyasviel/Annotators")
71 | else:
72 | raise NotImplementedError("not implemented yet")
73 |
74 | prompt = args.prompt
75 | if args.prompt_sd1_5 is None:
76 | prompt_sd1_5 = prompt
77 | else:
78 | prompt_sd1_5 = args.prompt_sd1_5
79 |
80 | if args.negative_prompt is None:
81 | negative_prompt = "(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck"
82 | else:
83 | negative_prompt = args.negative_prompt
84 |
85 | torch.set_grad_enabled(False)
86 | torch.backends.cudnn.benchmark = True
87 |
88 | # load controlnet
89 | controlnet = ControlNetModel.from_pretrained(
90 | controlnet_path, torch_dtype=weight_dtype
91 | )
92 | print('successfully load controlnet')
93 |
94 | input_image = Image.open(args.input_image_path)
95 | # input_image = input_image.resize((512, 512), Image.LANCZOS)
96 | input_image = input_image.resize((args.width_sd1_5, args.height_sd1_5), Image.LANCZOS)
97 | if args.condition_type == "canny":
98 | control_image = canny(input_image)
99 | control_image.save(f'{args.save_path}/{prompt[:10]}_canny_condition.png')
100 | elif args.condition_type == "depth":
101 | control_image = depth(input_image)
102 | control_image.save(f'{args.save_path}/{prompt[:10]}_depth_condition.png')
103 |
104 | # load adapter
105 | adapter = Adapter_XL()
106 | ckpt = torch.load(adapter_path)
107 | adapter.load_state_dict(ckpt)
108 | adapter.to(weight_dtype)
109 | print('successfully load adapter')
110 | # load SD1.5
111 | noise_scheduler_sd1_5 = DDPMScheduler.from_pretrained(
112 | path, subfolder="scheduler"
113 | )
114 | tokenizer_sd1_5 = CLIPTokenizer.from_pretrained(
115 | path, subfolder="tokenizer", revision=None, torch_dtype=weight_dtype
116 | )
117 | text_encoder_sd1_5 = CLIPTextModel.from_pretrained(
118 | path, subfolder="text_encoder", revision=None, torch_dtype=weight_dtype
119 | )
120 | vae_sd1_5 = AutoencoderKL.from_pretrained(
121 | path, subfolder="vae", revision=None, torch_dtype=weight_dtype
122 | )
123 | unet_sd1_5 = UNet2DConditionModel.from_pretrained(
124 | path, subfolder="unet", revision=None, torch_dtype=weight_dtype
125 | )
126 | print('successfully load SD1.5')
127 | # load SDXL
128 | tokenizer_one = AutoTokenizer.from_pretrained(
129 | path_sdxl, subfolder="tokenizer", revision=None, use_fast=False, torch_dtype=weight_dtype
130 | )
131 | tokenizer_two = AutoTokenizer.from_pretrained(
132 | path_sdxl, subfolder="tokenizer_2", revision=None, use_fast=False, torch_dtype=weight_dtype
133 | )
134 | # import correct text encoder classes
135 | text_encoder_cls_one = import_model_class_from_model_name_or_path(
136 | path_sdxl, None
137 | )
138 | text_encoder_cls_two = import_model_class_from_model_name_or_path(
139 | path_sdxl, None, subfolder="text_encoder_2"
140 | )
141 | # Load scheduler and models
142 | noise_scheduler = DDPMScheduler.from_pretrained(path_sdxl, subfolder="scheduler")
143 | text_encoder_one = text_encoder_cls_one.from_pretrained(
144 | path_sdxl, subfolder="text_encoder", revision=None, torch_dtype=weight_dtype
145 | )
146 | text_encoder_two = text_encoder_cls_two.from_pretrained(
147 | path_sdxl, subfolder="text_encoder_2", revision=None, torch_dtype=weight_dtype
148 | )
149 | vae = AutoencoderKL.from_pretrained(
150 | path_vae_sdxl, revision=None, torch_dtype=weight_dtype
151 | )
152 | unet = UNet2DConditionModel.from_pretrained(
153 | path_sdxl, subfolder="unet", revision=None, torch_dtype=weight_dtype
154 | )
155 | print('successfully load SDXL')
156 |
157 |
158 | if is_xformers_available():
159 | import xformers
160 |
161 | xformers_version = version.parse(xformers.__version__)
162 | if xformers_version == version.parse("0.0.16"):
163 | logger.warn(
164 | "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
165 | )
166 | unet.enable_xformers_memory_efficient_attention()
167 | unet_sd1_5.enable_xformers_memory_efficient_attention()
168 | controlnet.enable_xformers_memory_efficient_attention()
169 |
170 |
171 | with torch.inference_mode():
172 | gen = Generator("cuda")
173 | gen.manual_seed(args.seed)
174 | pipe = StableDiffusionXLAdapterControlnetPipeline(
175 | vae=vae,
176 | text_encoder=text_encoder_one,
177 | text_encoder_2=text_encoder_two,
178 | tokenizer=tokenizer_one,
179 | tokenizer_2=tokenizer_two,
180 | unet=unet,
181 | scheduler=noise_scheduler,
182 | vae_sd1_5=vae_sd1_5,
183 | text_encoder_sd1_5=text_encoder_sd1_5,
184 | tokenizer_sd1_5=tokenizer_sd1_5,
185 | unet_sd1_5=unet_sd1_5,
186 | scheduler_sd1_5=noise_scheduler_sd1_5,
187 | adapter=adapter,
188 | controlnet=controlnet
189 | )
190 |
191 | pipe.enable_model_cpu_offload()
192 |
193 | pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
194 | pipe.scheduler_sd1_5 = DPMSolverMultistepScheduler.from_config(pipe.scheduler_sd1_5.config)
195 | pipe.scheduler_sd1_5.config.timestep_spacing = "leading"
196 | pipe.unet.to(device=device, dtype=torch.float16, memory_format=torch.channels_last)
197 |
198 | for i in range(args.iter_num):
199 | for controlnet_condition_scale in controlnet_condition_scale_list:
200 | for adapter_guidance_start in adapter_guidance_start_list:
201 | for adapter_condition_scale in adapter_condition_scale_list:
202 | img = \
203 | pipe(prompt=prompt, negative_prompt=negative_prompt, prompt_sd1_5=prompt_sd1_5,
204 | width=args.width, height=args.height, height_sd1_5=args.height_sd1_5,
205 | width_sd1_5=args.width_sd1_5, image=control_image,
206 | num_inference_steps=args.num_inference_steps, guidance_scale=args.guidance_scale,
207 | num_images_per_prompt=1, generator=gen,
208 | controlnet_conditioning_scale=controlnet_condition_scale,
209 | adapter_condition_scale=adapter_condition_scale,
210 | adapter_guidance_start=adapter_guidance_start).images[0]
211 | img.save(
212 | f"{args.save_path}/{prompt[:10]}_{i}_ccs_{controlnet_condition_scale:.2f}_ags_{adapter_guidance_start:.2f}_acs_{adapter_condition_scale:.2f}.png")
213 |
214 | print(f"results saved in {args.save_path}")
215 |
--------------------------------------------------------------------------------
/scripts/inference_ctrlnet_tile.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import json
3 | import os
4 | import numpy as np
5 | import cv2
6 | from tqdm import tqdm
7 | from diffusers import DiffusionPipeline
8 | from diffusers import DPMSolverMultistepScheduler
9 | from diffusers.utils import load_image
10 | from torch import Generator
11 | from safetensors.torch import load_file
12 | from PIL import Image
13 | from packaging import version
14 | from huggingface_hub import HfApi
15 | from pathlib import Path
16 |
17 | from transformers import CLIPTextModel, CLIPTokenizer, AutoTokenizer, PretrainedConfig
18 |
19 | import diffusers
20 | from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel, ControlNetModel, T2IAdapter, StableDiffusionControlNetPipeline
21 | from diffusers.optimization import get_scheduler
22 | from diffusers.training_utils import EMAModel
23 | from diffusers.utils import check_min_version, deprecate, is_wandb_available
24 | from diffusers.utils.import_utils import is_xformers_available
25 |
26 | from model.unet_adapter import UNet2DConditionModel as UNet2DConditionModel_v2
27 | from model.adapter import Adapter_XL
28 | from pipeline.pipeline_sd_xl_adapter_controlnet_img2img import StableDiffusionXLAdapterControlnetI2IPipeline
29 | from scripts.utils import str2float
30 |
31 | def import_model_class_from_model_name_or_path(
32 | pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
33 | ):
34 | text_encoder_config = PretrainedConfig.from_pretrained(
35 | pretrained_model_name_or_path, subfolder=subfolder, revision=revision
36 | )
37 | model_class = text_encoder_config.architectures[0]
38 |
39 | if model_class == "CLIPTextModel":
40 | from transformers import CLIPTextModel
41 |
42 | return CLIPTextModel
43 | elif model_class == "CLIPTextModelWithProjection":
44 | from transformers import CLIPTextModelWithProjection
45 |
46 | return CLIPTextModelWithProjection
47 | else:
48 | raise ValueError(f"{model_class} is not supported.")
49 |
50 |
51 | def resize_for_condition_image(input_image: Image, resolution: int):
52 | input_image = input_image.convert("RGB")
53 | W, H = input_image.size
54 | k = float(resolution) / min(H, W)
55 | H *= k
56 | W *= k
57 | H = int(round(H / 64.0)) * 64
58 | W = int(round(W / 64.0)) * 64
59 | img = input_image.resize((W, H), resample=Image.LANCZOS)
60 | return img
61 |
62 |
63 | def inference_ctrlnet_tile(args):
64 | device = 'cuda'
65 | weight_dtype = torch.float16
66 |
67 | controlnet_condition_scale_list = str2float(args.controlnet_condition_scale_list)
68 | adapter_guidance_start_list = str2float(args.adapter_guidance_start_list)
69 | adapter_condition_scale_list = str2float(args.adapter_condition_scale_list)
70 |
71 | path = args.base_path
72 | path_sdxl = args.sdxl_path
73 | path_vae_sdxl = args.path_vae_sdxl
74 | adapter_path = args.adapter_checkpoint
75 | controlnet_path = args.controlnet_tile_path
76 |
77 | prompt = args.prompt
78 | if args.prompt_sd1_5 is None:
79 | prompt_sd1_5 = prompt
80 | else:
81 | prompt_sd1_5 = args.prompt_sd1_5
82 |
83 | if args.negative_prompt is None:
84 | negative_prompt = "(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck"
85 | else:
86 | negative_prompt = args.negative_prompt
87 |
88 | torch.set_grad_enabled(False)
89 | torch.backends.cudnn.benchmark = True
90 |
91 | # load controlnet
92 | controlnet = ControlNetModel.from_pretrained(
93 | controlnet_path, torch_dtype=weight_dtype
94 | )
95 |
96 | source_image = Image.open(args.input_image_path)
97 | # control_image = resize_for_condition_image(source_image, 512)
98 | input_image = source_image.convert("RGB")
99 | control_image = input_image.resize((args.width_sd1_5, args.height_sd1_5), resample=Image.LANCZOS)
100 |
101 | print('successfully load controlnet')
102 | # load adapter
103 | adapter = Adapter_XL()
104 | ckpt = torch.load(adapter_path)
105 | adapter.load_state_dict(ckpt)
106 | adapter.to(weight_dtype)
107 | print('successfully load adapter')
108 | # load SD1.5
109 | noise_scheduler_sd1_5 = DDPMScheduler.from_pretrained(
110 | path, subfolder="scheduler"
111 | )
112 | tokenizer_sd1_5 = CLIPTokenizer.from_pretrained(
113 | path, subfolder="tokenizer", revision=None, torch_dtype=weight_dtype
114 | )
115 | text_encoder_sd1_5 = CLIPTextModel.from_pretrained(
116 | path, subfolder="text_encoder", revision=None, torch_dtype=weight_dtype
117 | )
118 | vae_sd1_5 = AutoencoderKL.from_pretrained(
119 | path, subfolder="vae", revision=None, torch_dtype=weight_dtype
120 | )
121 | unet_sd1_5 = UNet2DConditionModel_v2.from_pretrained(
122 | path, subfolder="unet", revision=None, torch_dtype=weight_dtype
123 | )
124 | print('successfully load SD1.5')
125 | # load SDXL
126 | tokenizer_one = AutoTokenizer.from_pretrained(
127 | path_sdxl, subfolder="tokenizer", revision=None, use_fast=False, torch_dtype=weight_dtype
128 | )
129 | tokenizer_two = AutoTokenizer.from_pretrained(
130 | path_sdxl, subfolder="tokenizer_2", revision=None, use_fast=False, torch_dtype=weight_dtype
131 | )
132 | # import correct text encoder classes
133 | text_encoder_cls_one = import_model_class_from_model_name_or_path(
134 | path_sdxl, None
135 | )
136 | text_encoder_cls_two = import_model_class_from_model_name_or_path(
137 | path_sdxl, None, subfolder="text_encoder_2"
138 | )
139 | # Load scheduler and models
140 | noise_scheduler = DDPMScheduler.from_pretrained(path_sdxl, subfolder="scheduler")
141 | text_encoder_one = text_encoder_cls_one.from_pretrained(
142 | path_sdxl, subfolder="text_encoder", revision=None, torch_dtype=weight_dtype
143 | )
144 | text_encoder_two = text_encoder_cls_two.from_pretrained(
145 | path_sdxl, subfolder="text_encoder_2", revision=None, torch_dtype=weight_dtype
146 | )
147 | vae = AutoencoderKL.from_pretrained(
148 | path_vae_sdxl, revision=None, torch_dtype=weight_dtype
149 | )
150 | unet = UNet2DConditionModel_v2.from_pretrained(
151 | path_sdxl, subfolder="unet", revision=None, torch_dtype=weight_dtype
152 | )
153 | print('successfully load SDXL')
154 |
155 | if is_xformers_available():
156 | import xformers
157 |
158 | xformers_version = version.parse(xformers.__version__)
159 | if xformers_version == version.parse("0.0.16"):
160 | logger.warn(
161 | "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
162 | )
163 | unet.enable_xformers_memory_efficient_attention()
164 | unet_sd1_5.enable_xformers_memory_efficient_attention()
165 | controlnet.enable_xformers_memory_efficient_attention()
166 |
167 | with torch.inference_mode():
168 | gen = Generator(device)
169 | gen.manual_seed(args.seed)
170 | pipe = StableDiffusionXLAdapterControlnetI2IPipeline(
171 | vae=vae,
172 | text_encoder=text_encoder_one,
173 | text_encoder_2=text_encoder_two,
174 | tokenizer=tokenizer_one,
175 | tokenizer_2=tokenizer_two,
176 | unet=unet,
177 | scheduler=noise_scheduler,
178 | vae_sd1_5=vae_sd1_5,
179 | text_encoder_sd1_5=text_encoder_sd1_5,
180 | tokenizer_sd1_5=tokenizer_sd1_5,
181 | unet_sd1_5=unet_sd1_5,
182 | scheduler_sd1_5=noise_scheduler_sd1_5,
183 | adapter=adapter,
184 | controlnet=controlnet
185 | )
186 | pipe.enable_model_cpu_offload()
187 |
188 | pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
189 | pipe.scheduler_sd1_5 = DPMSolverMultistepScheduler.from_config(pipe.scheduler_sd1_5.config)
190 | pipe.scheduler_sd1_5.config.timestep_spacing = "leading"
191 | pipe.unet.to(device=device, dtype=weight_dtype, memory_format=torch.channels_last)
192 |
193 |
194 | for i in range(args.iter_num):
195 | for controlnet_condition_scale in controlnet_condition_scale_list:
196 | for adapter_guidance_start in adapter_guidance_start_list:
197 | for adapter_condition_scale in adapter_condition_scale_list:
198 | img = \
199 | pipe(prompt=prompt, negative_prompt=negative_prompt, prompt_sd1_5=prompt_sd1_5,
200 | width=args.width, height=args.height, height_sd1_5=args.height_sd1_5,
201 | width_sd1_5=args.width_sd1_5, source_img=control_image, image=control_image,
202 | num_inference_steps=args.num_inference_steps, guidance_scale=args.guidance_scale,
203 | num_images_per_prompt=1, generator=gen,
204 | controlnet_conditioning_scale=controlnet_condition_scale,
205 | adapter_condition_scale=adapter_condition_scale,
206 | adapter_guidance_start=adapter_guidance_start).images[0]
207 | img.save(
208 | f"{args.save_path}/{prompt[:10]}_{i}_ccs_{controlnet_condition_scale:.2f}_ags_{adapter_guidance_start:.2f}_acs_{adapter_condition_scale:.2f}.png")
209 |
210 | print(f"results saved in {args.save_path}")
211 |
212 |
--------------------------------------------------------------------------------
/scripts/inference_lora.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import json
3 | import os
4 | import numpy as np
5 | import cv2
6 | from tqdm import tqdm
7 | from diffusers import DiffusionPipeline
8 | from diffusers import DPMSolverMultistepScheduler
9 | from diffusers.utils import load_image
10 | from torch import Generator
11 | from safetensors.torch import load_file
12 | from PIL import Image
13 | from packaging import version
14 |
15 | from transformers import CLIPTextModel, CLIPTokenizer, AutoTokenizer, PretrainedConfig
16 |
17 | import diffusers
18 | from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel, ControlNetModel, \
19 | T2IAdapter
20 | from diffusers.optimization import get_scheduler
21 | from diffusers.training_utils import EMAModel
22 | from diffusers.utils import check_min_version, deprecate, is_wandb_available
23 | from diffusers.utils.import_utils import is_xformers_available
24 |
25 | from model.unet_adapter import UNet2DConditionModel
26 | from pipeline.pipeline_sd_xl_adapter import StableDiffusionXLAdapterPipeline
27 | from model.adapter import Adapter_XL
28 | from scripts.utils import str2float
29 |
30 |
31 | def import_model_class_from_model_name_or_path(
32 | pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
33 | ):
34 | text_encoder_config = PretrainedConfig.from_pretrained(
35 | pretrained_model_name_or_path, subfolder=subfolder, revision=revision
36 | )
37 | model_class = text_encoder_config.architectures[0]
38 |
39 | if model_class == "CLIPTextModel":
40 | from transformers import CLIPTextModel
41 |
42 | return CLIPTextModel
43 | elif model_class == "CLIPTextModelWithProjection":
44 | from transformers import CLIPTextModelWithProjection
45 |
46 | return CLIPTextModelWithProjection
47 | else:
48 | raise ValueError(f"{model_class} is not supported.")
49 |
50 |
51 | def load_lora(pipeline, lora_model_path, alpha):
52 | state_dict = load_file(lora_model_path)
53 |
54 | LORA_PREFIX_UNET = 'lora_unet'
55 | LORA_PREFIX_TEXT_ENCODER = 'lora_te'
56 |
57 | visited = []
58 |
59 | # directly update weight in diffusers model
60 | for key in state_dict:
61 |
62 | # it is suggested to print out the key, it usually will be something like below
63 | # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
64 |
65 | # as we have set the alpha beforehand, so just skip
66 | if '.alpha' in key or key in visited:
67 | continue
68 |
69 | if 'text' in key:
70 | layer_infos = key.split('.')[0].split(LORA_PREFIX_TEXT_ENCODER + '_')[-1].split('_')
71 | curr_layer = pipeline.text_encoder_sd1_5
72 | else:
73 | layer_infos = key.split('.')[0].split(LORA_PREFIX_UNET + '_')[-1].split('_')
74 | curr_layer = pipeline.unet_sd1_5
75 |
76 | # find the target layer
77 | temp_name = layer_infos.pop(0)
78 | while len(layer_infos) > -1:
79 | try:
80 | curr_layer = curr_layer.__getattr__(temp_name)
81 | if len(layer_infos) > 0:
82 | temp_name = layer_infos.pop(0)
83 | elif len(layer_infos) == 0:
84 | break
85 | except Exception:
86 | if len(temp_name) > 0:
87 | temp_name += '_' + layer_infos.pop(0)
88 | else:
89 | temp_name = layer_infos.pop(0)
90 |
91 | # org_forward(x) + lora_up(lora_down(x)) * multiplier
92 | pair_keys = []
93 | if 'lora_down' in key:
94 | pair_keys.append(key.replace('lora_down', 'lora_up'))
95 | pair_keys.append(key)
96 | else:
97 | pair_keys.append(key)
98 | pair_keys.append(key.replace('lora_up', 'lora_down'))
99 |
100 | # update weight
101 | if len(state_dict[pair_keys[0]].shape) == 4:
102 | weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
103 | weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
104 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
105 | else:
106 | weight_up = state_dict[pair_keys[0]].to(torch.float32)
107 | weight_down = state_dict[pair_keys[1]].to(torch.float32)
108 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down)
109 |
110 | # update visited list
111 | for item in pair_keys:
112 | visited.append(item)
113 |
114 |
115 | def inference_lora(args):
116 | device = 'cuda'
117 | weight_dtype = torch.float16
118 |
119 | adapter_guidance_start_list = str2float(args.adapter_guidance_start_list)
120 | adapter_condition_scale_list = str2float(args.adapter_condition_scale_list)
121 |
122 | path = args.base_path
123 | path_sdxl = args.sdxl_path
124 | path_vae_sdxl = args.path_vae_sdxl
125 | adapter_path = args.adapter_checkpoint
126 | lora_model_path = args.lora_model_path
127 |
128 | prompt = args.prompt
129 | if args.prompt_sd1_5 is None:
130 | prompt_sd1_5 = prompt
131 | else:
132 | prompt_sd1_5 = args.prompt_sd1_5
133 |
134 | if args.negative_prompt is None:
135 | negative_prompt = "(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck"
136 | else:
137 | negative_prompt = args.negative_prompt
138 |
139 | torch.set_grad_enabled(False)
140 | torch.backends.cudnn.benchmark = True
141 |
142 | # load adapter
143 | adapter = Adapter_XL()
144 | ckpt = torch.load(adapter_path)
145 | adapter.load_state_dict(ckpt)
146 | print('successfully load adapter')
147 | # load SD1.5
148 | noise_scheduler_sd1_5 = DDPMScheduler.from_pretrained(
149 | path, subfolder="scheduler"
150 | )
151 | tokenizer_sd1_5 = CLIPTokenizer.from_pretrained(
152 | path, subfolder="tokenizer", revision=None
153 | )
154 | text_encoder_sd1_5 = CLIPTextModel.from_pretrained(
155 | path, subfolder="text_encoder", revision=None
156 | )
157 | vae_sd1_5 = AutoencoderKL.from_pretrained(
158 | path, subfolder="vae", revision=None
159 | )
160 | unet_sd1_5 = UNet2DConditionModel.from_pretrained(
161 | path, subfolder="unet", revision=None
162 | )
163 | print('successfully load SD1.5')
164 | # load SDXL
165 | tokenizer_one = AutoTokenizer.from_pretrained(
166 | path_sdxl, subfolder="tokenizer", revision=None, use_fast=False
167 | )
168 | tokenizer_two = AutoTokenizer.from_pretrained(
169 | path_sdxl, subfolder="tokenizer_2", revision=None, use_fast=False
170 | )
171 | # import correct text encoder classes
172 | text_encoder_cls_one = import_model_class_from_model_name_or_path(
173 | path_sdxl, None
174 | )
175 | text_encoder_cls_two = import_model_class_from_model_name_or_path(
176 | path_sdxl, None, subfolder="text_encoder_2"
177 | )
178 | # Load scheduler and models
179 | noise_scheduler = DDPMScheduler.from_pretrained(path_sdxl, subfolder="scheduler")
180 | text_encoder_one = text_encoder_cls_one.from_pretrained(
181 | path_sdxl, subfolder="text_encoder", revision=None
182 | )
183 | text_encoder_two = text_encoder_cls_two.from_pretrained(
184 | path_sdxl, subfolder="text_encoder_2", revision=None
185 | )
186 | vae = AutoencoderKL.from_pretrained(
187 | path_vae_sdxl, revision=None
188 | )
189 | unet = UNet2DConditionModel.from_pretrained(
190 | path_sdxl, subfolder="unet", revision=None
191 | )
192 | print('successfully load SDXL')
193 |
194 | if is_xformers_available():
195 | import xformers
196 |
197 | xformers_version = version.parse(xformers.__version__)
198 | if xformers_version == version.parse("0.0.16"):
199 | logger.warn(
200 | "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
201 | )
202 | unet.enable_xformers_memory_efficient_attention()
203 | unet_sd1_5.enable_xformers_memory_efficient_attention()
204 |
205 | with torch.inference_mode():
206 | gen = Generator("cuda")
207 | gen.manual_seed(args.seed)
208 |
209 | pipe = StableDiffusionXLAdapterPipeline(
210 | vae=vae,
211 | text_encoder=text_encoder_one,
212 | text_encoder_2=text_encoder_two,
213 | tokenizer=tokenizer_one,
214 | tokenizer_2=tokenizer_two,
215 | unet=unet,
216 | scheduler=noise_scheduler,
217 | vae_sd1_5=vae_sd1_5,
218 | text_encoder_sd1_5=text_encoder_sd1_5,
219 | tokenizer_sd1_5=tokenizer_sd1_5,
220 | unet_sd1_5=unet_sd1_5,
221 | scheduler_sd1_5=noise_scheduler_sd1_5,
222 | adapter=adapter,
223 | )
224 | # load lora
225 | load_lora(pipe, lora_model_path, 1)
226 | print('successfully load lora')
227 |
228 | pipe.to('cuda', weight_dtype)
229 | pipe.enable_model_cpu_offload()
230 | pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
231 | pipe.scheduler_sd1_5 = DPMSolverMultistepScheduler.from_config(pipe.scheduler_sd1_5.config)
232 | pipe.scheduler_sd1_5.config.timestep_spacing = "leading"
233 |
234 | for i in range(args.iter_num):
235 | for adapter_guidance_start in adapter_guidance_start_list:
236 | for adapter_condition_scale in adapter_condition_scale_list:
237 | img = \
238 | pipe(prompt=prompt, prompt_sd1_5=prompt_sd1_5, negative_prompt=negative_prompt, width=args.width,
239 | height=args.height, height_sd1_5=args.height_sd1_5, width_sd1_5=args.width_sd1_5,
240 | num_inference_steps=args.num_inference_steps, guidance_scale=args.guidance_scale,
241 | num_images_per_prompt=1, generator=gen,
242 | adapter_guidance_start=adapter_guidance_start,
243 | adapter_condition_scale=adapter_condition_scale).images[0]
244 | img.save(
245 | f"{args.save_path}/{prompt[:10]}_{i}_ags_{adapter_guidance_start:.2f}_acs_{adapter_condition_scale:.2f}.png")
246 | print(f"results saved in {args.save_path}")
247 |
248 |
249 |
250 |
251 |
252 |
253 |
--------------------------------------------------------------------------------
/scripts/utils.py:
--------------------------------------------------------------------------------
1 | def str2float(x):
2 | for i in range(len(x)):
3 | x[i] = float(x[i])
4 | return x
5 |
--------------------------------------------------------------------------------