├── LICENSE
├── README.md
├── assets
├── demo
│ ├── crop_and_resize_cmp.jpg
│ ├── image-to-image.jpg
│ ├── image_variations.jpg
│ ├── inpainting.jpg
│ ├── ip_adpter_plus_image_variations.jpg
│ ├── ip_adpter_plus_multi.jpg
│ ├── multi_prompts.jpg
│ ├── sd15_face.jpg
│ ├── sdxl_cmp.jpg
│ ├── structural_cond.jpg
│ └── t2i-adapter_demo.jpg
├── figs
│ └── fig1.png
├── images
│ ├── ai_face.png
│ ├── ai_face2.png
│ ├── girl.png
│ ├── river.png
│ ├── statue.png
│ ├── stone.png
│ ├── vermeer.jpg
│ └── woman.png
├── inpainting
│ ├── image.png
│ └── mask.png
└── structure_controls
│ ├── depth.png
│ ├── depth2.png
│ └── openpose.png
├── demo.ipynb
├── ip_adapter-full-face_demo.ipynb
├── ip_adapter-plus-face_demo.ipynb
├── ip_adapter-plus_demo.ipynb
├── ip_adapter-plus_sdxl_demo.ipynb
├── ip_adapter
├── __init__.py
├── attention_processor.py
├── attention_processor_faceid.py
├── custom_pipelines.py
├── ip_adapter.py
├── ip_adapter_faceid.py
├── ip_adapter_faceid_separate.py
├── resampler.py
├── test_resampler.py
└── utils.py
├── ip_adapter_controlnet_demo_new.ipynb
├── ip_adapter_demo.ipynb
├── ip_adapter_multimodal_prompts_demo.ipynb
├── ip_adapter_sdxl_controlnet_demo.ipynb
├── ip_adapter_sdxl_demo.ipynb
├── ip_adapter_sdxl_plus-face_demo.ipynb
├── ip_adapter_t2i-adapter_demo.ipynb
├── ip_adapter_t2i_demo.ipynb
├── pyproject.toml
├── tutorial_train.py
├── tutorial_train_faceid.py
├── tutorial_train_plus.py
├── tutorial_train_sdxl.py
├── visualization_attnmap_faceid.ipynb
└── visualization_attnmap_sdxl_plus-face.ipynb
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Stable Diffusion IP-Adapter with a negative image prompt
2 | Experimental implementation of negative image prompts. An image embedding is used for the unconditioned embedding in the same way as the negative prompt. Read the write-up [here](https://stable-diffusion-art.com/negative-image-prompt).
3 |
4 | Update: [ComfyUI version](https://github.com/sagiodev/ComfyUI_IPAdapter_plus)
5 |
6 | ## Usage
7 | Run `demo.ipynb` for a GUI. [](https://colab.research.google.com/github/sagiodev/IP-Adapter-Negative/blob/main/demo.ipynb)
8 |
9 | You can specify any or all of the following
10 | - Text prompt
11 | - Image prompt
12 | - Negative text prompt
13 | - Negative image prompt
14 |
15 | It seems to behave the best when:
16 | 1. Supply a positive and a negative image prompt. Leave the text prompts empty.
17 | 2. Supply a text prompt and a negative image prompt. Leave the image prompt empty.
18 |
19 | You will need to adjust the positive and negative prompt weight to get the desired effect.
20 |
21 | ## Samples
22 | Cherry-picked examples 🤣 (Positive and negative image prompts only. Text prompts not used.)
23 |
24 | 
25 |
26 | 
27 |
28 | 
29 |
30 |
31 | ## IP-adapters
32 | Supports IP-adapter and IP-adapter Plus for SD 1.5
33 |
34 |
35 |
36 |
37 | # ___***IP-Adapter: Text Compatible Image Prompt Adapter for Text-to-Image Diffusion Models***___
38 |
39 |
40 |
41 |
42 | [](https://github.com/tencent-ailab/IP-Adapter/)
43 |
44 |
45 | ---
46 |
47 |
48 | ## Introduction
49 |
50 | we present IP-Adapter, an effective and lightweight adapter to achieve image prompt capability for the pre-trained text-to-image diffusion models. An IP-Adapter with only 22M parameters can achieve comparable or even better performance to a fine-tuned image prompt model. IP-Adapter can be generalized not only to other custom models fine-tuned from the same base model, but also to controllable generation using existing controllable tools. Moreover, the image prompt can also work well with the text prompt to accomplish multimodal image generation.
51 |
52 | 
53 |
54 | ## Release
55 | - [2024/01/04] 🔥 Add an experimental version of IP-Adapter-FaceID for SDXL, more information can be found [here](https://huggingface.co/h94/IP-Adapter-FaceID).
56 | - [2023/12/29] 🔥 Add an experimental version of IP-Adapter-FaceID-PlusV2, more information can be found [here](https://huggingface.co/h94/IP-Adapter-FaceID).
57 | - [2023/12/27] 🔥 Add an experimental version of IP-Adapter-FaceID-Plus, more information can be found [here](https://huggingface.co/h94/IP-Adapter-FaceID).
58 | - [2023/12/20] 🔥 Add an experimental version of IP-Adapter-FaceID, more information can be found [here](https://huggingface.co/h94/IP-Adapter-FaceID).
59 | - [2023/11/22] IP-Adapter is available in [Diffusers](https://github.com/huggingface/diffusers/pull/5713) thanks to Diffusers Team.
60 | - [2023/11/10] 🔥 Add an updated version of IP-Adapter-Face. The demo is [here](ip_adapter-full-face_demo.ipynb).
61 | - [2023/11/05] 🔥 Add text-to-image [demo](ip_adapter_t2i_demo.ipynb) with IP-Adapter and [Kandinsky 2.2 Prior](https://huggingface.co/kandinsky-community/kandinsky-2-2-prior)
62 | - [2023/11/02] Support [safetensors](https://github.com/huggingface/safetensors)
63 | - [2023/9/08] 🔥 Update a new version of IP-Adapter with SDXL_1.0. More information can be found [here](#sdxl_10).
64 | - [2023/9/05] 🔥🔥🔥 IP-Adapter is supported in [WebUI](https://github.com/Mikubill/sd-webui-controlnet/discussions/2039) and [ComfyUI](https://github.com/laksjdjf/IPAdapter-ComfyUI) (or [ComfyUI_IPAdapter_plus](https://github.com/cubiq/ComfyUI_IPAdapter_plus)).
65 | - [2023/8/30] 🔥 Add an IP-Adapter with face image as prompt. The demo is [here](ip_adapter-plus-face_demo.ipynb).
66 | - [2023/8/29] 🔥 Release the training code.
67 | - [2023/8/23] 🔥 Add code and models of IP-Adapter with fine-grained features. The demo is [here](ip_adapter-plus_demo.ipynb).
68 | - [2023/8/18] 🔥 Add code and models for [SDXL 1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0). The demo is [here](ip_adapter_sdxl_demo.ipynb).
69 | - [2023/8/16] 🔥 We release the code and models.
70 |
71 |
72 | ## Installation
73 |
74 | ```
75 | # install latest diffusers
76 | pip install diffusers==0.22.1
77 |
78 | # install ip-adapter
79 | pip install git+https://github.com/tencent-ailab/IP-Adapter.git
80 |
81 | # download the models
82 | cd IP-Adapter
83 | git lfs install
84 | git clone https://huggingface.co/h94/IP-Adapter
85 | mv IP-Adapter/models models
86 | mv IP-Adapter/sdxl_models sdxl_models
87 |
88 | # then you can use the notebook
89 | ```
90 |
91 | ## Download Models
92 |
93 | you can download models from [here](https://huggingface.co/h94/IP-Adapter). To run the demo, you should also download the following models:
94 | - [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5)
95 | - [stabilityai/sd-vae-ft-mse](https://huggingface.co/stabilityai/sd-vae-ft-mse)
96 | - [SG161222/Realistic_Vision_V4.0_noVAE](https://huggingface.co/SG161222/Realistic_Vision_V4.0_noVAE)
97 | - [ControlNet models](https://huggingface.co/lllyasviel)
98 |
99 | ## How to Use
100 |
101 | ### SD_1.5
102 |
103 | - [**ip_adapter_demo**](ip_adapter_demo.ipynb): image variations, image-to-image, and inpainting with image prompt.
104 | - [](https://colab.research.google.com/github/tencent-ailab/IP-Adapter/blob/main/ip_adapter_demo.ipynb)
105 |
106 | 
107 |
108 | 
109 |
110 | 
111 |
112 | - [**ip_adapter_controlnet_demo**](ip_adapter_controlnet_demo_new.ipynb), [**ip_adapter_t2i-adapter**](ip_adapter_t2i-adapter_demo.ipynb): structural generation with image prompt.
113 | - [](https://colab.research.google.com/github/tencent-ailab/IP-Adapter/blob/main/ip_adapter_controlnet_demo.ipynb)
114 |
115 | 
116 | 
117 |
118 | - [**ip_adapter_multimodal_prompts_demo**](ip_adapter_multimodal_prompts_demo.ipynb): generation with multimodal prompts.
119 | - [](https://colab.research.google.com/github/tencent-ailab/IP-Adapter/blob/main/ip_adapter_multimodal_prompts_demo.ipynb)
120 |
121 | 
122 |
123 | - [**ip_adapter-plus_demo**](ip_adapter-plus_demo.ipynb): the demo of IP-Adapter with fine-grained features.
124 |
125 | 
126 | 
127 |
128 | - [**ip_adapter-plus-face_demo**](ip_adapter-plus-face_demo.ipynb): generation with face image as prompt.
129 |
130 | 
131 |
132 | **Best Practice**
133 | - If you only use the image prompt, you can set the `scale=1.0` and `text_prompt=""`(or some generic text prompts, e.g. "best quality", you can also use any negative text prompt). If you lower the `scale`, more diverse images can be generated, but they may not be as consistent with the image prompt.
134 | - For multimodal prompts, you can adjust the `scale` to get the best results. In most cases, setting `scale=0.5` can get good results. For the version of SD 1.5, we recommend using community models to generate good images.
135 |
136 | **IP-Adapter for non-square images**
137 |
138 | As the image is center cropped in the default image processor of CLIP, IP-Adapter works best for square images. For the non square images, it will miss the information outside the center. But you can just resize to 224x224 for non-square images, the comparison is as follows:
139 |
140 | 
141 |
142 | ### SDXL_1.0
143 |
144 | - [**ip_adapter_sdxl_demo**](ip_adapter_sdxl_demo.ipynb): image variations with image prompt.
145 | - [**ip_adapter_sdxl_controlnet_demo**](ip_adapter_sdxl_controlnet_demo.ipynb): structural generation with image prompt.
146 |
147 | The comparison of **IP-Adapter_XL** with [Reimagine XL](https://clipdrop.co/stable-diffusion-reimagine) is shown as follows:
148 |
149 | 
150 |
151 | **Improvements in new version (2023.9.8)**:
152 | - **Switch to CLIP-ViT-H**: we trained the new IP-Adapter with [OpenCLIP-ViT-H-14](https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K) instead of [OpenCLIP-ViT-bigG-14](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k). Although ViT-bigG is much larger than ViT-H, our experimental results did not find a significant difference, and the smaller model can reduce the memory usage in the inference phase.
153 | - **A Faster and better training recipe**: In our previous version, training directly at a resolution of 1024x1024 proved to be highly inefficient. However, in the new version, we have implemented a more effective two-stage training strategy. Firstly, we perform pre-training at a resolution of 512x512. Then, we employ a multi-scale strategy for fine-tuning. (Maybe this training strategy can also be used to speed up the training of controlnet).
154 |
155 | ## How to Train
156 | For training, you should install [accelerate](https://github.com/huggingface/accelerate) and make your own dataset into a json file.
157 |
158 | ```
159 | accelerate launch --num_processes 8 --multi_gpu --mixed_precision "fp16" \
160 | tutorial_train.py \
161 | --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5/" \
162 | --image_encoder_path="{image_encoder_path}" \
163 | --data_json_file="{data.json}" \
164 | --data_root_path="{image_path}" \
165 | --mixed_precision="fp16" \
166 | --resolution=512 \
167 | --train_batch_size=8 \
168 | --dataloader_num_workers=4 \
169 | --learning_rate=1e-04 \
170 | --weight_decay=0.01 \
171 | --output_dir="{output_dir}" \
172 | --save_steps=10000
173 | ```
174 |
175 | Once training is complete, you can convert the weights with the following code:
176 |
177 | ```python
178 | import torch
179 | ckpt = "checkpoint-50000/pytorch_model.bin"
180 | sd = torch.load(ckpt, map_location="cpu")
181 | image_proj_sd = {}
182 | ip_sd = {}
183 | for k in sd:
184 | if k.startswith("unet"):
185 | pass
186 | elif k.startswith("image_proj_model"):
187 | image_proj_sd[k.replace("image_proj_model.", "")] = sd[k]
188 | elif k.startswith("adapter_modules"):
189 | ip_sd[k.replace("adapter_modules.", "")] = sd[k]
190 |
191 | torch.save({"image_proj": image_proj_sd, "ip_adapter": ip_sd}, "ip_adapter.bin")
192 | ```
193 |
194 | ## Third-party Usage
195 | - [IP-Adapter for WebUI](https://github.com/Mikubill/sd-webui-controlnet) [[release notes](https://github.com/Mikubill/sd-webui-controlnet/discussions/2039)]
196 | - IP-Adapter for ComfyUI [[IPAdapter-ComfyUI](https://github.com/laksjdjf/IPAdapter-ComfyUI) or [ComfyUI_IPAdapter_plus](https://github.com/cubiq/ComfyUI_IPAdapter_plus)]
197 | - [IP-Adapter for InvokeAI](https://github.com/invoke-ai/InvokeAI) [[release notes](https://github.com/invoke-ai/InvokeAI/releases/tag/v3.2.0)]
198 | - [IP-Adapter for AnimateDiff prompt travel](https://github.com/s9roll7/animatediff-cli-prompt-travel)
199 | - [Diffusers_IPAdapter](https://github.com/cubiq/Diffusers_IPAdapter): more features such as supporting multiple input images
200 | - [Official Diffusers ](https://github.com/huggingface/diffusers/pull/5713)
201 |
202 | ## Disclaimer
203 |
204 | This project strives to positively impact the domain of AI-driven image generation. Users are granted the freedom to create images using this tool, but they are expected to comply with local laws and utilize it in a responsible manner. **The developers do not assume any responsibility for potential misuse by users.**
205 |
206 | ## Citation
207 | If you find IP-Adapter useful for your research and applications, please cite using this BibTeX:
208 | ```bibtex
209 | @article{ye2023ip-adapter,
210 | title={IP-Adapter: Text Compatible Image Prompt Adapter for Text-to-Image Diffusion Models},
211 | author={Ye, Hu and Zhang, Jun and Liu, Sibo and Han, Xiao and Yang, Wei},
212 | booktitle={arXiv preprint arxiv:2308.06721},
213 | year={2023}
214 | }
215 | ```
216 |
--------------------------------------------------------------------------------
/assets/demo/crop_and_resize_cmp.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sagiodev/IP-Adapter-Negative/912a655ef473d887847c8d61bb66d106eb39a26c/assets/demo/crop_and_resize_cmp.jpg
--------------------------------------------------------------------------------
/assets/demo/image-to-image.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sagiodev/IP-Adapter-Negative/912a655ef473d887847c8d61bb66d106eb39a26c/assets/demo/image-to-image.jpg
--------------------------------------------------------------------------------
/assets/demo/image_variations.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sagiodev/IP-Adapter-Negative/912a655ef473d887847c8d61bb66d106eb39a26c/assets/demo/image_variations.jpg
--------------------------------------------------------------------------------
/assets/demo/inpainting.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sagiodev/IP-Adapter-Negative/912a655ef473d887847c8d61bb66d106eb39a26c/assets/demo/inpainting.jpg
--------------------------------------------------------------------------------
/assets/demo/ip_adpter_plus_image_variations.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sagiodev/IP-Adapter-Negative/912a655ef473d887847c8d61bb66d106eb39a26c/assets/demo/ip_adpter_plus_image_variations.jpg
--------------------------------------------------------------------------------
/assets/demo/ip_adpter_plus_multi.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sagiodev/IP-Adapter-Negative/912a655ef473d887847c8d61bb66d106eb39a26c/assets/demo/ip_adpter_plus_multi.jpg
--------------------------------------------------------------------------------
/assets/demo/multi_prompts.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sagiodev/IP-Adapter-Negative/912a655ef473d887847c8d61bb66d106eb39a26c/assets/demo/multi_prompts.jpg
--------------------------------------------------------------------------------
/assets/demo/sd15_face.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sagiodev/IP-Adapter-Negative/912a655ef473d887847c8d61bb66d106eb39a26c/assets/demo/sd15_face.jpg
--------------------------------------------------------------------------------
/assets/demo/sdxl_cmp.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sagiodev/IP-Adapter-Negative/912a655ef473d887847c8d61bb66d106eb39a26c/assets/demo/sdxl_cmp.jpg
--------------------------------------------------------------------------------
/assets/demo/structural_cond.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sagiodev/IP-Adapter-Negative/912a655ef473d887847c8d61bb66d106eb39a26c/assets/demo/structural_cond.jpg
--------------------------------------------------------------------------------
/assets/demo/t2i-adapter_demo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sagiodev/IP-Adapter-Negative/912a655ef473d887847c8d61bb66d106eb39a26c/assets/demo/t2i-adapter_demo.jpg
--------------------------------------------------------------------------------
/assets/figs/fig1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sagiodev/IP-Adapter-Negative/912a655ef473d887847c8d61bb66d106eb39a26c/assets/figs/fig1.png
--------------------------------------------------------------------------------
/assets/images/ai_face.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sagiodev/IP-Adapter-Negative/912a655ef473d887847c8d61bb66d106eb39a26c/assets/images/ai_face.png
--------------------------------------------------------------------------------
/assets/images/ai_face2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sagiodev/IP-Adapter-Negative/912a655ef473d887847c8d61bb66d106eb39a26c/assets/images/ai_face2.png
--------------------------------------------------------------------------------
/assets/images/girl.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sagiodev/IP-Adapter-Negative/912a655ef473d887847c8d61bb66d106eb39a26c/assets/images/girl.png
--------------------------------------------------------------------------------
/assets/images/river.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sagiodev/IP-Adapter-Negative/912a655ef473d887847c8d61bb66d106eb39a26c/assets/images/river.png
--------------------------------------------------------------------------------
/assets/images/statue.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sagiodev/IP-Adapter-Negative/912a655ef473d887847c8d61bb66d106eb39a26c/assets/images/statue.png
--------------------------------------------------------------------------------
/assets/images/stone.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sagiodev/IP-Adapter-Negative/912a655ef473d887847c8d61bb66d106eb39a26c/assets/images/stone.png
--------------------------------------------------------------------------------
/assets/images/vermeer.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sagiodev/IP-Adapter-Negative/912a655ef473d887847c8d61bb66d106eb39a26c/assets/images/vermeer.jpg
--------------------------------------------------------------------------------
/assets/images/woman.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sagiodev/IP-Adapter-Negative/912a655ef473d887847c8d61bb66d106eb39a26c/assets/images/woman.png
--------------------------------------------------------------------------------
/assets/inpainting/image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sagiodev/IP-Adapter-Negative/912a655ef473d887847c8d61bb66d106eb39a26c/assets/inpainting/image.png
--------------------------------------------------------------------------------
/assets/inpainting/mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sagiodev/IP-Adapter-Negative/912a655ef473d887847c8d61bb66d106eb39a26c/assets/inpainting/mask.png
--------------------------------------------------------------------------------
/assets/structure_controls/depth.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sagiodev/IP-Adapter-Negative/912a655ef473d887847c8d61bb66d106eb39a26c/assets/structure_controls/depth.png
--------------------------------------------------------------------------------
/assets/structure_controls/depth2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sagiodev/IP-Adapter-Negative/912a655ef473d887847c8d61bb66d106eb39a26c/assets/structure_controls/depth2.png
--------------------------------------------------------------------------------
/assets/structure_controls/openpose.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sagiodev/IP-Adapter-Negative/912a655ef473d887847c8d61bb66d106eb39a26c/assets/structure_controls/openpose.png
--------------------------------------------------------------------------------
/ip_adapter/__init__.py:
--------------------------------------------------------------------------------
1 | from .ip_adapter import IPAdapter, IPAdapterPlus, IPAdapterPlusXL, IPAdapterXL, IPAdapterFull
2 |
3 | __all__ = [
4 | "IPAdapter",
5 | "IPAdapterPlus",
6 | "IPAdapterPlusXL",
7 | "IPAdapterXL",
8 | "IPAdapterFull",
9 | ]
10 |
--------------------------------------------------------------------------------
/ip_adapter/attention_processor.py:
--------------------------------------------------------------------------------
1 | # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | class AttnProcessor(nn.Module):
8 | r"""
9 | Default processor for performing attention-related computations.
10 | """
11 |
12 | def __init__(
13 | self,
14 | hidden_size=None,
15 | cross_attention_dim=None,
16 | ):
17 | super().__init__()
18 |
19 | def __call__(
20 | self,
21 | attn,
22 | hidden_states,
23 | encoder_hidden_states=None,
24 | attention_mask=None,
25 | temb=None,
26 | ):
27 | residual = hidden_states
28 |
29 | if attn.spatial_norm is not None:
30 | hidden_states = attn.spatial_norm(hidden_states, temb)
31 |
32 | input_ndim = hidden_states.ndim
33 |
34 | if input_ndim == 4:
35 | batch_size, channel, height, width = hidden_states.shape
36 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
37 |
38 | batch_size, sequence_length, _ = (
39 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
40 | )
41 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
42 |
43 | if attn.group_norm is not None:
44 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
45 |
46 | query = attn.to_q(hidden_states)
47 |
48 | if encoder_hidden_states is None:
49 | encoder_hidden_states = hidden_states
50 | elif attn.norm_cross:
51 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
52 |
53 | key = attn.to_k(encoder_hidden_states)
54 | value = attn.to_v(encoder_hidden_states)
55 |
56 | query = attn.head_to_batch_dim(query)
57 | key = attn.head_to_batch_dim(key)
58 | value = attn.head_to_batch_dim(value)
59 |
60 | attention_probs = attn.get_attention_scores(query, key, attention_mask)
61 | hidden_states = torch.bmm(attention_probs, value)
62 | hidden_states = attn.batch_to_head_dim(hidden_states)
63 |
64 | # linear proj
65 | hidden_states = attn.to_out[0](hidden_states)
66 | # dropout
67 | hidden_states = attn.to_out[1](hidden_states)
68 |
69 | if input_ndim == 4:
70 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
71 |
72 | if attn.residual_connection:
73 | hidden_states = hidden_states + residual
74 |
75 | hidden_states = hidden_states / attn.rescale_output_factor
76 |
77 | return hidden_states
78 |
79 |
80 | class IPAttnProcessor(nn.Module):
81 | r"""
82 | Attention processor for IP-Adapater.
83 | Args:
84 | hidden_size (`int`):
85 | The hidden size of the attention layer.
86 | cross_attention_dim (`int`):
87 | The number of channels in the `encoder_hidden_states`.
88 | scale (`float`, defaults to 1.0):
89 | the weight scale of image prompt.
90 | num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
91 | The context length of the image features.
92 | """
93 |
94 | def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
95 | super().__init__()
96 |
97 | self.hidden_size = hidden_size
98 | self.cross_attention_dim = cross_attention_dim
99 | self.scale = scale
100 | self.num_tokens = num_tokens
101 |
102 | self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
103 | self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
104 |
105 | def __call__(
106 | self,
107 | attn,
108 | hidden_states,
109 | encoder_hidden_states=None,
110 | attention_mask=None,
111 | temb=None,
112 | ):
113 | residual = hidden_states
114 |
115 | if attn.spatial_norm is not None:
116 | hidden_states = attn.spatial_norm(hidden_states, temb)
117 |
118 | input_ndim = hidden_states.ndim
119 |
120 | if input_ndim == 4:
121 | batch_size, channel, height, width = hidden_states.shape
122 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
123 |
124 | batch_size, sequence_length, _ = (
125 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
126 | )
127 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
128 |
129 | if attn.group_norm is not None:
130 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
131 |
132 | query = attn.to_q(hidden_states)
133 |
134 | if encoder_hidden_states is None:
135 | encoder_hidden_states = hidden_states
136 | else:
137 | # get encoder_hidden_states, ip_hidden_states
138 | end_pos = encoder_hidden_states.shape[1] - self.num_tokens
139 | encoder_hidden_states, ip_hidden_states = (
140 | encoder_hidden_states[:, :end_pos, :],
141 | encoder_hidden_states[:, end_pos:, :],
142 | )
143 | if attn.norm_cross:
144 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
145 |
146 | key = attn.to_k(encoder_hidden_states)
147 | value = attn.to_v(encoder_hidden_states)
148 |
149 | query = attn.head_to_batch_dim(query)
150 | key = attn.head_to_batch_dim(key)
151 | value = attn.head_to_batch_dim(value)
152 |
153 | attention_probs = attn.get_attention_scores(query, key, attention_mask)
154 | hidden_states = torch.bmm(attention_probs, value)
155 | hidden_states = attn.batch_to_head_dim(hidden_states)
156 |
157 | # for ip-adapter
158 | ip_key = self.to_k_ip(ip_hidden_states)
159 | ip_value = self.to_v_ip(ip_hidden_states)
160 |
161 | ip_key = attn.head_to_batch_dim(ip_key)
162 | ip_value = attn.head_to_batch_dim(ip_value)
163 |
164 | ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
165 | self.attn_map = ip_attention_probs
166 | ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
167 | ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
168 |
169 | hidden_states = hidden_states + self.scale * ip_hidden_states
170 |
171 | # linear proj
172 | hidden_states = attn.to_out[0](hidden_states)
173 | # dropout
174 | hidden_states = attn.to_out[1](hidden_states)
175 |
176 | if input_ndim == 4:
177 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
178 |
179 | if attn.residual_connection:
180 | hidden_states = hidden_states + residual
181 |
182 | hidden_states = hidden_states / attn.rescale_output_factor
183 |
184 | return hidden_states
185 |
186 |
187 | class AttnProcessor2_0(torch.nn.Module):
188 | r"""
189 | Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
190 | """
191 |
192 | def __init__(
193 | self,
194 | hidden_size=None,
195 | cross_attention_dim=None,
196 | ):
197 | super().__init__()
198 | if not hasattr(F, "scaled_dot_product_attention"):
199 | raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
200 |
201 | def __call__(
202 | self,
203 | attn,
204 | hidden_states,
205 | encoder_hidden_states=None,
206 | attention_mask=None,
207 | temb=None,
208 | ):
209 | residual = hidden_states
210 |
211 | if attn.spatial_norm is not None:
212 | hidden_states = attn.spatial_norm(hidden_states, temb)
213 |
214 | input_ndim = hidden_states.ndim
215 |
216 | if input_ndim == 4:
217 | batch_size, channel, height, width = hidden_states.shape
218 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
219 |
220 | batch_size, sequence_length, _ = (
221 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
222 | )
223 |
224 | if attention_mask is not None:
225 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
226 | # scaled_dot_product_attention expects attention_mask shape to be
227 | # (batch, heads, source_length, target_length)
228 | attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
229 |
230 | if attn.group_norm is not None:
231 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
232 |
233 | query = attn.to_q(hidden_states)
234 |
235 | if encoder_hidden_states is None:
236 | encoder_hidden_states = hidden_states
237 | elif attn.norm_cross:
238 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
239 |
240 | key = attn.to_k(encoder_hidden_states)
241 | value = attn.to_v(encoder_hidden_states)
242 |
243 | inner_dim = key.shape[-1]
244 | head_dim = inner_dim // attn.heads
245 |
246 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
247 |
248 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
249 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
250 |
251 | # the output of sdp = (batch, num_heads, seq_len, head_dim)
252 | # TODO: add support for attn.scale when we move to Torch 2.1
253 | hidden_states = F.scaled_dot_product_attention(
254 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
255 | )
256 |
257 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
258 | hidden_states = hidden_states.to(query.dtype)
259 |
260 | # linear proj
261 | hidden_states = attn.to_out[0](hidden_states)
262 | # dropout
263 | hidden_states = attn.to_out[1](hidden_states)
264 |
265 | if input_ndim == 4:
266 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
267 |
268 | if attn.residual_connection:
269 | hidden_states = hidden_states + residual
270 |
271 | hidden_states = hidden_states / attn.rescale_output_factor
272 |
273 | return hidden_states
274 |
275 |
276 | class IPAttnProcessor2_0(torch.nn.Module):
277 | r"""
278 | Attention processor for IP-Adapater for PyTorch 2.0.
279 | Args:
280 | hidden_size (`int`):
281 | The hidden size of the attention layer.
282 | cross_attention_dim (`int`):
283 | The number of channels in the `encoder_hidden_states`.
284 | scale (`float`, defaults to 1.0):
285 | the weight scale of image prompt.
286 | num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
287 | The context length of the image features.
288 | """
289 |
290 | def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
291 | super().__init__()
292 |
293 | if not hasattr(F, "scaled_dot_product_attention"):
294 | raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
295 |
296 | self.hidden_size = hidden_size
297 | self.cross_attention_dim = cross_attention_dim
298 | self.scale = scale
299 | self.num_tokens = num_tokens
300 |
301 | self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
302 | self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
303 |
304 | def __call__(
305 | self,
306 | attn,
307 | hidden_states,
308 | encoder_hidden_states=None,
309 | attention_mask=None,
310 | temb=None,
311 | ):
312 | residual = hidden_states
313 |
314 | if attn.spatial_norm is not None:
315 | hidden_states = attn.spatial_norm(hidden_states, temb)
316 |
317 | input_ndim = hidden_states.ndim
318 |
319 | if input_ndim == 4:
320 | batch_size, channel, height, width = hidden_states.shape
321 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
322 |
323 | batch_size, sequence_length, _ = (
324 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
325 | )
326 |
327 | if attention_mask is not None:
328 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
329 | # scaled_dot_product_attention expects attention_mask shape to be
330 | # (batch, heads, source_length, target_length)
331 | attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
332 |
333 | if attn.group_norm is not None:
334 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
335 |
336 | query = attn.to_q(hidden_states)
337 |
338 | if encoder_hidden_states is None:
339 | encoder_hidden_states = hidden_states
340 | else:
341 | # get encoder_hidden_states, ip_hidden_states
342 | end_pos = encoder_hidden_states.shape[1] - self.num_tokens
343 | encoder_hidden_states, ip_hidden_states = (
344 | encoder_hidden_states[:, :end_pos, :],
345 | encoder_hidden_states[:, end_pos:, :],
346 | )
347 | if attn.norm_cross:
348 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
349 |
350 | key = attn.to_k(encoder_hidden_states)
351 | value = attn.to_v(encoder_hidden_states)
352 |
353 | inner_dim = key.shape[-1]
354 | head_dim = inner_dim // attn.heads
355 |
356 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
357 |
358 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
359 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
360 |
361 | # the output of sdp = (batch, num_heads, seq_len, head_dim)
362 | # TODO: add support for attn.scale when we move to Torch 2.1
363 | hidden_states = F.scaled_dot_product_attention(
364 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
365 | )
366 |
367 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
368 | hidden_states = hidden_states.to(query.dtype)
369 |
370 | # for ip-adapter
371 | ip_key = self.to_k_ip(ip_hidden_states)
372 | ip_value = self.to_v_ip(ip_hidden_states)
373 |
374 | ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
375 | ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
376 |
377 | # the output of sdp = (batch, num_heads, seq_len, head_dim)
378 | # TODO: add support for attn.scale when we move to Torch 2.1
379 | ip_hidden_states = F.scaled_dot_product_attention(
380 | query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
381 | )
382 | with torch.no_grad():
383 | self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)
384 | #print(self.attn_map.shape)
385 |
386 | ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
387 | ip_hidden_states = ip_hidden_states.to(query.dtype)
388 |
389 | scales = torch.ones(ip_hidden_states.shape, dtype= ip_hidden_states.dtype)
390 | scales[0:int(scales.shape[0]/2)] = self.scale[0] # uncond
391 | scales[int(scales.shape[0]/2):] = self.scale[1] # cond
392 | scales = scales.to(hidden_states.device)
393 |
394 | hidden_states = hidden_states + scales * ip_hidden_states
395 |
396 | # linear proj
397 | hidden_states = attn.to_out[0](hidden_states)
398 | # dropout
399 | hidden_states = attn.to_out[1](hidden_states)
400 |
401 | if input_ndim == 4:
402 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
403 |
404 | if attn.residual_connection:
405 | hidden_states = hidden_states + residual
406 |
407 | hidden_states = hidden_states / attn.rescale_output_factor
408 |
409 | return hidden_states
410 |
411 |
412 | ## for controlnet
413 | class CNAttnProcessor:
414 | r"""
415 | Default processor for performing attention-related computations.
416 | """
417 |
418 | def __init__(self, num_tokens=4):
419 | self.num_tokens = num_tokens
420 |
421 | def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
422 | residual = hidden_states
423 |
424 | if attn.spatial_norm is not None:
425 | hidden_states = attn.spatial_norm(hidden_states, temb)
426 |
427 | input_ndim = hidden_states.ndim
428 |
429 | if input_ndim == 4:
430 | batch_size, channel, height, width = hidden_states.shape
431 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
432 |
433 | batch_size, sequence_length, _ = (
434 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
435 | )
436 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
437 |
438 | if attn.group_norm is not None:
439 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
440 |
441 | query = attn.to_q(hidden_states)
442 |
443 | if encoder_hidden_states is None:
444 | encoder_hidden_states = hidden_states
445 | else:
446 | end_pos = encoder_hidden_states.shape[1] - self.num_tokens
447 | encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
448 | if attn.norm_cross:
449 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
450 |
451 | key = attn.to_k(encoder_hidden_states)
452 | value = attn.to_v(encoder_hidden_states)
453 |
454 | query = attn.head_to_batch_dim(query)
455 | key = attn.head_to_batch_dim(key)
456 | value = attn.head_to_batch_dim(value)
457 |
458 | attention_probs = attn.get_attention_scores(query, key, attention_mask)
459 | hidden_states = torch.bmm(attention_probs, value)
460 | hidden_states = attn.batch_to_head_dim(hidden_states)
461 |
462 | # linear proj
463 | hidden_states = attn.to_out[0](hidden_states)
464 | # dropout
465 | hidden_states = attn.to_out[1](hidden_states)
466 |
467 | if input_ndim == 4:
468 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
469 |
470 | if attn.residual_connection:
471 | hidden_states = hidden_states + residual
472 |
473 | hidden_states = hidden_states / attn.rescale_output_factor
474 |
475 | return hidden_states
476 |
477 |
478 | class CNAttnProcessor2_0:
479 | r"""
480 | Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
481 | """
482 |
483 | def __init__(self, num_tokens=4):
484 | if not hasattr(F, "scaled_dot_product_attention"):
485 | raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
486 | self.num_tokens = num_tokens
487 |
488 | def __call__(
489 | self,
490 | attn,
491 | hidden_states,
492 | encoder_hidden_states=None,
493 | attention_mask=None,
494 | temb=None,
495 | ):
496 | residual = hidden_states
497 |
498 | if attn.spatial_norm is not None:
499 | hidden_states = attn.spatial_norm(hidden_states, temb)
500 |
501 | input_ndim = hidden_states.ndim
502 |
503 | if input_ndim == 4:
504 | batch_size, channel, height, width = hidden_states.shape
505 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
506 |
507 | batch_size, sequence_length, _ = (
508 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
509 | )
510 |
511 | if attention_mask is not None:
512 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
513 | # scaled_dot_product_attention expects attention_mask shape to be
514 | # (batch, heads, source_length, target_length)
515 | attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
516 |
517 | if attn.group_norm is not None:
518 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
519 |
520 | query = attn.to_q(hidden_states)
521 |
522 | if encoder_hidden_states is None:
523 | encoder_hidden_states = hidden_states
524 | else:
525 | end_pos = encoder_hidden_states.shape[1] - self.num_tokens
526 | encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
527 | if attn.norm_cross:
528 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
529 |
530 | key = attn.to_k(encoder_hidden_states)
531 | value = attn.to_v(encoder_hidden_states)
532 |
533 | inner_dim = key.shape[-1]
534 | head_dim = inner_dim // attn.heads
535 |
536 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
537 |
538 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
539 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
540 |
541 | # the output of sdp = (batch, num_heads, seq_len, head_dim)
542 | # TODO: add support for attn.scale when we move to Torch 2.1
543 | hidden_states = F.scaled_dot_product_attention(
544 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
545 | )
546 |
547 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
548 | hidden_states = hidden_states.to(query.dtype)
549 |
550 | # linear proj
551 | hidden_states = attn.to_out[0](hidden_states)
552 | # dropout
553 | hidden_states = attn.to_out[1](hidden_states)
554 |
555 | if input_ndim == 4:
556 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
557 |
558 | if attn.residual_connection:
559 | hidden_states = hidden_states + residual
560 |
561 | hidden_states = hidden_states / attn.rescale_output_factor
562 |
563 | return hidden_states
564 |
--------------------------------------------------------------------------------
/ip_adapter/attention_processor_faceid.py:
--------------------------------------------------------------------------------
1 | # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from diffusers.models.lora import LoRALinearLayer
7 |
8 |
9 | class LoRAAttnProcessor(nn.Module):
10 | r"""
11 | Default processor for performing attention-related computations.
12 | """
13 |
14 | def __init__(
15 | self,
16 | hidden_size=None,
17 | cross_attention_dim=None,
18 | rank=4,
19 | network_alpha=None,
20 | lora_scale=1.0,
21 | ):
22 | super().__init__()
23 |
24 | self.rank = rank
25 | self.lora_scale = lora_scale
26 |
27 | self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
28 | self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
29 | self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
30 | self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
31 |
32 | def __call__(
33 | self,
34 | attn,
35 | hidden_states,
36 | encoder_hidden_states=None,
37 | attention_mask=None,
38 | temb=None,
39 | ):
40 | residual = hidden_states
41 |
42 | if attn.spatial_norm is not None:
43 | hidden_states = attn.spatial_norm(hidden_states, temb)
44 |
45 | input_ndim = hidden_states.ndim
46 |
47 | if input_ndim == 4:
48 | batch_size, channel, height, width = hidden_states.shape
49 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
50 |
51 | batch_size, sequence_length, _ = (
52 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
53 | )
54 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
55 |
56 | if attn.group_norm is not None:
57 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
58 |
59 | query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states)
60 |
61 | if encoder_hidden_states is None:
62 | encoder_hidden_states = hidden_states
63 | elif attn.norm_cross:
64 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
65 |
66 | key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
67 | value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states)
68 |
69 | query = attn.head_to_batch_dim(query)
70 | key = attn.head_to_batch_dim(key)
71 | value = attn.head_to_batch_dim(value)
72 |
73 | attention_probs = attn.get_attention_scores(query, key, attention_mask)
74 | hidden_states = torch.bmm(attention_probs, value)
75 | hidden_states = attn.batch_to_head_dim(hidden_states)
76 |
77 | # linear proj
78 | hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states)
79 | # dropout
80 | hidden_states = attn.to_out[1](hidden_states)
81 |
82 | if input_ndim == 4:
83 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
84 |
85 | if attn.residual_connection:
86 | hidden_states = hidden_states + residual
87 |
88 | hidden_states = hidden_states / attn.rescale_output_factor
89 |
90 | return hidden_states
91 |
92 |
93 | class LoRAIPAttnProcessor(nn.Module):
94 | r"""
95 | Attention processor for IP-Adapater.
96 | Args:
97 | hidden_size (`int`):
98 | The hidden size of the attention layer.
99 | cross_attention_dim (`int`):
100 | The number of channels in the `encoder_hidden_states`.
101 | scale (`float`, defaults to 1.0):
102 | the weight scale of image prompt.
103 | num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
104 | The context length of the image features.
105 | """
106 |
107 | def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None, lora_scale=1.0, scale=1.0, num_tokens=4):
108 | super().__init__()
109 |
110 | self.rank = rank
111 | self.lora_scale = lora_scale
112 |
113 | self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
114 | self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
115 | self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
116 | self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
117 |
118 | self.hidden_size = hidden_size
119 | self.cross_attention_dim = cross_attention_dim
120 | self.scale = scale
121 | self.num_tokens = num_tokens
122 |
123 | self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
124 | self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
125 |
126 | def __call__(
127 | self,
128 | attn,
129 | hidden_states,
130 | encoder_hidden_states=None,
131 | attention_mask=None,
132 | temb=None,
133 | ):
134 | residual = hidden_states
135 |
136 | if attn.spatial_norm is not None:
137 | hidden_states = attn.spatial_norm(hidden_states, temb)
138 |
139 | input_ndim = hidden_states.ndim
140 |
141 | if input_ndim == 4:
142 | batch_size, channel, height, width = hidden_states.shape
143 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
144 |
145 | batch_size, sequence_length, _ = (
146 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
147 | )
148 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
149 |
150 | if attn.group_norm is not None:
151 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
152 |
153 | query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states)
154 |
155 | if encoder_hidden_states is None:
156 | encoder_hidden_states = hidden_states
157 | else:
158 | # get encoder_hidden_states, ip_hidden_states
159 | end_pos = encoder_hidden_states.shape[1] - self.num_tokens
160 | encoder_hidden_states, ip_hidden_states = (
161 | encoder_hidden_states[:, :end_pos, :],
162 | encoder_hidden_states[:, end_pos:, :],
163 | )
164 | if attn.norm_cross:
165 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
166 |
167 | key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
168 | value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states)
169 |
170 | query = attn.head_to_batch_dim(query)
171 | key = attn.head_to_batch_dim(key)
172 | value = attn.head_to_batch_dim(value)
173 |
174 | attention_probs = attn.get_attention_scores(query, key, attention_mask)
175 | hidden_states = torch.bmm(attention_probs, value)
176 | hidden_states = attn.batch_to_head_dim(hidden_states)
177 |
178 | # for ip-adapter
179 | ip_key = self.to_k_ip(ip_hidden_states)
180 | ip_value = self.to_v_ip(ip_hidden_states)
181 |
182 | ip_key = attn.head_to_batch_dim(ip_key)
183 | ip_value = attn.head_to_batch_dim(ip_value)
184 |
185 | ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
186 | self.attn_map = ip_attention_probs
187 | ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
188 | ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
189 |
190 | hidden_states = hidden_states + self.scale * ip_hidden_states
191 |
192 | # linear proj
193 | hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states)
194 | # dropout
195 | hidden_states = attn.to_out[1](hidden_states)
196 |
197 | if input_ndim == 4:
198 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
199 |
200 | if attn.residual_connection:
201 | hidden_states = hidden_states + residual
202 |
203 | hidden_states = hidden_states / attn.rescale_output_factor
204 |
205 | return hidden_states
206 |
207 |
208 | class LoRAAttnProcessor2_0(nn.Module):
209 |
210 | r"""
211 | Default processor for performing attention-related computations.
212 | """
213 |
214 | def __init__(
215 | self,
216 | hidden_size=None,
217 | cross_attention_dim=None,
218 | rank=4,
219 | network_alpha=None,
220 | lora_scale=1.0,
221 | ):
222 | super().__init__()
223 |
224 | self.rank = rank
225 | self.lora_scale = lora_scale
226 |
227 | self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
228 | self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
229 | self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
230 | self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
231 |
232 | def __call__(
233 | self,
234 | attn,
235 | hidden_states,
236 | encoder_hidden_states=None,
237 | attention_mask=None,
238 | temb=None,
239 | ):
240 | residual = hidden_states
241 |
242 | if attn.spatial_norm is not None:
243 | hidden_states = attn.spatial_norm(hidden_states, temb)
244 |
245 | input_ndim = hidden_states.ndim
246 |
247 | if input_ndim == 4:
248 | batch_size, channel, height, width = hidden_states.shape
249 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
250 |
251 | batch_size, sequence_length, _ = (
252 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
253 | )
254 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
255 |
256 | if attn.group_norm is not None:
257 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
258 |
259 | query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states)
260 |
261 | if encoder_hidden_states is None:
262 | encoder_hidden_states = hidden_states
263 | elif attn.norm_cross:
264 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
265 |
266 | key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
267 | value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states)
268 |
269 | inner_dim = key.shape[-1]
270 | head_dim = inner_dim // attn.heads
271 |
272 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
273 |
274 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
275 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
276 |
277 | # the output of sdp = (batch, num_heads, seq_len, head_dim)
278 | # TODO: add support for attn.scale when we move to Torch 2.1
279 | hidden_states = F.scaled_dot_product_attention(
280 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
281 | )
282 |
283 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
284 | hidden_states = hidden_states.to(query.dtype)
285 |
286 | # linear proj
287 | hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states)
288 | # dropout
289 | hidden_states = attn.to_out[1](hidden_states)
290 |
291 | if input_ndim == 4:
292 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
293 |
294 | if attn.residual_connection:
295 | hidden_states = hidden_states + residual
296 |
297 | hidden_states = hidden_states / attn.rescale_output_factor
298 |
299 | return hidden_states
300 |
301 |
302 | class LoRAIPAttnProcessor2_0(nn.Module):
303 | r"""
304 | Processor for implementing the LoRA attention mechanism.
305 |
306 | Args:
307 | hidden_size (`int`, *optional*):
308 | The hidden size of the attention layer.
309 | cross_attention_dim (`int`, *optional*):
310 | The number of channels in the `encoder_hidden_states`.
311 | rank (`int`, defaults to 4):
312 | The dimension of the LoRA update matrices.
313 | network_alpha (`int`, *optional*):
314 | Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
315 | """
316 |
317 | def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None, lora_scale=1.0, scale=1.0, num_tokens=4):
318 | super().__init__()
319 |
320 | self.rank = rank
321 | self.lora_scale = lora_scale
322 | self.num_tokens = num_tokens
323 |
324 | self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
325 | self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
326 | self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
327 | self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
328 |
329 |
330 | self.hidden_size = hidden_size
331 | self.cross_attention_dim = cross_attention_dim
332 | self.scale = scale
333 |
334 | self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
335 | self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
336 |
337 | def __call__(
338 | self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
339 | ):
340 | residual = hidden_states
341 |
342 | if attn.spatial_norm is not None:
343 | hidden_states = attn.spatial_norm(hidden_states, temb)
344 |
345 | input_ndim = hidden_states.ndim
346 |
347 | if input_ndim == 4:
348 | batch_size, channel, height, width = hidden_states.shape
349 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
350 |
351 | batch_size, sequence_length, _ = (
352 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
353 | )
354 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
355 |
356 | if attn.group_norm is not None:
357 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
358 |
359 | query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states)
360 | #query = attn.head_to_batch_dim(query)
361 |
362 | if encoder_hidden_states is None:
363 | encoder_hidden_states = hidden_states
364 | else:
365 | # get encoder_hidden_states, ip_hidden_states
366 | end_pos = encoder_hidden_states.shape[1] - self.num_tokens
367 | encoder_hidden_states, ip_hidden_states = (
368 | encoder_hidden_states[:, :end_pos, :],
369 | encoder_hidden_states[:, end_pos:, :],
370 | )
371 | if attn.norm_cross:
372 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
373 |
374 | # for text
375 | key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
376 | value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states)
377 |
378 | inner_dim = key.shape[-1]
379 | head_dim = inner_dim // attn.heads
380 |
381 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
382 |
383 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
384 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
385 |
386 | # the output of sdp = (batch, num_heads, seq_len, head_dim)
387 | # TODO: add support for attn.scale when we move to Torch 2.1
388 | hidden_states = F.scaled_dot_product_attention(
389 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
390 | )
391 |
392 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
393 | hidden_states = hidden_states.to(query.dtype)
394 |
395 | # for ip
396 | ip_key = self.to_k_ip(ip_hidden_states)
397 | ip_value = self.to_v_ip(ip_hidden_states)
398 |
399 | ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
400 | ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
401 |
402 | # the output of sdp = (batch, num_heads, seq_len, head_dim)
403 | # TODO: add support for attn.scale when we move to Torch 2.1
404 | ip_hidden_states = F.scaled_dot_product_attention(
405 | query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
406 | )
407 |
408 |
409 | ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
410 | ip_hidden_states = ip_hidden_states.to(query.dtype)
411 |
412 | hidden_states = hidden_states + self.scale * ip_hidden_states
413 |
414 | # linear proj
415 | hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states)
416 | # dropout
417 | hidden_states = attn.to_out[1](hidden_states)
418 |
419 | if input_ndim == 4:
420 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
421 |
422 | if attn.residual_connection:
423 | hidden_states = hidden_states + residual
424 |
425 | hidden_states = hidden_states / attn.rescale_output_factor
426 |
427 | return hidden_states
428 |
--------------------------------------------------------------------------------
/ip_adapter/ip_adapter.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import List
3 |
4 | import torch
5 | from diffusers import StableDiffusionPipeline
6 | from diffusers.pipelines.controlnet import MultiControlNetModel
7 | from PIL import Image
8 | from safetensors import safe_open
9 | from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
10 |
11 | from .utils import is_torch2_available
12 |
13 | if is_torch2_available():
14 | from .attention_processor import (
15 | AttnProcessor2_0 as AttnProcessor,
16 | )
17 | from .attention_processor import (
18 | CNAttnProcessor2_0 as CNAttnProcessor,
19 | )
20 | from .attention_processor import (
21 | IPAttnProcessor2_0 as IPAttnProcessor,
22 | )
23 | else:
24 | from .attention_processor import AttnProcessor, CNAttnProcessor, IPAttnProcessor
25 | from .resampler import Resampler
26 |
27 |
28 | class ImageProjModel(torch.nn.Module):
29 | """Projection Model"""
30 |
31 | def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
32 | super().__init__()
33 |
34 | self.cross_attention_dim = cross_attention_dim
35 | self.clip_extra_context_tokens = clip_extra_context_tokens
36 | self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
37 | self.norm = torch.nn.LayerNorm(cross_attention_dim)
38 |
39 | def forward(self, image_embeds):
40 | embeds = image_embeds
41 | clip_extra_context_tokens = self.proj(embeds).reshape(
42 | -1, self.clip_extra_context_tokens, self.cross_attention_dim
43 | )
44 | clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
45 | return clip_extra_context_tokens
46 |
47 |
48 | class MLPProjModel(torch.nn.Module):
49 | """SD model with image prompt"""
50 | def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024):
51 | super().__init__()
52 |
53 | self.proj = torch.nn.Sequential(
54 | torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim),
55 | torch.nn.GELU(),
56 | torch.nn.Linear(clip_embeddings_dim, cross_attention_dim),
57 | torch.nn.LayerNorm(cross_attention_dim)
58 | )
59 |
60 | def forward(self, image_embeds):
61 | clip_extra_context_tokens = self.proj(image_embeds)
62 | return clip_extra_context_tokens
63 |
64 |
65 | class IPAdapter:
66 | def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4):
67 | self.device = device
68 | self.image_encoder_path = image_encoder_path
69 | self.ip_ckpt = ip_ckpt
70 | self.num_tokens = num_tokens
71 |
72 | self.pipe = sd_pipe.to(self.device)
73 | self.set_ip_adapter()
74 |
75 | # load image encoder
76 | self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
77 | self.device, dtype=torch.float16
78 | )
79 | self.clip_image_processor = CLIPImageProcessor()
80 | # image proj model
81 | self.image_proj_model = self.init_proj()
82 |
83 | self.load_ip_adapter()
84 |
85 | def init_proj(self):
86 | image_proj_model = ImageProjModel(
87 | cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
88 | clip_embeddings_dim=self.image_encoder.config.projection_dim,
89 | clip_extra_context_tokens=self.num_tokens,
90 | ).to(self.device, dtype=torch.float16)
91 | return image_proj_model
92 |
93 | def set_ip_adapter(self):
94 | unet = self.pipe.unet
95 | attn_procs = {}
96 | for name in unet.attn_processors.keys():
97 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
98 | if name.startswith("mid_block"):
99 | hidden_size = unet.config.block_out_channels[-1]
100 | elif name.startswith("up_blocks"):
101 | block_id = int(name[len("up_blocks.")])
102 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
103 | elif name.startswith("down_blocks"):
104 | block_id = int(name[len("down_blocks.")])
105 | hidden_size = unet.config.block_out_channels[block_id]
106 | if cross_attention_dim is None:
107 | attn_procs[name] = AttnProcessor()
108 | else:
109 | attn_procs[name] = IPAttnProcessor(
110 | hidden_size=hidden_size,
111 | cross_attention_dim=cross_attention_dim,
112 | scale=1.0,
113 | num_tokens=self.num_tokens,
114 | ).to(self.device, dtype=torch.float16)
115 | unet.set_attn_processor(attn_procs)
116 | if hasattr(self.pipe, "controlnet"):
117 | if isinstance(self.pipe.controlnet, MultiControlNetModel):
118 | for controlnet in self.pipe.controlnet.nets:
119 | controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
120 | else:
121 | self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
122 |
123 | def load_ip_adapter(self):
124 | if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
125 | state_dict = {"image_proj": {}, "ip_adapter": {}}
126 | with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
127 | for key in f.keys():
128 | if key.startswith("image_proj."):
129 | state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
130 | elif key.startswith("ip_adapter."):
131 | state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
132 | else:
133 | state_dict = torch.load(self.ip_ckpt, map_location="cpu")
134 | self.image_proj_model.load_state_dict(state_dict["image_proj"])
135 | ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
136 | ip_layers.load_state_dict(state_dict["ip_adapter"])
137 |
138 | @torch.inference_mode()
139 | def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
140 | if pil_image is not None:
141 | if isinstance(pil_image, Image.Image):
142 | pil_image = [pil_image]
143 | clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
144 | clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
145 | else:
146 | clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
147 | image_prompt_embeds = self.image_proj_model(clip_image_embeds)
148 | uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
149 | return image_prompt_embeds, uncond_image_prompt_embeds
150 |
151 | def set_scale(self, scale):
152 | for attn_processor in self.pipe.unet.attn_processors.values():
153 | if isinstance(attn_processor, IPAttnProcessor):
154 | attn_processor.scale = scale
155 |
156 | def generate(
157 | self,
158 | pil_image=None,
159 | clip_image_embeds=None,
160 | negative_pil_image=None,
161 | negative_clip_image_embeds=None,
162 | prompt=None,
163 | negative_prompt=None,
164 | scale=1.0, # weight for image prompt
165 | scale_start = 0.0,
166 | scale_stop = 1.0,
167 | scale_neg = 0.0, # weight for negative image prompt
168 | scale_neg_start = 0.0,
169 | scale_neg_stop = 1.0,
170 | num_samples=4,
171 | seed=None,
172 | guidance_scale=7.5,
173 | num_inference_steps=30,
174 | **kwargs,
175 | ):
176 |
177 | # set scales for negative and positive image prompt
178 | self.set_scale( (scale_neg if scale_neg_start == 0. else 0., scale if scale_start == 0. else 0.) )
179 | self.pipe.ip_scale = scale
180 | self.pipe.ip_scale_start = scale_start
181 | self.pipe.ip_scale_stop = scale_stop
182 | self.pipe.ip_scale_neg = scale
183 | self.pipe.ip_scale_neg_start = scale_neg_start
184 | self.pipe.ip_scale_neg_stop = scale_neg_stop
185 |
186 |
187 | if pil_image is not None:
188 | num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
189 | elif negative_pil_image:
190 | num_prompts = 1 if isinstance(negative_pil_image, Image.Image) else len(negative_pil_image)
191 | else:
192 | num_prompts = clip_image_embeds.size(0)
193 |
194 | if prompt is None:
195 | prompt = "best quality, high quality"
196 | if negative_prompt is None:
197 | negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
198 |
199 | if not isinstance(prompt, List):
200 | prompt = [prompt] * num_prompts
201 | if not isinstance(negative_prompt, List):
202 | negative_prompt = [negative_prompt] * num_prompts
203 |
204 | if (pil_image or clip_image_embeds) and not (negative_pil_image or negative_clip_image_embeds):
205 | # positive image only
206 | print('positive image')
207 | image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(
208 | pil_image=pil_image, clip_image_embeds=clip_image_embeds
209 | )
210 | elif not (pil_image or clip_image_embeds) and (negative_pil_image or negative_clip_image_embeds):
211 | # negative prompt only
212 | print('negative image')
213 | uncond_image_prompt_embeds, image_prompt_embeds = self.get_image_embeds(
214 | pil_image=negative_pil_image, clip_image_embeds=negative_clip_image_embeds
215 | )
216 | elif (pil_image or clip_image_embeds) and (negative_pil_image or negative_clip_image_embeds):
217 | # positive and negative images
218 | print('positive and negative image')
219 | image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(
220 | pil_image=pil_image, clip_image_embeds=clip_image_embeds
221 | )
222 | uncond_image_prompt_embeds, _ = self.get_image_embeds(
223 | pil_image=negative_pil_image, clip_image_embeds=negative_clip_image_embeds
224 | )
225 | else:
226 | # No postive or negative images
227 | NotImplementedError("")
228 |
229 | bs_embed, seq_len, _ = image_prompt_embeds.shape
230 | image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
231 | image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
232 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
233 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
234 |
235 | with torch.inference_mode():
236 | prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
237 | prompt,
238 | device=self.device,
239 | num_images_per_prompt=num_samples,
240 | do_classifier_free_guidance=True,
241 | negative_prompt=negative_prompt,
242 | )
243 | prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
244 | negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
245 |
246 | def callback_weight_schedule(pipe, step_index, timestep, callback_kwargs):
247 |
248 | # turn on ip scale when it is between start and stop
249 | if step_index >= int(pipe.num_timesteps * pipe.ip_scale_start) and step_index <= int(pipe.num_timesteps * pipe.ip_scale_stop):
250 | scale = pipe.ip_scale
251 | else:
252 | scale = 0
253 |
254 | # turn on ip negative scale when it is between start and stop
255 | if step_index >= int(pipe.num_timesteps * pipe.ip_scale_neg_start) and step_index <= int(pipe.num_timesteps * pipe.ip_scale_neg_stop):
256 | scale_neg = pipe.ip_scale_neg
257 | else:
258 | scale_neg = 0
259 |
260 | # update guidance_scale and prompt_embeds
261 | self.set_scale((scale_neg, scale))
262 | return callback_kwargs
263 |
264 | generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
265 | images = self.pipe(
266 | prompt_embeds=prompt_embeds,
267 | negative_prompt_embeds=negative_prompt_embeds,
268 | guidance_scale=guidance_scale,
269 | num_inference_steps=num_inference_steps,
270 | generator=generator,
271 | callback_on_step_end= callback_weight_schedule,
272 | **kwargs,
273 | ).images
274 |
275 | return images
276 |
277 |
278 | class IPAdapterXL(IPAdapter):
279 | """SDXL"""
280 |
281 | def generate(
282 | self,
283 | pil_image,
284 | prompt=None,
285 | negative_prompt=None,
286 | scale=1.0,
287 | num_samples=4,
288 | seed=None,
289 | num_inference_steps=30,
290 | **kwargs,
291 | ):
292 | self.set_scale(scale)
293 |
294 | num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
295 |
296 | if prompt is None:
297 | prompt = "best quality, high quality"
298 | if negative_prompt is None:
299 | negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
300 |
301 | if not isinstance(prompt, List):
302 | prompt = [prompt] * num_prompts
303 | if not isinstance(negative_prompt, List):
304 | negative_prompt = [negative_prompt] * num_prompts
305 |
306 | image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
307 | bs_embed, seq_len, _ = image_prompt_embeds.shape
308 | image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
309 | image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
310 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
311 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
312 |
313 | with torch.inference_mode():
314 | (
315 | prompt_embeds,
316 | negative_prompt_embeds,
317 | pooled_prompt_embeds,
318 | negative_pooled_prompt_embeds,
319 | ) = self.pipe.encode_prompt(
320 | prompt,
321 | num_images_per_prompt=num_samples,
322 | do_classifier_free_guidance=True,
323 | negative_prompt=negative_prompt,
324 | )
325 | prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
326 | negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
327 |
328 | generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
329 | images = self.pipe(
330 | prompt_embeds=prompt_embeds,
331 | negative_prompt_embeds=negative_prompt_embeds,
332 | pooled_prompt_embeds=pooled_prompt_embeds,
333 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
334 | num_inference_steps=num_inference_steps,
335 |
336 | generator=generator,
337 | **kwargs,
338 | ).images
339 |
340 | return images
341 |
342 |
343 | class IPAdapterPlus(IPAdapter):
344 | """IP-Adapter with fine-grained features"""
345 |
346 | def init_proj(self):
347 | image_proj_model = Resampler(
348 | dim=self.pipe.unet.config.cross_attention_dim,
349 | depth=4,
350 | dim_head=64,
351 | heads=12,
352 | num_queries=self.num_tokens,
353 | embedding_dim=self.image_encoder.config.hidden_size,
354 | output_dim=self.pipe.unet.config.cross_attention_dim,
355 | ff_mult=4,
356 | ).to(self.device, dtype=torch.float16)
357 | return image_proj_model
358 |
359 | @torch.inference_mode()
360 | def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
361 | if isinstance(pil_image, Image.Image):
362 | pil_image = [pil_image]
363 | clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
364 | clip_image = clip_image.to(self.device, dtype=torch.float16)
365 | clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
366 | image_prompt_embeds = self.image_proj_model(clip_image_embeds)
367 | uncond_clip_image_embeds = self.image_encoder(
368 | torch.zeros_like(clip_image), output_hidden_states=True
369 | ).hidden_states[-2]
370 | uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
371 | return image_prompt_embeds, uncond_image_prompt_embeds
372 |
373 |
374 | class IPAdapterFull(IPAdapterPlus):
375 | """IP-Adapter with full features"""
376 |
377 | def init_proj(self):
378 | image_proj_model = MLPProjModel(
379 | cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
380 | clip_embeddings_dim=self.image_encoder.config.hidden_size,
381 | ).to(self.device, dtype=torch.float16)
382 | return image_proj_model
383 |
384 |
385 | class IPAdapterPlusXL(IPAdapter):
386 | """SDXL"""
387 |
388 | def init_proj(self):
389 | image_proj_model = Resampler(
390 | dim=1280,
391 | depth=4,
392 | dim_head=64,
393 | heads=20,
394 | num_queries=self.num_tokens,
395 | embedding_dim=self.image_encoder.config.hidden_size,
396 | output_dim=self.pipe.unet.config.cross_attention_dim,
397 | ff_mult=4,
398 | ).to(self.device, dtype=torch.float16)
399 | return image_proj_model
400 |
401 | @torch.inference_mode()
402 | def get_image_embeds(self, pil_image):
403 | if isinstance(pil_image, Image.Image):
404 | pil_image = [pil_image]
405 | clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
406 | clip_image = clip_image.to(self.device, dtype=torch.float16)
407 | clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
408 | image_prompt_embeds = self.image_proj_model(clip_image_embeds)
409 | uncond_clip_image_embeds = self.image_encoder(
410 | torch.zeros_like(clip_image), output_hidden_states=True
411 | ).hidden_states[-2]
412 | uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
413 | return image_prompt_embeds, uncond_image_prompt_embeds
414 |
415 | def generate(
416 | self,
417 | pil_image,
418 | prompt=None,
419 | negative_prompt=None,
420 | scale=1.0,
421 | num_samples=4,
422 | seed=None,
423 | num_inference_steps=30,
424 | **kwargs,
425 | ):
426 | self.set_scale(scale)
427 |
428 | num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
429 |
430 | if prompt is None:
431 | prompt = "best quality, high quality"
432 | if negative_prompt is None:
433 | negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
434 |
435 | if not isinstance(prompt, List):
436 | prompt = [prompt] * num_prompts
437 | if not isinstance(negative_prompt, List):
438 | negative_prompt = [negative_prompt] * num_prompts
439 |
440 | image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
441 | bs_embed, seq_len, _ = image_prompt_embeds.shape
442 | image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
443 | image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
444 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
445 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
446 |
447 | with torch.inference_mode():
448 | (
449 | prompt_embeds,
450 | negative_prompt_embeds,
451 | pooled_prompt_embeds,
452 | negative_pooled_prompt_embeds,
453 | ) = self.pipe.encode_prompt(
454 | prompt,
455 | num_images_per_prompt=num_samples,
456 | do_classifier_free_guidance=True,
457 | negative_prompt=negative_prompt,
458 | )
459 | prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
460 | negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
461 |
462 | generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
463 | images = self.pipe(
464 | prompt_embeds=prompt_embeds,
465 | negative_prompt_embeds=negative_prompt_embeds,
466 | pooled_prompt_embeds=pooled_prompt_embeds,
467 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
468 | num_inference_steps=num_inference_steps,
469 | generator=generator,
470 | **kwargs,
471 | ).images
472 |
473 | return images
474 |
--------------------------------------------------------------------------------
/ip_adapter/ip_adapter_faceid.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import List
3 |
4 | import torch
5 | from diffusers import StableDiffusionPipeline
6 | from diffusers.pipelines.controlnet import MultiControlNetModel
7 | from PIL import Image
8 | from safetensors import safe_open
9 | from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
10 |
11 | from .attention_processor_faceid import LoRAAttnProcessor, LoRAIPAttnProcessor
12 | from .utils import is_torch2_available
13 |
14 | USE_DAFAULT_ATTN = False # should be True for visualization_attnmap
15 | if is_torch2_available() and (not USE_DAFAULT_ATTN):
16 | from .attention_processor_faceid import (
17 | LoRAAttnProcessor2_0 as LoRAAttnProcessor,
18 | )
19 | from .attention_processor_faceid import (
20 | LoRAIPAttnProcessor2_0 as LoRAIPAttnProcessor,
21 | )
22 | else:
23 | from .attention_processor_faceid import LoRAAttnProcessor, LoRAIPAttnProcessor
24 | from .resampler import PerceiverAttention, FeedForward
25 |
26 |
27 | class FacePerceiverResampler(torch.nn.Module):
28 | def __init__(
29 | self,
30 | *,
31 | dim=768,
32 | depth=4,
33 | dim_head=64,
34 | heads=16,
35 | embedding_dim=1280,
36 | output_dim=768,
37 | ff_mult=4,
38 | ):
39 | super().__init__()
40 |
41 | self.proj_in = torch.nn.Linear(embedding_dim, dim)
42 | self.proj_out = torch.nn.Linear(dim, output_dim)
43 | self.norm_out = torch.nn.LayerNorm(output_dim)
44 | self.layers = torch.nn.ModuleList([])
45 | for _ in range(depth):
46 | self.layers.append(
47 | torch.nn.ModuleList(
48 | [
49 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
50 | FeedForward(dim=dim, mult=ff_mult),
51 | ]
52 | )
53 | )
54 |
55 | def forward(self, latents, x):
56 | x = self.proj_in(x)
57 | for attn, ff in self.layers:
58 | latents = attn(x, latents) + latents
59 | latents = ff(latents) + latents
60 | latents = self.proj_out(latents)
61 | return self.norm_out(latents)
62 |
63 |
64 | class MLPProjModel(torch.nn.Module):
65 | def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4):
66 | super().__init__()
67 |
68 | self.cross_attention_dim = cross_attention_dim
69 | self.num_tokens = num_tokens
70 |
71 | self.proj = torch.nn.Sequential(
72 | torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2),
73 | torch.nn.GELU(),
74 | torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens),
75 | )
76 | self.norm = torch.nn.LayerNorm(cross_attention_dim)
77 |
78 | def forward(self, id_embeds):
79 | x = self.proj(id_embeds)
80 | x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
81 | x = self.norm(x)
82 | return x
83 |
84 |
85 | class ProjPlusModel(torch.nn.Module):
86 | def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, clip_embeddings_dim=1280, num_tokens=4):
87 | super().__init__()
88 |
89 | self.cross_attention_dim = cross_attention_dim
90 | self.num_tokens = num_tokens
91 |
92 | self.proj = torch.nn.Sequential(
93 | torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2),
94 | torch.nn.GELU(),
95 | torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens),
96 | )
97 | self.norm = torch.nn.LayerNorm(cross_attention_dim)
98 |
99 | self.perceiver_resampler = FacePerceiverResampler(
100 | dim=cross_attention_dim,
101 | depth=4,
102 | dim_head=64,
103 | heads=cross_attention_dim // 64,
104 | embedding_dim=clip_embeddings_dim,
105 | output_dim=cross_attention_dim,
106 | ff_mult=4,
107 | )
108 |
109 | def forward(self, id_embeds, clip_embeds, shortcut=False, scale=1.0):
110 |
111 | x = self.proj(id_embeds)
112 | x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
113 | x = self.norm(x)
114 | out = self.perceiver_resampler(x, clip_embeds)
115 | if shortcut:
116 | out = x + scale * out
117 | return out
118 |
119 |
120 | class IPAdapterFaceID:
121 | def __init__(self, sd_pipe, ip_ckpt, device, lora_rank=128, num_tokens=4, torch_dtype=torch.float16):
122 | self.device = device
123 | self.ip_ckpt = ip_ckpt
124 | self.lora_rank = lora_rank
125 | self.num_tokens = num_tokens
126 | self.torch_dtype = torch_dtype
127 |
128 | self.pipe = sd_pipe.to(self.device)
129 | self.set_ip_adapter()
130 |
131 | # image proj model
132 | self.image_proj_model = self.init_proj()
133 |
134 | self.load_ip_adapter()
135 |
136 | def init_proj(self):
137 | image_proj_model = MLPProjModel(
138 | cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
139 | id_embeddings_dim=512,
140 | num_tokens=self.num_tokens,
141 | ).to(self.device, dtype=self.torch_dtype)
142 | return image_proj_model
143 |
144 | def set_ip_adapter(self):
145 | unet = self.pipe.unet
146 | attn_procs = {}
147 | for name in unet.attn_processors.keys():
148 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
149 | if name.startswith("mid_block"):
150 | hidden_size = unet.config.block_out_channels[-1]
151 | elif name.startswith("up_blocks"):
152 | block_id = int(name[len("up_blocks.")])
153 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
154 | elif name.startswith("down_blocks"):
155 | block_id = int(name[len("down_blocks.")])
156 | hidden_size = unet.config.block_out_channels[block_id]
157 | if cross_attention_dim is None:
158 | attn_procs[name] = LoRAAttnProcessor(
159 | hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=self.lora_rank,
160 | ).to(self.device, dtype=self.torch_dtype)
161 | else:
162 | attn_procs[name] = LoRAIPAttnProcessor(
163 | hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, rank=self.lora_rank, num_tokens=self.num_tokens,
164 | ).to(self.device, dtype=self.torch_dtype)
165 | unet.set_attn_processor(attn_procs)
166 |
167 | def load_ip_adapter(self):
168 | if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
169 | state_dict = {"image_proj": {}, "ip_adapter": {}}
170 | with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
171 | for key in f.keys():
172 | if key.startswith("image_proj."):
173 | state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
174 | elif key.startswith("ip_adapter."):
175 | state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
176 | else:
177 | state_dict = torch.load(self.ip_ckpt, map_location="cpu")
178 | self.image_proj_model.load_state_dict(state_dict["image_proj"])
179 | ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
180 | ip_layers.load_state_dict(state_dict["ip_adapter"])
181 |
182 | @torch.inference_mode()
183 | def get_image_embeds(self, faceid_embeds):
184 |
185 | faceid_embeds = faceid_embeds.to(self.device, dtype=self.torch_dtype)
186 | image_prompt_embeds = self.image_proj_model(faceid_embeds)
187 | uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(faceid_embeds))
188 | return image_prompt_embeds, uncond_image_prompt_embeds
189 |
190 | def set_scale(self, scale):
191 | for attn_processor in self.pipe.unet.attn_processors.values():
192 | if isinstance(attn_processor, LoRAIPAttnProcessor):
193 | attn_processor.scale = scale
194 |
195 | def generate(
196 | self,
197 | faceid_embeds=None,
198 | prompt=None,
199 | negative_prompt=None,
200 | scale=1.0,
201 | num_samples=4,
202 | seed=None,
203 | guidance_scale=7.5,
204 | num_inference_steps=30,
205 | **kwargs,
206 | ):
207 | self.set_scale(scale)
208 |
209 |
210 | num_prompts = faceid_embeds.size(0)
211 |
212 | if prompt is None:
213 | prompt = "best quality, high quality"
214 | if negative_prompt is None:
215 | negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
216 |
217 | if not isinstance(prompt, List):
218 | prompt = [prompt] * num_prompts
219 | if not isinstance(negative_prompt, List):
220 | negative_prompt = [negative_prompt] * num_prompts
221 |
222 | image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds)
223 |
224 | bs_embed, seq_len, _ = image_prompt_embeds.shape
225 | image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
226 | image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
227 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
228 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
229 |
230 | with torch.inference_mode():
231 | prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
232 | prompt,
233 | device=self.device,
234 | num_images_per_prompt=num_samples,
235 | do_classifier_free_guidance=True,
236 | negative_prompt=negative_prompt,
237 | )
238 | prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
239 | negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
240 |
241 | generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
242 | images = self.pipe(
243 | prompt_embeds=prompt_embeds,
244 | negative_prompt_embeds=negative_prompt_embeds,
245 | guidance_scale=guidance_scale,
246 | num_inference_steps=num_inference_steps,
247 | generator=generator,
248 | **kwargs,
249 | ).images
250 |
251 | return images
252 |
253 |
254 | class IPAdapterFaceIDPlus:
255 | def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, lora_rank=128, num_tokens=4, torch_dtype=torch.float16):
256 | self.device = device
257 | self.image_encoder_path = image_encoder_path
258 | self.ip_ckpt = ip_ckpt
259 | self.lora_rank = lora_rank
260 | self.num_tokens = num_tokens
261 | self.torch_dtype = torch_dtype
262 |
263 | self.pipe = sd_pipe.to(self.device)
264 | self.set_ip_adapter()
265 |
266 | # load image encoder
267 | self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
268 | self.device, dtype=self.torch_dtype
269 | )
270 | self.clip_image_processor = CLIPImageProcessor()
271 | # image proj model
272 | self.image_proj_model = self.init_proj()
273 |
274 | self.load_ip_adapter()
275 |
276 | def init_proj(self):
277 | image_proj_model = ProjPlusModel(
278 | cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
279 | id_embeddings_dim=512,
280 | clip_embeddings_dim=self.image_encoder.config.hidden_size,
281 | num_tokens=self.num_tokens,
282 | ).to(self.device, dtype=self.torch_dtype)
283 | return image_proj_model
284 |
285 | def set_ip_adapter(self):
286 | unet = self.pipe.unet
287 | attn_procs = {}
288 | for name in unet.attn_processors.keys():
289 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
290 | if name.startswith("mid_block"):
291 | hidden_size = unet.config.block_out_channels[-1]
292 | elif name.startswith("up_blocks"):
293 | block_id = int(name[len("up_blocks.")])
294 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
295 | elif name.startswith("down_blocks"):
296 | block_id = int(name[len("down_blocks.")])
297 | hidden_size = unet.config.block_out_channels[block_id]
298 | if cross_attention_dim is None:
299 | attn_procs[name] = LoRAAttnProcessor(
300 | hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=self.lora_rank,
301 | ).to(self.device, dtype=self.torch_dtype)
302 | else:
303 | attn_procs[name] = LoRAIPAttnProcessor(
304 | hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, rank=self.lora_rank, num_tokens=self.num_tokens,
305 | ).to(self.device, dtype=self.torch_dtype)
306 | unet.set_attn_processor(attn_procs)
307 |
308 | def load_ip_adapter(self):
309 | if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
310 | state_dict = {"image_proj": {}, "ip_adapter": {}}
311 | with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
312 | for key in f.keys():
313 | if key.startswith("image_proj."):
314 | state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
315 | elif key.startswith("ip_adapter."):
316 | state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
317 | else:
318 | state_dict = torch.load(self.ip_ckpt, map_location="cpu")
319 | self.image_proj_model.load_state_dict(state_dict["image_proj"])
320 | ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
321 | ip_layers.load_state_dict(state_dict["ip_adapter"])
322 |
323 | @torch.inference_mode()
324 | def get_image_embeds(self, faceid_embeds, face_image, s_scale, shortcut):
325 | if isinstance(face_image, Image.Image):
326 | pil_image = [face_image]
327 | clip_image = self.clip_image_processor(images=face_image, return_tensors="pt").pixel_values
328 | clip_image = clip_image.to(self.device, dtype=self.torch_dtype)
329 | clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
330 | uncond_clip_image_embeds = self.image_encoder(
331 | torch.zeros_like(clip_image), output_hidden_states=True
332 | ).hidden_states[-2]
333 |
334 | faceid_embeds = faceid_embeds.to(self.device, dtype=self.torch_dtype)
335 | image_prompt_embeds = self.image_proj_model(faceid_embeds, clip_image_embeds, shortcut=shortcut, scale=s_scale)
336 | uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(faceid_embeds), uncond_clip_image_embeds, shortcut=shortcut, scale=s_scale)
337 | return image_prompt_embeds, uncond_image_prompt_embeds
338 |
339 | def set_scale(self, scale):
340 | for attn_processor in self.pipe.unet.attn_processors.values():
341 | if isinstance(attn_processor, LoRAIPAttnProcessor):
342 | attn_processor.scale = scale
343 |
344 | def generate(
345 | self,
346 | face_image=None,
347 | faceid_embeds=None,
348 | prompt=None,
349 | negative_prompt=None,
350 | scale=1.0,
351 | num_samples=4,
352 | seed=None,
353 | guidance_scale=7.5,
354 | num_inference_steps=30,
355 | s_scale=1.0,
356 | shortcut=False,
357 | **kwargs,
358 | ):
359 | self.set_scale(scale)
360 |
361 |
362 | num_prompts = faceid_embeds.size(0)
363 |
364 | if prompt is None:
365 | prompt = "best quality, high quality"
366 | if negative_prompt is None:
367 | negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
368 |
369 | if not isinstance(prompt, List):
370 | prompt = [prompt] * num_prompts
371 | if not isinstance(negative_prompt, List):
372 | negative_prompt = [negative_prompt] * num_prompts
373 |
374 | image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds, face_image, s_scale, shortcut)
375 |
376 | bs_embed, seq_len, _ = image_prompt_embeds.shape
377 | image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
378 | image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
379 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
380 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
381 |
382 | with torch.inference_mode():
383 | prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
384 | prompt,
385 | device=self.device,
386 | num_images_per_prompt=num_samples,
387 | do_classifier_free_guidance=True,
388 | negative_prompt=negative_prompt,
389 | )
390 | prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
391 | negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
392 |
393 | generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
394 | images = self.pipe(
395 | prompt_embeds=prompt_embeds,
396 | negative_prompt_embeds=negative_prompt_embeds,
397 | guidance_scale=guidance_scale,
398 | num_inference_steps=num_inference_steps,
399 | generator=generator,
400 | **kwargs,
401 | ).images
402 |
403 | return images
404 |
405 |
406 | class IPAdapterFaceIDXL(IPAdapterFaceID):
407 | """SDXL"""
408 |
409 | def generate(
410 | self,
411 | faceid_embeds=None,
412 | prompt=None,
413 | negative_prompt=None,
414 | scale=1.0,
415 | num_samples=4,
416 | seed=None,
417 | num_inference_steps=30,
418 | **kwargs,
419 | ):
420 | self.set_scale(scale)
421 |
422 | num_prompts = faceid_embeds.size(0)
423 |
424 | if prompt is None:
425 | prompt = "best quality, high quality"
426 | if negative_prompt is None:
427 | negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
428 |
429 | if not isinstance(prompt, List):
430 | prompt = [prompt] * num_prompts
431 | if not isinstance(negative_prompt, List):
432 | negative_prompt = [negative_prompt] * num_prompts
433 |
434 | image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds)
435 |
436 | bs_embed, seq_len, _ = image_prompt_embeds.shape
437 | image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
438 | image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
439 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
440 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
441 |
442 | with torch.inference_mode():
443 | (
444 | prompt_embeds,
445 | negative_prompt_embeds,
446 | pooled_prompt_embeds,
447 | negative_pooled_prompt_embeds,
448 | ) = self.pipe.encode_prompt(
449 | prompt,
450 | num_images_per_prompt=num_samples,
451 | do_classifier_free_guidance=True,
452 | negative_prompt=negative_prompt,
453 | )
454 | prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
455 | negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
456 |
457 | generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
458 | images = self.pipe(
459 | prompt_embeds=prompt_embeds,
460 | negative_prompt_embeds=negative_prompt_embeds,
461 | pooled_prompt_embeds=pooled_prompt_embeds,
462 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
463 | num_inference_steps=num_inference_steps,
464 | generator=generator,
465 | **kwargs,
466 | ).images
467 |
468 | return images
469 |
--------------------------------------------------------------------------------
/ip_adapter/ip_adapter_faceid_separate.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import List
3 |
4 | import torch
5 | from diffusers import StableDiffusionPipeline
6 | from diffusers.pipelines.controlnet import MultiControlNetModel
7 | from PIL import Image
8 | from safetensors import safe_open
9 | from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
10 |
11 | from .attention_processor_faceid import LoRAAttnProcessor, LoRAIPAttnProcessor
12 | from .utils import is_torch2_available
13 |
14 | USE_DAFAULT_ATTN = False # should be True for visualization_attnmap
15 | if is_torch2_available() and (not USE_DAFAULT_ATTN):
16 | from .attention_processor import (
17 | AttnProcessor2_0 as AttnProcessor,
18 | )
19 | from .attention_processor import (
20 | IPAttnProcessor2_0 as IPAttnProcessor,
21 | )
22 | else:
23 | from .attention_processor import AttnProcessor, IPAttnProcessor
24 | from .resampler import PerceiverAttention, FeedForward
25 |
26 |
27 | class FacePerceiverResampler(torch.nn.Module):
28 | def __init__(
29 | self,
30 | *,
31 | dim=768,
32 | depth=4,
33 | dim_head=64,
34 | heads=16,
35 | embedding_dim=1280,
36 | output_dim=768,
37 | ff_mult=4,
38 | ):
39 | super().__init__()
40 |
41 | self.proj_in = torch.nn.Linear(embedding_dim, dim)
42 | self.proj_out = torch.nn.Linear(dim, output_dim)
43 | self.norm_out = torch.nn.LayerNorm(output_dim)
44 | self.layers = torch.nn.ModuleList([])
45 | for _ in range(depth):
46 | self.layers.append(
47 | torch.nn.ModuleList(
48 | [
49 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
50 | FeedForward(dim=dim, mult=ff_mult),
51 | ]
52 | )
53 | )
54 |
55 | def forward(self, latents, x):
56 | x = self.proj_in(x)
57 | for attn, ff in self.layers:
58 | latents = attn(x, latents) + latents
59 | latents = ff(latents) + latents
60 | latents = self.proj_out(latents)
61 | return self.norm_out(latents)
62 |
63 |
64 | class MLPProjModel(torch.nn.Module):
65 | def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4):
66 | super().__init__()
67 |
68 | self.cross_attention_dim = cross_attention_dim
69 | self.num_tokens = num_tokens
70 |
71 | self.proj = torch.nn.Sequential(
72 | torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2),
73 | torch.nn.GELU(),
74 | torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens),
75 | )
76 | self.norm = torch.nn.LayerNorm(cross_attention_dim)
77 |
78 | def forward(self, id_embeds):
79 | x = self.proj(id_embeds)
80 | x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
81 | x = self.norm(x)
82 | return x
83 |
84 |
85 | class ProjPlusModel(torch.nn.Module):
86 | def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, clip_embeddings_dim=1280, num_tokens=4):
87 | super().__init__()
88 |
89 | self.cross_attention_dim = cross_attention_dim
90 | self.num_tokens = num_tokens
91 |
92 | self.proj = torch.nn.Sequential(
93 | torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2),
94 | torch.nn.GELU(),
95 | torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens),
96 | )
97 | self.norm = torch.nn.LayerNorm(cross_attention_dim)
98 |
99 | self.perceiver_resampler = FacePerceiverResampler(
100 | dim=cross_attention_dim,
101 | depth=4,
102 | dim_head=64,
103 | heads=cross_attention_dim // 64,
104 | embedding_dim=clip_embeddings_dim,
105 | output_dim=cross_attention_dim,
106 | ff_mult=4,
107 | )
108 |
109 | def forward(self, id_embeds, clip_embeds, shortcut=False, scale=1.0):
110 |
111 | x = self.proj(id_embeds)
112 | x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
113 | x = self.norm(x)
114 | out = self.perceiver_resampler(x, clip_embeds)
115 | if shortcut:
116 | out = x + scale * out
117 | return out
118 |
119 |
120 | class IPAdapterFaceID:
121 | def __init__(self, sd_pipe, ip_ckpt, device, num_tokens=4, torch_dtype=torch.float16):
122 | self.device = device
123 | self.ip_ckpt = ip_ckpt
124 | self.num_tokens = num_tokens
125 | self.torch_dtype = torch_dtype
126 |
127 | self.pipe = sd_pipe.to(self.device)
128 | self.set_ip_adapter()
129 |
130 | # image proj model
131 | self.image_proj_model = self.init_proj()
132 |
133 | self.load_ip_adapter()
134 |
135 | def init_proj(self):
136 | image_proj_model = MLPProjModel(
137 | cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
138 | id_embeddings_dim=512,
139 | num_tokens=self.num_tokens,
140 | ).to(self.device, dtype=self.torch_dtype)
141 | return image_proj_model
142 |
143 | def set_ip_adapter(self):
144 | unet = self.pipe.unet
145 | attn_procs = {}
146 | for name in unet.attn_processors.keys():
147 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
148 | if name.startswith("mid_block"):
149 | hidden_size = unet.config.block_out_channels[-1]
150 | elif name.startswith("up_blocks"):
151 | block_id = int(name[len("up_blocks.")])
152 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
153 | elif name.startswith("down_blocks"):
154 | block_id = int(name[len("down_blocks.")])
155 | hidden_size = unet.config.block_out_channels[block_id]
156 | if cross_attention_dim is None:
157 | attn_procs[name] = AttnProcessor()
158 | else:
159 | attn_procs[name] = IPAttnProcessor(
160 | hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, num_tokens=self.num_tokens,
161 | ).to(self.device, dtype=self.torch_dtype)
162 | unet.set_attn_processor(attn_procs)
163 |
164 | def load_ip_adapter(self):
165 | if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
166 | state_dict = {"image_proj": {}, "ip_adapter": {}}
167 | with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
168 | for key in f.keys():
169 | if key.startswith("image_proj."):
170 | state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
171 | elif key.startswith("ip_adapter."):
172 | state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
173 | else:
174 | state_dict = torch.load(self.ip_ckpt, map_location="cpu")
175 | self.image_proj_model.load_state_dict(state_dict["image_proj"])
176 | ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
177 | ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False)
178 |
179 | @torch.inference_mode()
180 | def get_image_embeds(self, faceid_embeds):
181 |
182 | faceid_embeds = faceid_embeds.to(self.device, dtype=self.torch_dtype)
183 | image_prompt_embeds = self.image_proj_model(faceid_embeds)
184 | uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(faceid_embeds))
185 | return image_prompt_embeds, uncond_image_prompt_embeds
186 |
187 | def set_scale(self, scale):
188 | for attn_processor in self.pipe.unet.attn_processors.values():
189 | if isinstance(attn_processor, LoRAIPAttnProcessor):
190 | attn_processor.scale = scale
191 |
192 | def generate(
193 | self,
194 | faceid_embeds=None,
195 | prompt=None,
196 | negative_prompt=None,
197 | scale=1.0,
198 | num_samples=4,
199 | seed=None,
200 | guidance_scale=7.5,
201 | num_inference_steps=30,
202 | **kwargs,
203 | ):
204 | self.set_scale(scale)
205 |
206 |
207 | num_prompts = faceid_embeds.size(0)
208 |
209 | if prompt is None:
210 | prompt = "best quality, high quality"
211 | if negative_prompt is None:
212 | negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
213 |
214 | if not isinstance(prompt, List):
215 | prompt = [prompt] * num_prompts
216 | if not isinstance(negative_prompt, List):
217 | negative_prompt = [negative_prompt] * num_prompts
218 |
219 | image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds)
220 |
221 | bs_embed, seq_len, _ = image_prompt_embeds.shape
222 | image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
223 | image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
224 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
225 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
226 |
227 | with torch.inference_mode():
228 | prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
229 | prompt,
230 | device=self.device,
231 | num_images_per_prompt=num_samples,
232 | do_classifier_free_guidance=True,
233 | negative_prompt=negative_prompt,
234 | )
235 | prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
236 | negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
237 |
238 | generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
239 | images = self.pipe(
240 | prompt_embeds=prompt_embeds,
241 | negative_prompt_embeds=negative_prompt_embeds,
242 | guidance_scale=guidance_scale,
243 | num_inference_steps=num_inference_steps,
244 | generator=generator,
245 | **kwargs,
246 | ).images
247 |
248 | return images
249 |
250 |
251 | class IPAdapterFaceIDPlus:
252 | def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4, torch_dtype=torch.float16):
253 | self.device = device
254 | self.image_encoder_path = image_encoder_path
255 | self.ip_ckpt = ip_ckpt
256 | self.num_tokens = num_tokens
257 | self.torch_dtype = torch_dtype
258 |
259 | self.pipe = sd_pipe.to(self.device)
260 | self.set_ip_adapter()
261 |
262 | # load image encoder
263 | self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
264 | self.device, dtype=self.torch_dtype
265 | )
266 | self.clip_image_processor = CLIPImageProcessor()
267 | # image proj model
268 | self.image_proj_model = self.init_proj()
269 |
270 | self.load_ip_adapter()
271 |
272 | def init_proj(self):
273 | image_proj_model = ProjPlusModel(
274 | cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
275 | id_embeddings_dim=512,
276 | clip_embeddings_dim=self.image_encoder.config.hidden_size,
277 | num_tokens=self.num_tokens,
278 | ).to(self.device, dtype=self.torch_dtype)
279 | return image_proj_model
280 |
281 | def set_ip_adapter(self):
282 | unet = self.pipe.unet
283 | attn_procs = {}
284 | for name in unet.attn_processors.keys():
285 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
286 | if name.startswith("mid_block"):
287 | hidden_size = unet.config.block_out_channels[-1]
288 | elif name.startswith("up_blocks"):
289 | block_id = int(name[len("up_blocks.")])
290 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
291 | elif name.startswith("down_blocks"):
292 | block_id = int(name[len("down_blocks.")])
293 | hidden_size = unet.config.block_out_channels[block_id]
294 | if cross_attention_dim is None:
295 | attn_procs[name] = AttnProcessor()
296 | else:
297 | attn_procs[name] = IPAttnProcessor(
298 | hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, num_tokens=self.num_tokens,
299 | ).to(self.device, dtype=self.torch_dtype)
300 | unet.set_attn_processor(attn_procs)
301 |
302 | def load_ip_adapter(self):
303 | if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
304 | state_dict = {"image_proj": {}, "ip_adapter": {}}
305 | with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
306 | for key in f.keys():
307 | if key.startswith("image_proj."):
308 | state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
309 | elif key.startswith("ip_adapter."):
310 | state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
311 | else:
312 | state_dict = torch.load(self.ip_ckpt, map_location="cpu")
313 | self.image_proj_model.load_state_dict(state_dict["image_proj"])
314 | ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
315 | ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False)
316 |
317 | @torch.inference_mode()
318 | def get_image_embeds(self, faceid_embeds, face_image, s_scale, shortcut):
319 | if isinstance(face_image, Image.Image):
320 | pil_image = [face_image]
321 | clip_image = self.clip_image_processor(images=face_image, return_tensors="pt").pixel_values
322 | clip_image = clip_image.to(self.device, dtype=self.torch_dtype)
323 | clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
324 | uncond_clip_image_embeds = self.image_encoder(
325 | torch.zeros_like(clip_image), output_hidden_states=True
326 | ).hidden_states[-2]
327 |
328 | faceid_embeds = faceid_embeds.to(self.device, dtype=self.torch_dtype)
329 | image_prompt_embeds = self.image_proj_model(faceid_embeds, clip_image_embeds, shortcut=shortcut, scale=s_scale)
330 | uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(faceid_embeds), uncond_clip_image_embeds, shortcut=shortcut, scale=s_scale)
331 | return image_prompt_embeds, uncond_image_prompt_embeds
332 |
333 | def set_scale(self, scale):
334 | for attn_processor in self.pipe.unet.attn_processors.values():
335 | if isinstance(attn_processor, LoRAIPAttnProcessor):
336 | attn_processor.scale = scale
337 |
338 | def generate(
339 | self,
340 | face_image=None,
341 | faceid_embeds=None,
342 | prompt=None,
343 | negative_prompt=None,
344 | scale=1.0,
345 | num_samples=4,
346 | seed=None,
347 | guidance_scale=7.5,
348 | num_inference_steps=30,
349 | s_scale=1.0,
350 | shortcut=False,
351 | **kwargs,
352 | ):
353 | self.set_scale(scale)
354 |
355 |
356 | num_prompts = faceid_embeds.size(0)
357 |
358 | if prompt is None:
359 | prompt = "best quality, high quality"
360 | if negative_prompt is None:
361 | negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
362 |
363 | if not isinstance(prompt, List):
364 | prompt = [prompt] * num_prompts
365 | if not isinstance(negative_prompt, List):
366 | negative_prompt = [negative_prompt] * num_prompts
367 |
368 | image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds, face_image, s_scale, shortcut)
369 |
370 | bs_embed, seq_len, _ = image_prompt_embeds.shape
371 | image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
372 | image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
373 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
374 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
375 |
376 | with torch.inference_mode():
377 | prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
378 | prompt,
379 | device=self.device,
380 | num_images_per_prompt=num_samples,
381 | do_classifier_free_guidance=True,
382 | negative_prompt=negative_prompt,
383 | )
384 | prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
385 | negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
386 |
387 | generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
388 | images = self.pipe(
389 | prompt_embeds=prompt_embeds,
390 | negative_prompt_embeds=negative_prompt_embeds,
391 | guidance_scale=guidance_scale,
392 | num_inference_steps=num_inference_steps,
393 | generator=generator,
394 | **kwargs,
395 | ).images
396 |
397 | return images
398 |
399 |
400 | class IPAdapterFaceIDXL(IPAdapterFaceID):
401 | """SDXL"""
402 |
403 | def generate(
404 | self,
405 | faceid_embeds=None,
406 | prompt=None,
407 | negative_prompt=None,
408 | scale=1.0,
409 | num_samples=4,
410 | seed=None,
411 | num_inference_steps=30,
412 | **kwargs,
413 | ):
414 | self.set_scale(scale)
415 |
416 | num_prompts = faceid_embeds.size(0)
417 |
418 | if prompt is None:
419 | prompt = "best quality, high quality"
420 | if negative_prompt is None:
421 | negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
422 |
423 | if not isinstance(prompt, List):
424 | prompt = [prompt] * num_prompts
425 | if not isinstance(negative_prompt, List):
426 | negative_prompt = [negative_prompt] * num_prompts
427 |
428 | image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds)
429 |
430 | bs_embed, seq_len, _ = image_prompt_embeds.shape
431 | image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
432 | image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
433 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
434 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
435 |
436 | with torch.inference_mode():
437 | (
438 | prompt_embeds,
439 | negative_prompt_embeds,
440 | pooled_prompt_embeds,
441 | negative_pooled_prompt_embeds,
442 | ) = self.pipe.encode_prompt(
443 | prompt,
444 | num_images_per_prompt=num_samples,
445 | do_classifier_free_guidance=True,
446 | negative_prompt=negative_prompt,
447 | )
448 | prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
449 | negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
450 |
451 | generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
452 | images = self.pipe(
453 | prompt_embeds=prompt_embeds,
454 | negative_prompt_embeds=negative_prompt_embeds,
455 | pooled_prompt_embeds=pooled_prompt_embeds,
456 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
457 | num_inference_steps=num_inference_steps,
458 | generator=generator,
459 | **kwargs,
460 | ).images
461 |
462 | return images
463 |
--------------------------------------------------------------------------------
/ip_adapter/resampler.py:
--------------------------------------------------------------------------------
1 | # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
2 | # and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
3 |
4 | import math
5 |
6 | import torch
7 | import torch.nn as nn
8 | from einops import rearrange
9 | from einops.layers.torch import Rearrange
10 |
11 |
12 | # FFN
13 | def FeedForward(dim, mult=4):
14 | inner_dim = int(dim * mult)
15 | return nn.Sequential(
16 | nn.LayerNorm(dim),
17 | nn.Linear(dim, inner_dim, bias=False),
18 | nn.GELU(),
19 | nn.Linear(inner_dim, dim, bias=False),
20 | )
21 |
22 |
23 | def reshape_tensor(x, heads):
24 | bs, length, width = x.shape
25 | # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
26 | x = x.view(bs, length, heads, -1)
27 | # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
28 | x = x.transpose(1, 2)
29 | # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
30 | x = x.reshape(bs, heads, length, -1)
31 | return x
32 |
33 |
34 | class PerceiverAttention(nn.Module):
35 | def __init__(self, *, dim, dim_head=64, heads=8):
36 | super().__init__()
37 | self.scale = dim_head**-0.5
38 | self.dim_head = dim_head
39 | self.heads = heads
40 | inner_dim = dim_head * heads
41 |
42 | self.norm1 = nn.LayerNorm(dim)
43 | self.norm2 = nn.LayerNorm(dim)
44 |
45 | self.to_q = nn.Linear(dim, inner_dim, bias=False)
46 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
47 | self.to_out = nn.Linear(inner_dim, dim, bias=False)
48 |
49 | def forward(self, x, latents):
50 | """
51 | Args:
52 | x (torch.Tensor): image features
53 | shape (b, n1, D)
54 | latent (torch.Tensor): latent features
55 | shape (b, n2, D)
56 | """
57 | x = self.norm1(x)
58 | latents = self.norm2(latents)
59 |
60 | b, l, _ = latents.shape
61 |
62 | q = self.to_q(latents)
63 | kv_input = torch.cat((x, latents), dim=-2)
64 | k, v = self.to_kv(kv_input).chunk(2, dim=-1)
65 |
66 | q = reshape_tensor(q, self.heads)
67 | k = reshape_tensor(k, self.heads)
68 | v = reshape_tensor(v, self.heads)
69 |
70 | # attention
71 | scale = 1 / math.sqrt(math.sqrt(self.dim_head))
72 | weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
73 | weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
74 | out = weight @ v
75 |
76 | out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
77 |
78 | return self.to_out(out)
79 |
80 |
81 | class Resampler(nn.Module):
82 | def __init__(
83 | self,
84 | dim=1024,
85 | depth=8,
86 | dim_head=64,
87 | heads=16,
88 | num_queries=8,
89 | embedding_dim=768,
90 | output_dim=1024,
91 | ff_mult=4,
92 | max_seq_len: int = 257, # CLIP tokens + CLS token
93 | apply_pos_emb: bool = False,
94 | num_latents_mean_pooled: int = 0, # number of latents derived from mean pooled representation of the sequence
95 | ):
96 | super().__init__()
97 | self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None
98 |
99 | self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
100 |
101 | self.proj_in = nn.Linear(embedding_dim, dim)
102 |
103 | self.proj_out = nn.Linear(dim, output_dim)
104 | self.norm_out = nn.LayerNorm(output_dim)
105 |
106 | self.to_latents_from_mean_pooled_seq = (
107 | nn.Sequential(
108 | nn.LayerNorm(dim),
109 | nn.Linear(dim, dim * num_latents_mean_pooled),
110 | Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
111 | )
112 | if num_latents_mean_pooled > 0
113 | else None
114 | )
115 |
116 | self.layers = nn.ModuleList([])
117 | for _ in range(depth):
118 | self.layers.append(
119 | nn.ModuleList(
120 | [
121 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
122 | FeedForward(dim=dim, mult=ff_mult),
123 | ]
124 | )
125 | )
126 |
127 | def forward(self, x):
128 | if self.pos_emb is not None:
129 | n, device = x.shape[1], x.device
130 | pos_emb = self.pos_emb(torch.arange(n, device=device))
131 | x = x + pos_emb
132 |
133 | latents = self.latents.repeat(x.size(0), 1, 1)
134 |
135 | x = self.proj_in(x)
136 |
137 | if self.to_latents_from_mean_pooled_seq:
138 | meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))
139 | meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
140 | latents = torch.cat((meanpooled_latents, latents), dim=-2)
141 |
142 | for attn, ff in self.layers:
143 | latents = attn(x, latents) + latents
144 | latents = ff(latents) + latents
145 |
146 | latents = self.proj_out(latents)
147 | return self.norm_out(latents)
148 |
149 |
150 | def masked_mean(t, *, dim, mask=None):
151 | if mask is None:
152 | return t.mean(dim=dim)
153 |
154 | denom = mask.sum(dim=dim, keepdim=True)
155 | mask = rearrange(mask, "b n -> b n 1")
156 | masked_t = t.masked_fill(~mask, 0.0)
157 |
158 | return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)
159 |
--------------------------------------------------------------------------------
/ip_adapter/test_resampler.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from resampler import Resampler
3 | from transformers import CLIPVisionModel
4 |
5 | BATCH_SIZE = 2
6 | OUTPUT_DIM = 1280
7 | NUM_QUERIES = 8
8 | NUM_LATENTS_MEAN_POOLED = 4 # 0 for no mean pooling (previous behavior)
9 | APPLY_POS_EMB = True # False for no positional embeddings (previous behavior)
10 | IMAGE_ENCODER_NAME_OR_PATH = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
11 |
12 |
13 | def main():
14 | image_encoder = CLIPVisionModel.from_pretrained(IMAGE_ENCODER_NAME_OR_PATH)
15 | embedding_dim = image_encoder.config.hidden_size
16 | print(f"image_encoder hidden size: ", embedding_dim)
17 |
18 | image_proj_model = Resampler(
19 | dim=1024,
20 | depth=2,
21 | dim_head=64,
22 | heads=16,
23 | num_queries=NUM_QUERIES,
24 | embedding_dim=embedding_dim,
25 | output_dim=OUTPUT_DIM,
26 | ff_mult=2,
27 | max_seq_len=257,
28 | apply_pos_emb=APPLY_POS_EMB,
29 | num_latents_mean_pooled=NUM_LATENTS_MEAN_POOLED,
30 | )
31 |
32 | dummy_images = torch.randn(BATCH_SIZE, 3, 224, 224)
33 | with torch.no_grad():
34 | image_embeds = image_encoder(dummy_images, output_hidden_states=True).hidden_states[-2]
35 | print("image_embds shape: ", image_embeds.shape)
36 |
37 | with torch.no_grad():
38 | ip_tokens = image_proj_model(image_embeds)
39 | print("ip_tokens shape:", ip_tokens.shape)
40 | assert ip_tokens.shape == (BATCH_SIZE, NUM_QUERIES + NUM_LATENTS_MEAN_POOLED, OUTPUT_DIM)
41 |
42 |
43 | if __name__ == "__main__":
44 | main()
45 |
--------------------------------------------------------------------------------
/ip_adapter/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 | from PIL import Image
5 |
6 | attn_maps = {}
7 | def hook_fn(name):
8 | def forward_hook(module, input, output):
9 | if hasattr(module.processor, "attn_map"):
10 | attn_maps[name] = module.processor.attn_map
11 | del module.processor.attn_map
12 |
13 | return forward_hook
14 |
15 | def register_cross_attention_hook(unet):
16 | for name, module in unet.named_modules():
17 | if name.split('.')[-1].startswith('attn2'):
18 | module.register_forward_hook(hook_fn(name))
19 |
20 | return unet
21 |
22 | def upscale(attn_map, target_size):
23 | attn_map = torch.mean(attn_map, dim=0)
24 | attn_map = attn_map.permute(1,0)
25 | temp_size = None
26 |
27 | for i in range(0,5):
28 | scale = 2 ** i
29 | if ( target_size[0] // scale ) * ( target_size[1] // scale) == attn_map.shape[1]*64:
30 | temp_size = (target_size[0]//(scale*8), target_size[1]//(scale*8))
31 | break
32 |
33 | assert temp_size is not None, "temp_size cannot is None"
34 |
35 | attn_map = attn_map.view(attn_map.shape[0], *temp_size)
36 |
37 | attn_map = F.interpolate(
38 | attn_map.unsqueeze(0).to(dtype=torch.float32),
39 | size=target_size,
40 | mode='bilinear',
41 | align_corners=False
42 | )[0]
43 |
44 | attn_map = torch.softmax(attn_map, dim=0)
45 | return attn_map
46 | def get_net_attn_map(image_size, batch_size=2, instance_or_negative=False, detach=True):
47 |
48 | idx = 0 if instance_or_negative else 1
49 | net_attn_maps = []
50 |
51 | for name, attn_map in attn_maps.items():
52 | attn_map = attn_map.cpu() if detach else attn_map
53 | attn_map = torch.chunk(attn_map, batch_size)[idx].squeeze()
54 | attn_map = upscale(attn_map, image_size)
55 | net_attn_maps.append(attn_map)
56 |
57 | net_attn_maps = torch.mean(torch.stack(net_attn_maps,dim=0),dim=0)
58 |
59 | return net_attn_maps
60 |
61 | def attnmaps2images(net_attn_maps):
62 |
63 | #total_attn_scores = 0
64 | images = []
65 |
66 | for attn_map in net_attn_maps:
67 | attn_map = attn_map.cpu().numpy()
68 | #total_attn_scores += attn_map.mean().item()
69 |
70 | normalized_attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255
71 | normalized_attn_map = normalized_attn_map.astype(np.uint8)
72 | #print("norm: ", normalized_attn_map.shape)
73 | image = Image.fromarray(normalized_attn_map)
74 |
75 | #image = fix_save_attn_map(attn_map)
76 | images.append(image)
77 |
78 | #print(total_attn_scores)
79 | return images
80 | def is_torch2_available():
81 | return hasattr(F, "scaled_dot_product_attention")
82 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "ip-adapter"
3 | version = "0.1.0"
4 | description = "IP-Adapter: Text Compatible Image Prompt Adapter for Text-to-Image Diffusion Models"
5 | authors = ["Ye, Hu", "Zhang, Jun", "Liu, Sibo", "Han, Xiao", "Yang, Wei"]
6 | license = "Apache-2.0"
7 | readme = "README.md"
8 | packages = [{ include = "ip_adapter" }]
9 |
10 | [tool.poetry.dependencies]
11 | python = ">=3.6"
12 |
13 | [tool.ruff]
14 | line-length = 119
15 | # Deprecation of Cuda 11.6 and Python 3.7 support for PyTorch 2.0
16 | target-version = "py38"
17 |
18 | # A list of file patterns to omit from linting, in addition to those specified by exclude.
19 | extend-exclude = ["__pycache__", "*.pyc", "*.egg-info", ".cache"]
20 |
21 | select = ["E", "F", "W", "C90", "I", "UP", "B", "C4", "RET", "RUF", "SIM"]
22 |
23 |
24 | ignore = [
25 | "UP006", # UP006: Use list instead of typing.List for type annotations
26 | "UP007", # UP007: Use X | Y for type annotations
27 | "UP009",
28 | "UP035",
29 | "UP038",
30 | "E402",
31 | "RET504",
32 | ]
33 |
34 | [tool.isort]
35 | profile = "black"
36 |
37 | [tool.black]
38 | line-length = 119
39 | skip-string-normalization = 1
40 |
41 | [build-system]
42 | requires = ["poetry-core"]
43 | build-backend = "poetry.core.masonry.api"
44 |
--------------------------------------------------------------------------------
/tutorial_train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import argparse
4 | from pathlib import Path
5 | import json
6 | import itertools
7 | import time
8 |
9 | import torch
10 | import torch.nn.functional as F
11 | from torchvision import transforms
12 | from PIL import Image
13 | from transformers import CLIPImageProcessor
14 | from accelerate import Accelerator
15 | from accelerate.logging import get_logger
16 | from accelerate.utils import ProjectConfiguration
17 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
18 | from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
19 |
20 | from ip_adapter.ip_adapter import ImageProjModel
21 | from ip_adapter.utils import is_torch2_available
22 | if is_torch2_available():
23 | from ip_adapter.attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor
24 | else:
25 | from ip_adapter.attention_processor import IPAttnProcessor, AttnProcessor
26 |
27 |
28 | # Dataset
29 | class MyDataset(torch.utils.data.Dataset):
30 |
31 | def __init__(self, json_file, tokenizer, size=512, t_drop_rate=0.05, i_drop_rate=0.05, ti_drop_rate=0.05, image_root_path=""):
32 | super().__init__()
33 |
34 | self.tokenizer = tokenizer
35 | self.size = size
36 | self.i_drop_rate = i_drop_rate
37 | self.t_drop_rate = t_drop_rate
38 | self.ti_drop_rate = ti_drop_rate
39 | self.image_root_path = image_root_path
40 |
41 | self.data = json.load(open(json_file)) # list of dict: [{"image_file": "1.png", "text": "A dog"}]
42 |
43 | self.transform = transforms.Compose([
44 | transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR),
45 | transforms.CenterCrop(self.size),
46 | transforms.ToTensor(),
47 | transforms.Normalize([0.5], [0.5]),
48 | ])
49 | self.clip_image_processor = CLIPImageProcessor()
50 |
51 | def __getitem__(self, idx):
52 | item = self.data[idx]
53 | text = item["text"]
54 | image_file = item["image_file"]
55 |
56 | # read image
57 | raw_image = Image.open(os.path.join(self.image_root_path, image_file))
58 | image = self.transform(raw_image.convert("RGB"))
59 | clip_image = self.clip_image_processor(images=raw_image, return_tensors="pt").pixel_values
60 |
61 | # drop
62 | drop_image_embed = 0
63 | rand_num = random.random()
64 | if rand_num < self.i_drop_rate:
65 | drop_image_embed = 1
66 | elif rand_num < (self.i_drop_rate + self.t_drop_rate):
67 | text = ""
68 | elif rand_num < (self.i_drop_rate + self.t_drop_rate + self.ti_drop_rate):
69 | text = ""
70 | drop_image_embed = 1
71 | # get text and tokenize
72 | text_input_ids = self.tokenizer(
73 | text,
74 | max_length=self.tokenizer.model_max_length,
75 | padding="max_length",
76 | truncation=True,
77 | return_tensors="pt"
78 | ).input_ids
79 |
80 | return {
81 | "image": image,
82 | "text_input_ids": text_input_ids,
83 | "clip_image": clip_image,
84 | "drop_image_embed": drop_image_embed
85 | }
86 |
87 | def __len__(self):
88 | return len(self.data)
89 |
90 |
91 | def collate_fn(data):
92 | images = torch.stack([example["image"] for example in data])
93 | text_input_ids = torch.cat([example["text_input_ids"] for example in data], dim=0)
94 | clip_images = torch.cat([example["clip_image"] for example in data], dim=0)
95 | drop_image_embeds = [example["drop_image_embed"] for example in data]
96 |
97 | return {
98 | "images": images,
99 | "text_input_ids": text_input_ids,
100 | "clip_images": clip_images,
101 | "drop_image_embeds": drop_image_embeds
102 | }
103 |
104 |
105 | class IPAdapter(torch.nn.Module):
106 | """IP-Adapter"""
107 | def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None):
108 | super().__init__()
109 | self.unet = unet
110 | self.image_proj_model = image_proj_model
111 | self.adapter_modules = adapter_modules
112 |
113 | if ckpt_path is not None:
114 | self.load_from_checkpoint(ckpt_path)
115 |
116 | def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds):
117 | ip_tokens = self.image_proj_model(image_embeds)
118 | encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)
119 | # Predict the noise residual
120 | noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample
121 | return noise_pred
122 |
123 | def load_from_checkpoint(self, ckpt_path: str):
124 | # Calculate original checksums
125 | orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
126 | orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
127 |
128 | state_dict = torch.load(ckpt_path, map_location="cpu")
129 |
130 | # Load state dict for image_proj_model and adapter_modules
131 | self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True)
132 | self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True)
133 |
134 | # Calculate new checksums
135 | new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
136 | new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
137 |
138 | # Verify if the weights have changed
139 | assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!"
140 | assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!"
141 |
142 | print(f"Successfully loaded weights from checkpoint {ckpt_path}")
143 |
144 |
145 |
146 | def parse_args():
147 | parser = argparse.ArgumentParser(description="Simple example of a training script.")
148 | parser.add_argument(
149 | "--pretrained_model_name_or_path",
150 | type=str,
151 | default=None,
152 | required=True,
153 | help="Path to pretrained model or model identifier from huggingface.co/models.",
154 | )
155 | parser.add_argument(
156 | "--pretrained_ip_adapter_path",
157 | type=str,
158 | default=None,
159 | help="Path to pretrained ip adapter model. If not specified weights are initialized randomly.",
160 | )
161 | parser.add_argument(
162 | "--data_json_file",
163 | type=str,
164 | default=None,
165 | required=True,
166 | help="Training data",
167 | )
168 | parser.add_argument(
169 | "--data_root_path",
170 | type=str,
171 | default="",
172 | required=True,
173 | help="Training data root path",
174 | )
175 | parser.add_argument(
176 | "--image_encoder_path",
177 | type=str,
178 | default=None,
179 | required=True,
180 | help="Path to CLIP image encoder",
181 | )
182 | parser.add_argument(
183 | "--output_dir",
184 | type=str,
185 | default="sd-ip_adapter",
186 | help="The output directory where the model predictions and checkpoints will be written.",
187 | )
188 | parser.add_argument(
189 | "--logging_dir",
190 | type=str,
191 | default="logs",
192 | help=(
193 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
194 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
195 | ),
196 | )
197 | parser.add_argument(
198 | "--resolution",
199 | type=int,
200 | default=512,
201 | help=(
202 | "The resolution for input images"
203 | ),
204 | )
205 | parser.add_argument(
206 | "--learning_rate",
207 | type=float,
208 | default=1e-4,
209 | help="Learning rate to use.",
210 | )
211 | parser.add_argument("--weight_decay", type=float, default=1e-2, help="Weight decay to use.")
212 | parser.add_argument("--num_train_epochs", type=int, default=100)
213 | parser.add_argument(
214 | "--train_batch_size", type=int, default=8, help="Batch size (per device) for the training dataloader."
215 | )
216 | parser.add_argument(
217 | "--dataloader_num_workers",
218 | type=int,
219 | default=0,
220 | help=(
221 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
222 | ),
223 | )
224 | parser.add_argument(
225 | "--save_steps",
226 | type=int,
227 | default=2000,
228 | help=(
229 | "Save a checkpoint of the training state every X updates"
230 | ),
231 | )
232 | parser.add_argument(
233 | "--mixed_precision",
234 | type=str,
235 | default=None,
236 | choices=["no", "fp16", "bf16"],
237 | help=(
238 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
239 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
240 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
241 | ),
242 | )
243 | parser.add_argument(
244 | "--report_to",
245 | type=str,
246 | default="tensorboard",
247 | help=(
248 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
249 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
250 | ),
251 | )
252 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
253 |
254 | args = parser.parse_args()
255 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
256 | if env_local_rank != -1 and env_local_rank != args.local_rank:
257 | args.local_rank = env_local_rank
258 |
259 | return args
260 |
261 |
262 | def main():
263 | args = parse_args()
264 | logging_dir = Path(args.output_dir, args.logging_dir)
265 |
266 | accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
267 |
268 | accelerator = Accelerator(
269 | mixed_precision=args.mixed_precision,
270 | log_with=args.report_to,
271 | project_config=accelerator_project_config,
272 | )
273 |
274 | if accelerator.is_main_process:
275 | if args.output_dir is not None:
276 | os.makedirs(args.output_dir, exist_ok=True)
277 |
278 | # Load scheduler, tokenizer and models.
279 | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
280 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
281 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
282 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
283 | unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
284 | image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path)
285 | # freeze parameters of models to save more memory
286 | unet.requires_grad_(False)
287 | vae.requires_grad_(False)
288 | text_encoder.requires_grad_(False)
289 | image_encoder.requires_grad_(False)
290 |
291 | #ip-adapter
292 | image_proj_model = ImageProjModel(
293 | cross_attention_dim=unet.config.cross_attention_dim,
294 | clip_embeddings_dim=image_encoder.config.projection_dim,
295 | clip_extra_context_tokens=4,
296 | )
297 | # init adapter modules
298 | attn_procs = {}
299 | unet_sd = unet.state_dict()
300 | for name in unet.attn_processors.keys():
301 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
302 | if name.startswith("mid_block"):
303 | hidden_size = unet.config.block_out_channels[-1]
304 | elif name.startswith("up_blocks"):
305 | block_id = int(name[len("up_blocks.")])
306 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
307 | elif name.startswith("down_blocks"):
308 | block_id = int(name[len("down_blocks.")])
309 | hidden_size = unet.config.block_out_channels[block_id]
310 | if cross_attention_dim is None:
311 | attn_procs[name] = AttnProcessor()
312 | else:
313 | layer_name = name.split(".processor")[0]
314 | weights = {
315 | "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
316 | "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
317 | }
318 | attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
319 | attn_procs[name].load_state_dict(weights)
320 | unet.set_attn_processor(attn_procs)
321 | adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
322 |
323 | ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules, args.pretrained_ip_adapter_path)
324 |
325 | weight_dtype = torch.float32
326 | if accelerator.mixed_precision == "fp16":
327 | weight_dtype = torch.float16
328 | elif accelerator.mixed_precision == "bf16":
329 | weight_dtype = torch.bfloat16
330 | #unet.to(accelerator.device, dtype=weight_dtype)
331 | vae.to(accelerator.device, dtype=weight_dtype)
332 | text_encoder.to(accelerator.device, dtype=weight_dtype)
333 | image_encoder.to(accelerator.device, dtype=weight_dtype)
334 |
335 | # optimizer
336 | params_to_opt = itertools.chain(ip_adapter.image_proj_model.parameters(), ip_adapter.adapter_modules.parameters())
337 | optimizer = torch.optim.AdamW(params_to_opt, lr=args.learning_rate, weight_decay=args.weight_decay)
338 |
339 | # dataloader
340 | train_dataset = MyDataset(args.data_json_file, tokenizer=tokenizer, size=args.resolution, image_root_path=args.data_root_path)
341 | train_dataloader = torch.utils.data.DataLoader(
342 | train_dataset,
343 | shuffle=True,
344 | collate_fn=collate_fn,
345 | batch_size=args.train_batch_size,
346 | num_workers=args.dataloader_num_workers,
347 | )
348 |
349 | # Prepare everything with our `accelerator`.
350 | ip_adapter, optimizer, train_dataloader = accelerator.prepare(ip_adapter, optimizer, train_dataloader)
351 |
352 | global_step = 0
353 | for epoch in range(0, args.num_train_epochs):
354 | begin = time.perf_counter()
355 | for step, batch in enumerate(train_dataloader):
356 | load_data_time = time.perf_counter() - begin
357 | with accelerator.accumulate(ip_adapter):
358 | # Convert images to latent space
359 | with torch.no_grad():
360 | latents = vae.encode(batch["images"].to(accelerator.device, dtype=weight_dtype)).latent_dist.sample()
361 | latents = latents * vae.config.scaling_factor
362 |
363 | # Sample noise that we'll add to the latents
364 | noise = torch.randn_like(latents)
365 | bsz = latents.shape[0]
366 | # Sample a random timestep for each image
367 | timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
368 | timesteps = timesteps.long()
369 |
370 | # Add noise to the latents according to the noise magnitude at each timestep
371 | # (this is the forward diffusion process)
372 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
373 |
374 | with torch.no_grad():
375 | image_embeds = image_encoder(batch["clip_images"].to(accelerator.device, dtype=weight_dtype)).image_embeds
376 | image_embeds_ = []
377 | for image_embed, drop_image_embed in zip(image_embeds, batch["drop_image_embeds"]):
378 | if drop_image_embed == 1:
379 | image_embeds_.append(torch.zeros_like(image_embed))
380 | else:
381 | image_embeds_.append(image_embed)
382 | image_embeds = torch.stack(image_embeds_)
383 |
384 | with torch.no_grad():
385 | encoder_hidden_states = text_encoder(batch["text_input_ids"].to(accelerator.device))[0]
386 |
387 | noise_pred = ip_adapter(noisy_latents, timesteps, encoder_hidden_states, image_embeds)
388 |
389 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
390 |
391 | # Gather the losses across all processes for logging (if we use distributed training).
392 | avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean().item()
393 |
394 | # Backpropagate
395 | accelerator.backward(loss)
396 | optimizer.step()
397 | optimizer.zero_grad()
398 |
399 | if accelerator.is_main_process:
400 | print("Epoch {}, step {}, data_time: {}, time: {}, step_loss: {}".format(
401 | epoch, step, load_data_time, time.perf_counter() - begin, avg_loss))
402 |
403 | global_step += 1
404 |
405 | if global_step % args.save_steps == 0:
406 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
407 | accelerator.save_state(save_path)
408 |
409 | begin = time.perf_counter()
410 |
411 | if __name__ == "__main__":
412 | main()
413 |
--------------------------------------------------------------------------------
/tutorial_train_faceid.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import argparse
4 | from pathlib import Path
5 | import json
6 | import itertools
7 | import time
8 |
9 | import torch
10 | import torch.nn.functional as F
11 | from torchvision import transforms
12 | from PIL import Image
13 | from transformers import CLIPImageProcessor
14 | from accelerate import Accelerator
15 | from accelerate.logging import get_logger
16 | from accelerate.utils import ProjectConfiguration
17 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
18 | from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
19 |
20 | from ip_adapter.ip_adapter_faceid import MLPProjModel
21 | from ip_adapter.utils import is_torch2_available
22 | from ip_adapter.attention_processor_faceid import LoRAAttnProcessor, LoRAIPAttnProcessor
23 |
24 |
25 | # Dataset
26 | class MyDataset(torch.utils.data.Dataset):
27 |
28 | def __init__(self, json_file, tokenizer, size=512, t_drop_rate=0.05, i_drop_rate=0.05, ti_drop_rate=0.05, image_root_path=""):
29 | super().__init__()
30 |
31 | self.tokenizer = tokenizer
32 | self.size = size
33 | self.i_drop_rate = i_drop_rate
34 | self.t_drop_rate = t_drop_rate
35 | self.ti_drop_rate = ti_drop_rate
36 | self.image_root_path = image_root_path
37 |
38 | self.data = json.load(open(json_file)) # list of dict: [{"image_file": "1.png", "id_embed_file": "faceid.bin"}]
39 |
40 | self.transform = transforms.Compose([
41 | transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR),
42 | transforms.CenterCrop(self.size),
43 | transforms.ToTensor(),
44 | transforms.Normalize([0.5], [0.5]),
45 | ])
46 |
47 |
48 |
49 | def __getitem__(self, idx):
50 | item = self.data[idx]
51 | text = item["text"]
52 | image_file = item["image_file"]
53 |
54 | # read image
55 | raw_image = Image.open(os.path.join(self.image_root_path, image_file))
56 | image = self.transform(raw_image.convert("RGB"))
57 |
58 | face_id_embed = torch.load(item["id_embed_file"], map_location="cpu")
59 | face_id_embed = torch.from_numpy(face_id_embed)
60 |
61 | # drop
62 | drop_image_embed = 0
63 | rand_num = random.random()
64 | if rand_num < self.i_drop_rate:
65 | drop_image_embed = 1
66 | elif rand_num < (self.i_drop_rate + self.t_drop_rate):
67 | text = ""
68 | elif rand_num < (self.i_drop_rate + self.t_drop_rate + self.ti_drop_rate):
69 | text = ""
70 | drop_image_embed = 1
71 | if drop_image_embed:
72 | face_id_embed = torch.zeros_like(face_id_embed)
73 | # get text and tokenize
74 | text_input_ids = self.tokenizer(
75 | text,
76 | max_length=self.tokenizer.model_max_length,
77 | padding="max_length",
78 | truncation=True,
79 | return_tensors="pt"
80 | ).input_ids
81 |
82 | return {
83 | "image": image,
84 | "text_input_ids": text_input_ids,
85 | "face_id_embed": face_id_embed,
86 | "drop_image_embed": drop_image_embed
87 | }
88 |
89 | def __len__(self):
90 | return len(self.data)
91 |
92 |
93 | def collate_fn(data):
94 | images = torch.stack([example["image"] for example in data])
95 | text_input_ids = torch.cat([example["text_input_ids"] for example in data], dim=0)
96 | face_id_embed = torch.stack([example["face_id_embed"] for example in data])
97 | drop_image_embeds = [example["drop_image_embed"] for example in data]
98 |
99 | return {
100 | "images": images,
101 | "text_input_ids": text_input_ids,
102 | "face_id_embed": face_id_embed,
103 | "drop_image_embeds": drop_image_embeds
104 | }
105 |
106 |
107 | class IPAdapter(torch.nn.Module):
108 | """IP-Adapter"""
109 | def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None):
110 | super().__init__()
111 | self.unet = unet
112 | self.image_proj_model = image_proj_model
113 | self.adapter_modules = adapter_modules
114 |
115 | if ckpt_path is not None:
116 | self.load_from_checkpoint(ckpt_path)
117 |
118 | def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds):
119 | ip_tokens = self.image_proj_model(image_embeds)
120 | encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)
121 | # Predict the noise residual
122 | noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample
123 | return noise_pred
124 |
125 | def load_from_checkpoint(self, ckpt_path: str):
126 | # Calculate original checksums
127 | orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
128 | orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
129 |
130 | state_dict = torch.load(ckpt_path, map_location="cpu")
131 |
132 | # Load state dict for image_proj_model and adapter_modules
133 | self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True)
134 | self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True)
135 |
136 | # Calculate new checksums
137 | new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
138 | new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
139 |
140 | # Verify if the weights have changed
141 | assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!"
142 | assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!"
143 |
144 | print(f"Successfully loaded weights from checkpoint {ckpt_path}")
145 |
146 |
147 |
148 | def parse_args():
149 | parser = argparse.ArgumentParser(description="Simple example of a training script.")
150 | parser.add_argument(
151 | "--pretrained_model_name_or_path",
152 | type=str,
153 | default=None,
154 | required=True,
155 | help="Path to pretrained model or model identifier from huggingface.co/models.",
156 | )
157 | parser.add_argument(
158 | "--pretrained_ip_adapter_path",
159 | type=str,
160 | default=None,
161 | help="Path to pretrained ip adapter model. If not specified weights are initialized randomly.",
162 | )
163 | parser.add_argument(
164 | "--data_json_file",
165 | type=str,
166 | default=None,
167 | required=True,
168 | help="Training data",
169 | )
170 | parser.add_argument(
171 | "--data_root_path",
172 | type=str,
173 | default="",
174 | required=True,
175 | help="Training data root path",
176 | )
177 | parser.add_argument(
178 | "--image_encoder_path",
179 | type=str,
180 | default=None,
181 | required=True,
182 | help="Path to CLIP image encoder",
183 | )
184 | parser.add_argument(
185 | "--output_dir",
186 | type=str,
187 | default="sd-ip_adapter",
188 | help="The output directory where the model predictions and checkpoints will be written.",
189 | )
190 | parser.add_argument(
191 | "--logging_dir",
192 | type=str,
193 | default="logs",
194 | help=(
195 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
196 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
197 | ),
198 | )
199 | parser.add_argument(
200 | "--resolution",
201 | type=int,
202 | default=512,
203 | help=(
204 | "The resolution for input images"
205 | ),
206 | )
207 | parser.add_argument(
208 | "--learning_rate",
209 | type=float,
210 | default=1e-4,
211 | help="Learning rate to use.",
212 | )
213 | parser.add_argument("--weight_decay", type=float, default=1e-2, help="Weight decay to use.")
214 | parser.add_argument("--num_train_epochs", type=int, default=100)
215 | parser.add_argument(
216 | "--train_batch_size", type=int, default=8, help="Batch size (per device) for the training dataloader."
217 | )
218 | parser.add_argument(
219 | "--dataloader_num_workers",
220 | type=int,
221 | default=0,
222 | help=(
223 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
224 | ),
225 | )
226 | parser.add_argument(
227 | "--save_steps",
228 | type=int,
229 | default=2000,
230 | help=(
231 | "Save a checkpoint of the training state every X updates"
232 | ),
233 | )
234 | parser.add_argument(
235 | "--mixed_precision",
236 | type=str,
237 | default=None,
238 | choices=["no", "fp16", "bf16"],
239 | help=(
240 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
241 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
242 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
243 | ),
244 | )
245 | parser.add_argument(
246 | "--report_to",
247 | type=str,
248 | default="tensorboard",
249 | help=(
250 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
251 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
252 | ),
253 | )
254 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
255 |
256 | args = parser.parse_args()
257 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
258 | if env_local_rank != -1 and env_local_rank != args.local_rank:
259 | args.local_rank = env_local_rank
260 |
261 | return args
262 |
263 |
264 | def main():
265 | args = parse_args()
266 | logging_dir = Path(args.output_dir, args.logging_dir)
267 |
268 | accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
269 |
270 | accelerator = Accelerator(
271 | mixed_precision=args.mixed_precision,
272 | log_with=args.report_to,
273 | project_config=accelerator_project_config,
274 | )
275 |
276 | if accelerator.is_main_process:
277 | if args.output_dir is not None:
278 | os.makedirs(args.output_dir, exist_ok=True)
279 |
280 | # Load scheduler, tokenizer and models.
281 | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
282 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
283 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
284 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
285 | unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
286 | # image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path)
287 | # freeze parameters of models to save more memory
288 | unet.requires_grad_(False)
289 | vae.requires_grad_(False)
290 | text_encoder.requires_grad_(False)
291 | #image_encoder.requires_grad_(False)
292 |
293 | #ip-adapter
294 | image_proj_model = MLPProjModel(
295 | cross_attention_dim=unet.config.cross_attention_dim,
296 | id_embeddings_dim=512,
297 | num_tokens=4,
298 | )
299 | # init adapter modules
300 | lora_rank = 128
301 | attn_procs = {}
302 | unet_sd = unet.state_dict()
303 | for name in unet.attn_processors.keys():
304 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
305 | if name.startswith("mid_block"):
306 | hidden_size = unet.config.block_out_channels[-1]
307 | elif name.startswith("up_blocks"):
308 | block_id = int(name[len("up_blocks.")])
309 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
310 | elif name.startswith("down_blocks"):
311 | block_id = int(name[len("down_blocks.")])
312 | hidden_size = unet.config.block_out_channels[block_id]
313 | if cross_attention_dim is None:
314 | attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=lora_rank)
315 | else:
316 | layer_name = name.split(".processor")[0]
317 | weights = {
318 | "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
319 | "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
320 | }
321 | attn_procs[name] = LoRAIPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=lora_rank)
322 | attn_procs[name].load_state_dict(weights, strict=False)
323 | unet.set_attn_processor(attn_procs)
324 | adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
325 |
326 | ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules, args.pretrained_ip_adapter_path)
327 |
328 | weight_dtype = torch.float32
329 | if accelerator.mixed_precision == "fp16":
330 | weight_dtype = torch.float16
331 | elif accelerator.mixed_precision == "bf16":
332 | weight_dtype = torch.bfloat16
333 | #unet.to(accelerator.device, dtype=weight_dtype)
334 | vae.to(accelerator.device, dtype=weight_dtype)
335 | text_encoder.to(accelerator.device, dtype=weight_dtype)
336 | #image_encoder.to(accelerator.device, dtype=weight_dtype)
337 |
338 | # optimizer
339 | params_to_opt = itertools.chain(ip_adapter.image_proj_model.parameters(), ip_adapter.adapter_modules.parameters())
340 | optimizer = torch.optim.AdamW(params_to_opt, lr=args.learning_rate, weight_decay=args.weight_decay)
341 |
342 | # dataloader
343 | train_dataset = MyDataset(args.data_json_file, tokenizer=tokenizer, size=args.resolution, image_root_path=args.data_root_path)
344 | train_dataloader = torch.utils.data.DataLoader(
345 | train_dataset,
346 | shuffle=True,
347 | collate_fn=collate_fn,
348 | batch_size=args.train_batch_size,
349 | num_workers=args.dataloader_num_workers,
350 | )
351 |
352 | # Prepare everything with our `accelerator`.
353 | ip_adapter, optimizer, train_dataloader = accelerator.prepare(ip_adapter, optimizer, train_dataloader)
354 |
355 | global_step = 0
356 | for epoch in range(0, args.num_train_epochs):
357 | begin = time.perf_counter()
358 | for step, batch in enumerate(train_dataloader):
359 | load_data_time = time.perf_counter() - begin
360 | with accelerator.accumulate(ip_adapter):
361 | # Convert images to latent space
362 | with torch.no_grad():
363 | latents = vae.encode(batch["images"].to(accelerator.device, dtype=weight_dtype)).latent_dist.sample()
364 | latents = latents * vae.config.scaling_factor
365 |
366 | # Sample noise that we'll add to the latents
367 | noise = torch.randn_like(latents)
368 | bsz = latents.shape[0]
369 | # Sample a random timestep for each image
370 | timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
371 | timesteps = timesteps.long()
372 |
373 | # Add noise to the latents according to the noise magnitude at each timestep
374 | # (this is the forward diffusion process)
375 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
376 |
377 | image_embeds = batch["face_id_embed"].to(accelerator.device, dtype=weight_dtype)
378 |
379 | with torch.no_grad():
380 | encoder_hidden_states = text_encoder(batch["text_input_ids"].to(accelerator.device))[0]
381 |
382 | noise_pred = ip_adapter(noisy_latents, timesteps, encoder_hidden_states, image_embeds)
383 |
384 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
385 |
386 | # Gather the losses across all processes for logging (if we use distributed training).
387 | avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean().item()
388 |
389 | # Backpropagate
390 | accelerator.backward(loss)
391 | optimizer.step()
392 | optimizer.zero_grad()
393 |
394 | if accelerator.is_main_process:
395 | print("Epoch {}, step {}, data_time: {}, time: {}, step_loss: {}".format(
396 | epoch, step, load_data_time, time.perf_counter() - begin, avg_loss))
397 |
398 | global_step += 1
399 |
400 | if global_step % args.save_steps == 0:
401 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
402 | accelerator.save_state(save_path)
403 |
404 | begin = time.perf_counter()
405 |
406 | if __name__ == "__main__":
407 | main()
408 |
--------------------------------------------------------------------------------
/tutorial_train_plus.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import argparse
4 | from pathlib import Path
5 | import json
6 | import itertools
7 | import time
8 |
9 | import torch
10 | import torch.nn.functional as F
11 | from torchvision import transforms
12 | from PIL import Image
13 | from transformers import CLIPImageProcessor
14 | from accelerate import Accelerator
15 | from accelerate.logging import get_logger
16 | from accelerate.utils import ProjectConfiguration
17 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
18 | from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
19 |
20 | from ip_adapter.resampler import Resampler
21 | from ip_adapter.utils import is_torch2_available
22 | if is_torch2_available():
23 | from ip_adapter.attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor
24 | else:
25 | from ip_adapter.attention_processor import IPAttnProcessor, AttnProcessor
26 |
27 |
28 | # Dataset
29 | class MyDataset(torch.utils.data.Dataset):
30 |
31 | def __init__(self, json_file, tokenizer, size=512, t_drop_rate=0.05, i_drop_rate=0.05, ti_drop_rate=0.05, image_root_path=""):
32 | super().__init__()
33 |
34 | self.tokenizer = tokenizer
35 | self.size = size
36 | self.i_drop_rate = i_drop_rate
37 | self.t_drop_rate = t_drop_rate
38 | self.ti_drop_rate = ti_drop_rate
39 | self.image_root_path = image_root_path
40 |
41 | self.data = json.load(open(json_file)) # list of dict: [{"image_file": "1.png", "text": "A dog"}]
42 |
43 | self.transform = transforms.Compose([
44 | transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR),
45 | transforms.CenterCrop(self.size),
46 | transforms.ToTensor(),
47 | transforms.Normalize([0.5], [0.5]),
48 | ])
49 | self.clip_image_processor = CLIPImageProcessor()
50 |
51 | def __getitem__(self, idx):
52 | item = self.data[idx]
53 | text = item["text"]
54 | image_file = item["image_file"]
55 |
56 | # read image
57 | raw_image = Image.open(os.path.join(self.image_root_path, image_file))
58 | image = self.transform(raw_image.convert("RGB"))
59 | clip_image = self.clip_image_processor(images=raw_image, return_tensors="pt").pixel_values
60 |
61 | # drop
62 | drop_image_embed = 0
63 | rand_num = random.random()
64 | if rand_num < self.i_drop_rate:
65 | drop_image_embed = 1
66 | elif rand_num < (self.i_drop_rate + self.t_drop_rate):
67 | text = ""
68 | elif rand_num < (self.i_drop_rate + self.t_drop_rate + self.ti_drop_rate):
69 | text = ""
70 | drop_image_embed = 1
71 | # get text and tokenize
72 | text_input_ids = self.tokenizer(
73 | text,
74 | max_length=self.tokenizer.model_max_length,
75 | padding="max_length",
76 | truncation=True,
77 | return_tensors="pt"
78 | ).input_ids
79 |
80 | return {
81 | "image": image,
82 | "text_input_ids": text_input_ids,
83 | "clip_image": clip_image,
84 | "drop_image_embed": drop_image_embed
85 | }
86 |
87 | def __len__(self):
88 | return len(self.data)
89 |
90 |
91 | def collate_fn(data):
92 | images = torch.stack([example["image"] for example in data])
93 | text_input_ids = torch.cat([example["text_input_ids"] for example in data], dim=0)
94 | clip_images = torch.cat([example["clip_image"] for example in data], dim=0)
95 | drop_image_embeds = [example["drop_image_embed"] for example in data]
96 |
97 | return {
98 | "images": images,
99 | "text_input_ids": text_input_ids,
100 | "clip_images": clip_images,
101 | "drop_image_embeds": drop_image_embeds
102 | }
103 |
104 |
105 | class IPAdapter(torch.nn.Module):
106 | """IP-Adapter"""
107 | def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None):
108 | super().__init__()
109 | self.unet = unet
110 | self.image_proj_model = image_proj_model
111 | self.adapter_modules = adapter_modules
112 |
113 | if ckpt_path is not None:
114 | self.load_from_checkpoint(ckpt_path)
115 |
116 | def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds):
117 | ip_tokens = self.image_proj_model(image_embeds)
118 | encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)
119 | # Predict the noise residual
120 | noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample
121 | return noise_pred
122 |
123 | def load_from_checkpoint(self, ckpt_path: str):
124 | # Calculate original checksums
125 | orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
126 | orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
127 |
128 | state_dict = torch.load(ckpt_path, map_location="cpu")
129 |
130 | # Check if 'latents' exists in both the saved state_dict and the current model's state_dict
131 | strict_load_image_proj_model = True
132 | if "latents" in state_dict["image_proj"] and "latents" in self.image_proj_model.state_dict():
133 | # Check if the shapes are mismatched
134 | if state_dict["image_proj"]["latents"].shape != self.image_proj_model.state_dict()["latents"].shape:
135 | print(f"Shapes of 'image_proj.latents' in checkpoint {ckpt_path} and current model do not match.")
136 | print("Removing 'latents' from checkpoint and loading the rest of the weights.")
137 | del state_dict["image_proj"]["latents"]
138 | strict_load_image_proj_model = False
139 |
140 | # Load state dict for image_proj_model and adapter_modules
141 | self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict_load_image_proj_model)
142 | self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True)
143 |
144 | # Calculate new checksums
145 | new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
146 | new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
147 |
148 | # Verify if the weights have changed
149 | assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!"
150 | assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!"
151 |
152 | print(f"Successfully loaded weights from checkpoint {ckpt_path}")
153 |
154 |
155 |
156 | def parse_args():
157 | parser = argparse.ArgumentParser(description="Simple example of a training script.")
158 | parser.add_argument(
159 | "--pretrained_model_name_or_path",
160 | type=str,
161 | default=None,
162 | required=True,
163 | help="Path to pretrained model or model identifier from huggingface.co/models.",
164 | )
165 | parser.add_argument(
166 | "--pretrained_ip_adapter_path",
167 | type=str,
168 | default=None,
169 | help="Path to pretrained ip adapter model. If not specified weights are initialized randomly.",
170 | )
171 | parser.add_argument(
172 | "--num_tokens",
173 | type=int,
174 | default=16,
175 | help="Number of tokens to query from the CLIP image encoding.",
176 | )
177 | parser.add_argument(
178 | "--data_json_file",
179 | type=str,
180 | default=None,
181 | required=True,
182 | help="Training data",
183 | )
184 | parser.add_argument(
185 | "--data_root_path",
186 | type=str,
187 | default="",
188 | required=True,
189 | help="Training data root path",
190 | )
191 | parser.add_argument(
192 | "--image_encoder_path",
193 | type=str,
194 | default=None,
195 | required=True,
196 | help="Path to CLIP image encoder",
197 | )
198 | parser.add_argument(
199 | "--output_dir",
200 | type=str,
201 | default="sd-ip_adapter",
202 | help="The output directory where the model predictions and checkpoints will be written.",
203 | )
204 | parser.add_argument(
205 | "--logging_dir",
206 | type=str,
207 | default="logs",
208 | help=(
209 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
210 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
211 | ),
212 | )
213 | parser.add_argument(
214 | "--resolution",
215 | type=int,
216 | default=512,
217 | help=(
218 | "The resolution for input images"
219 | ),
220 | )
221 | parser.add_argument(
222 | "--learning_rate",
223 | type=float,
224 | default=1e-4,
225 | help="Learning rate to use.",
226 | )
227 | parser.add_argument("--weight_decay", type=float, default=1e-2, help="Weight decay to use.")
228 | parser.add_argument("--num_train_epochs", type=int, default=100)
229 | parser.add_argument(
230 | "--train_batch_size", type=int, default=8, help="Batch size (per device) for the training dataloader."
231 | )
232 | parser.add_argument(
233 | "--dataloader_num_workers",
234 | type=int,
235 | default=0,
236 | help=(
237 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
238 | ),
239 | )
240 | parser.add_argument(
241 | "--save_steps",
242 | type=int,
243 | default=2000,
244 | help=(
245 | "Save a checkpoint of the training state every X updates"
246 | ),
247 | )
248 | parser.add_argument(
249 | "--mixed_precision",
250 | type=str,
251 | default=None,
252 | choices=["no", "fp16", "bf16"],
253 | help=(
254 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
255 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
256 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
257 | ),
258 | )
259 | parser.add_argument(
260 | "--report_to",
261 | type=str,
262 | default="tensorboard",
263 | help=(
264 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
265 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
266 | ),
267 | )
268 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
269 |
270 | args = parser.parse_args()
271 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
272 | if env_local_rank != -1 and env_local_rank != args.local_rank:
273 | args.local_rank = env_local_rank
274 |
275 | return args
276 |
277 |
278 | def main():
279 | args = parse_args()
280 | logging_dir = Path(args.output_dir, args.logging_dir)
281 |
282 | accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
283 |
284 | accelerator = Accelerator(
285 | mixed_precision=args.mixed_precision,
286 | log_with=args.report_to,
287 | project_config=accelerator_project_config,
288 | )
289 |
290 | if accelerator.is_main_process:
291 | if args.output_dir is not None:
292 | os.makedirs(args.output_dir, exist_ok=True)
293 |
294 | # Load scheduler, tokenizer and models.
295 | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
296 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
297 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
298 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
299 | unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
300 | image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path)
301 | # freeze parameters of models to save more memory
302 | unet.requires_grad_(False)
303 | vae.requires_grad_(False)
304 | text_encoder.requires_grad_(False)
305 | image_encoder.requires_grad_(False)
306 |
307 | #ip-adapter-plus
308 | image_proj_model = Resampler(
309 | dim=unet.config.cross_attention_dim,
310 | depth=4,
311 | dim_head=64,
312 | heads=12,
313 | num_queries=args.num_tokens,
314 | embedding_dim=image_encoder.config.hidden_size,
315 | output_dim=unet.config.cross_attention_dim,
316 | ff_mult=4
317 | )
318 | # init adapter modules
319 | attn_procs = {}
320 | unet_sd = unet.state_dict()
321 | for name in unet.attn_processors.keys():
322 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
323 | if name.startswith("mid_block"):
324 | hidden_size = unet.config.block_out_channels[-1]
325 | elif name.startswith("up_blocks"):
326 | block_id = int(name[len("up_blocks.")])
327 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
328 | elif name.startswith("down_blocks"):
329 | block_id = int(name[len("down_blocks.")])
330 | hidden_size = unet.config.block_out_channels[block_id]
331 | if cross_attention_dim is None:
332 | attn_procs[name] = AttnProcessor()
333 | else:
334 | layer_name = name.split(".processor")[0]
335 | weights = {
336 | "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
337 | "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
338 | }
339 | attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, num_tokens=args.num_tokens)
340 | attn_procs[name].load_state_dict(weights)
341 | unet.set_attn_processor(attn_procs)
342 | adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
343 |
344 | ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules, args.pretrained_ip_adapter_path)
345 |
346 | weight_dtype = torch.float32
347 | if accelerator.mixed_precision == "fp16":
348 | weight_dtype = torch.float16
349 | elif accelerator.mixed_precision == "bf16":
350 | weight_dtype = torch.bfloat16
351 | #unet.to(accelerator.device, dtype=weight_dtype)
352 | vae.to(accelerator.device, dtype=weight_dtype)
353 | text_encoder.to(accelerator.device, dtype=weight_dtype)
354 | image_encoder.to(accelerator.device, dtype=weight_dtype)
355 |
356 | # optimizer
357 | params_to_opt = itertools.chain(ip_adapter.image_proj_model.parameters(), ip_adapter.adapter_modules.parameters())
358 | optimizer = torch.optim.AdamW(params_to_opt, lr=args.learning_rate, weight_decay=args.weight_decay)
359 |
360 | # dataloader
361 | train_dataset = MyDataset(args.data_json_file, tokenizer=tokenizer, size=args.resolution, image_root_path=args.data_root_path)
362 | train_dataloader = torch.utils.data.DataLoader(
363 | train_dataset,
364 | shuffle=True,
365 | collate_fn=collate_fn,
366 | batch_size=args.train_batch_size,
367 | num_workers=args.dataloader_num_workers,
368 | )
369 |
370 | # Prepare everything with our `accelerator`.
371 | ip_adapter, optimizer, train_dataloader = accelerator.prepare(ip_adapter, optimizer, train_dataloader)
372 |
373 | global_step = 0
374 | for epoch in range(0, args.num_train_epochs):
375 | begin = time.perf_counter()
376 | for step, batch in enumerate(train_dataloader):
377 | load_data_time = time.perf_counter() - begin
378 | with accelerator.accumulate(ip_adapter):
379 | # Convert images to latent space
380 | with torch.no_grad():
381 | latents = vae.encode(batch["images"].to(accelerator.device, dtype=weight_dtype)).latent_dist.sample()
382 | latents = latents * vae.config.scaling_factor
383 |
384 | # Sample noise that we'll add to the latents
385 | noise = torch.randn_like(latents)
386 | bsz = latents.shape[0]
387 | # Sample a random timestep for each image
388 | timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
389 | timesteps = timesteps.long()
390 |
391 | # Add noise to the latents according to the noise magnitude at each timestep
392 | # (this is the forward diffusion process)
393 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
394 |
395 | clip_images = []
396 | for clip_image, drop_image_embed in zip(batch["clip_images"], batch["drop_image_embeds"]):
397 | if drop_image_embed == 1:
398 | clip_images.append(torch.zeros_like(clip_image))
399 | else:
400 | clip_images.append(clip_image)
401 | clip_images = torch.stack(clip_images, dim=0)
402 | with torch.no_grad():
403 | image_embeds = image_encoder(clip_images.to(accelerator.device, dtype=weight_dtype), output_hidden_states=True).hidden_states[-2]
404 |
405 | with torch.no_grad():
406 | encoder_hidden_states = text_encoder(batch["text_input_ids"].to(accelerator.device))[0]
407 |
408 | noise_pred = ip_adapter(noisy_latents, timesteps, encoder_hidden_states, image_embeds)
409 |
410 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
411 |
412 | # Gather the losses across all processes for logging (if we use distributed training).
413 | avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean().item()
414 |
415 | # Backpropagate
416 | accelerator.backward(loss)
417 | optimizer.step()
418 | optimizer.zero_grad()
419 |
420 | if accelerator.is_main_process:
421 | print("Epoch {}, step {}, data_time: {}, time: {}, step_loss: {}".format(
422 | epoch, step, load_data_time, time.perf_counter() - begin, avg_loss))
423 |
424 | global_step += 1
425 |
426 | if global_step % args.save_steps == 0:
427 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
428 | accelerator.save_state(save_path)
429 |
430 | begin = time.perf_counter()
431 |
432 | if __name__ == "__main__":
433 | main()
434 |
--------------------------------------------------------------------------------
/tutorial_train_sdxl.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import argparse
4 | from pathlib import Path
5 | import json
6 | import itertools
7 | import time
8 |
9 | import torch
10 | import torch.nn.functional as F
11 | import numpy as np
12 | from torchvision import transforms
13 | from PIL import Image
14 | from transformers import CLIPImageProcessor
15 | from accelerate import Accelerator
16 | from accelerate.logging import get_logger
17 | from accelerate.utils import ProjectConfiguration
18 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
19 | from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection, CLIPTextModelWithProjection
20 |
21 | from ip_adapter.ip_adapter import ImageProjModel
22 | from ip_adapter.utils import is_torch2_available
23 | if is_torch2_available():
24 | from ip_adapter.attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor
25 | else:
26 | from ip_adapter.attention_processor import IPAttnProcessor, AttnProcessor
27 |
28 |
29 | # Dataset
30 | class MyDataset(torch.utils.data.Dataset):
31 |
32 | def __init__(self, json_file, tokenizer, tokenizer_2, size=1024, center_crop=True, t_drop_rate=0.05, i_drop_rate=0.05, ti_drop_rate=0.05, image_root_path=""):
33 | super().__init__()
34 |
35 | self.tokenizer = tokenizer
36 | self.tokenizer_2 = tokenizer_2
37 | self.size = size
38 | self.center_crop = center_crop
39 | self.i_drop_rate = i_drop_rate
40 | self.t_drop_rate = t_drop_rate
41 | self.ti_drop_rate = ti_drop_rate
42 | self.image_root_path = image_root_path
43 |
44 | self.data = json.load(open(json_file)) # list of dict: [{"image_file": "1.png", "text": "A dog"}]
45 |
46 | self.transform = transforms.Compose([
47 | transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR),
48 | transforms.ToTensor(),
49 | transforms.Normalize([0.5], [0.5]),
50 | ])
51 |
52 | self.clip_image_processor = CLIPImageProcessor()
53 |
54 | def __getitem__(self, idx):
55 | item = self.data[idx]
56 | text = item["text"]
57 | image_file = item["image_file"]
58 |
59 | # read image
60 | raw_image = Image.open(os.path.join(self.image_root_path, image_file))
61 |
62 | # original size
63 | original_width, original_height = raw_image.size
64 | original_size = torch.tensor([original_height, original_width])
65 |
66 | image_tensor = self.transform(raw_image.convert("RGB"))
67 | # random crop
68 | delta_h = image_tensor.shape[1] - self.size
69 | delta_w = image_tensor.shape[2] - self.size
70 | assert not all([delta_h, delta_w])
71 |
72 | if self.center_crop:
73 | top = delta_h // 2
74 | left = delta_w // 2
75 | else:
76 | top = np.random.randint(0, delta_h + 1)
77 | left = np.random.randint(0, delta_w + 1)
78 | image = transforms.functional.crop(
79 | image_tensor, top=top, left=left, height=self.size, width=self.size
80 | )
81 | crop_coords_top_left = torch.tensor([top, left])
82 |
83 | clip_image = self.clip_image_processor(images=raw_image, return_tensors="pt").pixel_values
84 |
85 | # drop
86 | drop_image_embed = 0
87 | rand_num = random.random()
88 | if rand_num < self.i_drop_rate:
89 | drop_image_embed = 1
90 | elif rand_num < (self.i_drop_rate + self.t_drop_rate):
91 | text = ""
92 | elif rand_num < (self.i_drop_rate + self.t_drop_rate + self.ti_drop_rate):
93 | text = ""
94 | drop_image_embed = 1
95 |
96 | # get text and tokenize
97 | text_input_ids = self.tokenizer(
98 | text,
99 | max_length=self.tokenizer.model_max_length,
100 | padding="max_length",
101 | truncation=True,
102 | return_tensors="pt"
103 | ).input_ids
104 |
105 | text_input_ids_2 = self.tokenizer_2(
106 | text,
107 | max_length=self.tokenizer_2.model_max_length,
108 | padding="max_length",
109 | truncation=True,
110 | return_tensors="pt"
111 | ).input_ids
112 |
113 | return {
114 | "image": image,
115 | "text_input_ids": text_input_ids,
116 | "text_input_ids_2": text_input_ids_2,
117 | "clip_image": clip_image,
118 | "drop_image_embed": drop_image_embed,
119 | "original_size": original_size,
120 | "crop_coords_top_left": crop_coords_top_left,
121 | "target_size": torch.tensor([self.size, self.size]),
122 | }
123 |
124 |
125 | def __len__(self):
126 | return len(self.data)
127 |
128 |
129 | def collate_fn(data):
130 | images = torch.stack([example["image"] for example in data])
131 | text_input_ids = torch.cat([example["text_input_ids"] for example in data], dim=0)
132 | text_input_ids_2 = torch.cat([example["text_input_ids_2"] for example in data], dim=0)
133 | clip_images = torch.cat([example["clip_image"] for example in data], dim=0)
134 | drop_image_embeds = [example["drop_image_embed"] for example in data]
135 | original_size = torch.stack([example["original_size"] for example in data])
136 | crop_coords_top_left = torch.stack([example["crop_coords_top_left"] for example in data])
137 | target_size = torch.stack([example["target_size"] for example in data])
138 |
139 | return {
140 | "images": images,
141 | "text_input_ids": text_input_ids,
142 | "text_input_ids_2": text_input_ids_2,
143 | "clip_images": clip_images,
144 | "drop_image_embeds": drop_image_embeds,
145 | "original_size": original_size,
146 | "crop_coords_top_left": crop_coords_top_left,
147 | "target_size": target_size,
148 | }
149 |
150 |
151 | class IPAdapter(torch.nn.Module):
152 | """IP-Adapter"""
153 | def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None):
154 | super().__init__()
155 | self.unet = unet
156 | self.image_proj_model = image_proj_model
157 | self.adapter_modules = adapter_modules
158 |
159 | if ckpt_path is not None:
160 | self.load_from_checkpoint(ckpt_path)
161 |
162 | def forward(self, noisy_latents, timesteps, encoder_hidden_states, unet_added_cond_kwargs, image_embeds):
163 | ip_tokens = self.image_proj_model(image_embeds)
164 | encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)
165 | # Predict the noise residual
166 | noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=unet_added_cond_kwargs).sample
167 | return noise_pred
168 |
169 | def load_from_checkpoint(self, ckpt_path: str):
170 | # Calculate original checksums
171 | orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
172 | orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
173 |
174 | state_dict = torch.load(ckpt_path, map_location="cpu")
175 |
176 | # Load state dict for image_proj_model and adapter_modules
177 | self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True)
178 | self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True)
179 |
180 | # Calculate new checksums
181 | new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
182 | new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
183 |
184 | # Verify if the weights have changed
185 | assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!"
186 | assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!"
187 |
188 | print(f"Successfully loaded weights from checkpoint {ckpt_path}")
189 |
190 |
191 | def parse_args():
192 | parser = argparse.ArgumentParser(description="Simple example of a training script.")
193 | parser.add_argument(
194 | "--pretrained_model_name_or_path",
195 | type=str,
196 | default=None,
197 | required=True,
198 | help="Path to pretrained model or model identifier from huggingface.co/models.",
199 | )
200 | parser.add_argument(
201 | "--pretrained_ip_adapter_path",
202 | type=str,
203 | default=None,
204 | help="Path to pretrained ip adapter model. If not specified weights are initialized randomly.",
205 | )
206 | parser.add_argument(
207 | "--data_json_file",
208 | type=str,
209 | default=None,
210 | required=True,
211 | help="Training data",
212 | )
213 | parser.add_argument(
214 | "--data_root_path",
215 | type=str,
216 | default="",
217 | required=True,
218 | help="Training data root path",
219 | )
220 | parser.add_argument(
221 | "--image_encoder_path",
222 | type=str,
223 | default=None,
224 | required=True,
225 | help="Path to CLIP image encoder",
226 | )
227 | parser.add_argument(
228 | "--output_dir",
229 | type=str,
230 | default="sd-ip_adapter",
231 | help="The output directory where the model predictions and checkpoints will be written.",
232 | )
233 | parser.add_argument(
234 | "--logging_dir",
235 | type=str,
236 | default="logs",
237 | help=(
238 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
239 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
240 | ),
241 | )
242 | parser.add_argument(
243 | "--resolution",
244 | type=int,
245 | default=512,
246 | help=(
247 | "The resolution for input images"
248 | ),
249 | )
250 | parser.add_argument(
251 | "--learning_rate",
252 | type=float,
253 | default=1e-4,
254 | help="Learning rate to use.",
255 | )
256 | parser.add_argument("--weight_decay", type=float, default=1e-2, help="Weight decay to use.")
257 | parser.add_argument("--num_train_epochs", type=int, default=100)
258 | parser.add_argument(
259 | "--train_batch_size", type=int, default=8, help="Batch size (per device) for the training dataloader."
260 | )
261 | parser.add_argument("--noise_offset", type=float, default=None, help="noise offset")
262 | parser.add_argument(
263 | "--dataloader_num_workers",
264 | type=int,
265 | default=0,
266 | help=(
267 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
268 | ),
269 | )
270 | parser.add_argument(
271 | "--save_steps",
272 | type=int,
273 | default=2000,
274 | help=(
275 | "Save a checkpoint of the training state every X updates"
276 | ),
277 | )
278 | parser.add_argument(
279 | "--mixed_precision",
280 | type=str,
281 | default=None,
282 | choices=["no", "fp16", "bf16"],
283 | help=(
284 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
285 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
286 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
287 | ),
288 | )
289 | parser.add_argument(
290 | "--report_to",
291 | type=str,
292 | default="tensorboard",
293 | help=(
294 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
295 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
296 | ),
297 | )
298 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
299 |
300 | args = parser.parse_args()
301 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
302 | if env_local_rank != -1 and env_local_rank != args.local_rank:
303 | args.local_rank = env_local_rank
304 |
305 | return args
306 |
307 |
308 | def main():
309 | args = parse_args()
310 | logging_dir = Path(args.output_dir, args.logging_dir)
311 |
312 | accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
313 |
314 | accelerator = Accelerator(
315 | mixed_precision=args.mixed_precision,
316 | log_with=args.report_to,
317 | project_config=accelerator_project_config,
318 | )
319 |
320 | if accelerator.is_main_process:
321 | if args.output_dir is not None:
322 | os.makedirs(args.output_dir, exist_ok=True)
323 |
324 | # Load scheduler, tokenizer and models.
325 | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
326 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
327 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
328 | tokenizer_2 = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer_2")
329 | text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder_2")
330 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
331 | unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
332 | image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path)
333 | # freeze parameters of models to save more memory
334 | unet.requires_grad_(False)
335 | vae.requires_grad_(False)
336 | text_encoder.requires_grad_(False)
337 | text_encoder_2.requires_grad_(False)
338 | image_encoder.requires_grad_(False)
339 |
340 | #ip-adapter
341 | num_tokens = 4
342 | image_proj_model = ImageProjModel(
343 | cross_attention_dim=unet.config.cross_attention_dim,
344 | clip_embeddings_dim=image_encoder.config.projection_dim,
345 | clip_extra_context_tokens=num_tokens,
346 | )
347 | # init adapter modules
348 | attn_procs = {}
349 | unet_sd = unet.state_dict()
350 | for name in unet.attn_processors.keys():
351 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
352 | if name.startswith("mid_block"):
353 | hidden_size = unet.config.block_out_channels[-1]
354 | elif name.startswith("up_blocks"):
355 | block_id = int(name[len("up_blocks.")])
356 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
357 | elif name.startswith("down_blocks"):
358 | block_id = int(name[len("down_blocks.")])
359 | hidden_size = unet.config.block_out_channels[block_id]
360 | if cross_attention_dim is None:
361 | attn_procs[name] = AttnProcessor()
362 | else:
363 | layer_name = name.split(".processor")[0]
364 | weights = {
365 | "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
366 | "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
367 | }
368 | attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, num_tokens=num_tokens)
369 | attn_procs[name].load_state_dict(weights)
370 | unet.set_attn_processor(attn_procs)
371 | adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
372 |
373 | ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules, args.pretrained_ip_adapter_path)
374 |
375 | weight_dtype = torch.float32
376 | if accelerator.mixed_precision == "fp16":
377 | weight_dtype = torch.float16
378 | elif accelerator.mixed_precision == "bf16":
379 | weight_dtype = torch.bfloat16
380 | #unet.to(accelerator.device, dtype=weight_dtype)
381 | vae.to(accelerator.device) # use fp32
382 | text_encoder.to(accelerator.device, dtype=weight_dtype)
383 | text_encoder_2.to(accelerator.device, dtype=weight_dtype)
384 | image_encoder.to(accelerator.device, dtype=weight_dtype)
385 |
386 | # optimizer
387 | params_to_opt = itertools.chain(ip_adapter.image_proj_model.parameters(), ip_adapter.adapter_modules.parameters())
388 | optimizer = torch.optim.AdamW(params_to_opt, lr=args.learning_rate, weight_decay=args.weight_decay)
389 |
390 | # dataloader
391 | train_dataset = MyDataset(args.data_json_file, tokenizer=tokenizer, tokenizer_2=tokenizer_2, size=args.resolution, image_root_path=args.data_root_path)
392 | train_dataloader = torch.utils.data.DataLoader(
393 | train_dataset,
394 | shuffle=True,
395 | collate_fn=collate_fn,
396 | batch_size=args.train_batch_size,
397 | num_workers=args.dataloader_num_workers,
398 | )
399 |
400 | # Prepare everything with our `accelerator`.
401 | ip_adapter, optimizer, train_dataloader = accelerator.prepare(ip_adapter, optimizer, train_dataloader)
402 |
403 | global_step = 0
404 | for epoch in range(0, args.num_train_epochs):
405 | begin = time.perf_counter()
406 | for step, batch in enumerate(train_dataloader):
407 | load_data_time = time.perf_counter() - begin
408 | with accelerator.accumulate(ip_adapter):
409 | # Convert images to latent space
410 | with torch.no_grad():
411 | # vae of sdxl should use fp32
412 | latents = vae.encode(batch["images"].to(accelerator.device, dtype=torch.float32)).latent_dist.sample()
413 | latents = latents * vae.config.scaling_factor
414 | latents = latents.to(accelerator.device, dtype=weight_dtype)
415 |
416 | # Sample noise that we'll add to the latents
417 | noise = torch.randn_like(latents)
418 | if args.noise_offset:
419 | # https://www.crosslabs.org//blog/diffusion-with-offset-noise
420 | noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1)).to(accelerator.device, dtype=weight_dtype)
421 |
422 | bsz = latents.shape[0]
423 | # Sample a random timestep for each image
424 | timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
425 | timesteps = timesteps.long()
426 |
427 | # Add noise to the latents according to the noise magnitude at each timestep
428 | # (this is the forward diffusion process)
429 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
430 |
431 | with torch.no_grad():
432 | image_embeds = image_encoder(batch["clip_images"].to(accelerator.device, dtype=weight_dtype)).image_embeds
433 | image_embeds_ = []
434 | for image_embed, drop_image_embed in zip(image_embeds, batch["drop_image_embeds"]):
435 | if drop_image_embed == 1:
436 | image_embeds_.append(torch.zeros_like(image_embed))
437 | else:
438 | image_embeds_.append(image_embed)
439 | image_embeds = torch.stack(image_embeds_)
440 |
441 | with torch.no_grad():
442 | encoder_output = text_encoder(batch['text_input_ids'].to(accelerator.device), output_hidden_states=True)
443 | text_embeds = encoder_output.hidden_states[-2]
444 | encoder_output_2 = text_encoder_2(batch['text_input_ids_2'].to(accelerator.device), output_hidden_states=True)
445 | pooled_text_embeds = encoder_output_2[0]
446 | text_embeds_2 = encoder_output_2.hidden_states[-2]
447 | text_embeds = torch.concat([text_embeds, text_embeds_2], dim=-1) # concat
448 |
449 | # add cond
450 | add_time_ids = [
451 | batch["original_size"].to(accelerator.device),
452 | batch["crop_coords_top_left"].to(accelerator.device),
453 | batch["target_size"].to(accelerator.device),
454 | ]
455 | add_time_ids = torch.cat(add_time_ids, dim=1).to(accelerator.device, dtype=weight_dtype)
456 | unet_added_cond_kwargs = {"text_embeds": pooled_text_embeds, "time_ids": add_time_ids}
457 |
458 | noise_pred = ip_adapter(noisy_latents, timesteps, text_embeds, unet_added_cond_kwargs, image_embeds)
459 |
460 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
461 |
462 | # Gather the losses across all processes for logging (if we use distributed training).
463 | avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean().item()
464 |
465 | # Backpropagate
466 | accelerator.backward(loss)
467 | optimizer.step()
468 | optimizer.zero_grad()
469 |
470 | if accelerator.is_main_process:
471 | print("Epoch {}, step {}, data_time: {}, time: {}, step_loss: {}".format(
472 | epoch, step, load_data_time, time.perf_counter() - begin, avg_loss))
473 |
474 | global_step += 1
475 |
476 | if global_step % args.save_steps == 0:
477 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
478 | accelerator.save_state(save_path)
479 |
480 | begin = time.perf_counter()
481 |
482 | if __name__ == "__main__":
483 | main()
484 |
--------------------------------------------------------------------------------