├── LICENSE ├── README.md ├── assets ├── demo │ ├── crop_and_resize_cmp.jpg │ ├── image-to-image.jpg │ ├── image_variations.jpg │ ├── inpainting.jpg │ ├── ip_adpter_plus_image_variations.jpg │ ├── ip_adpter_plus_multi.jpg │ ├── multi_prompts.jpg │ ├── sd15_face.jpg │ ├── sdxl_cmp.jpg │ ├── structural_cond.jpg │ └── t2i-adapter_demo.jpg ├── figs │ └── fig1.png ├── images │ ├── ai_face.png │ ├── ai_face2.png │ ├── girl.png │ ├── river.png │ ├── statue.png │ ├── stone.png │ ├── vermeer.jpg │ └── woman.png ├── inpainting │ ├── image.png │ └── mask.png └── structure_controls │ ├── depth.png │ ├── depth2.png │ └── openpose.png ├── demo.ipynb ├── ip_adapter-full-face_demo.ipynb ├── ip_adapter-plus-face_demo.ipynb ├── ip_adapter-plus_demo.ipynb ├── ip_adapter-plus_sdxl_demo.ipynb ├── ip_adapter ├── __init__.py ├── attention_processor.py ├── attention_processor_faceid.py ├── custom_pipelines.py ├── ip_adapter.py ├── ip_adapter_faceid.py ├── ip_adapter_faceid_separate.py ├── resampler.py ├── test_resampler.py └── utils.py ├── ip_adapter_controlnet_demo_new.ipynb ├── ip_adapter_demo.ipynb ├── ip_adapter_multimodal_prompts_demo.ipynb ├── ip_adapter_sdxl_controlnet_demo.ipynb ├── ip_adapter_sdxl_demo.ipynb ├── ip_adapter_sdxl_plus-face_demo.ipynb ├── ip_adapter_t2i-adapter_demo.ipynb ├── ip_adapter_t2i_demo.ipynb ├── pyproject.toml ├── tutorial_train.py ├── tutorial_train_faceid.py ├── tutorial_train_plus.py ├── tutorial_train_sdxl.py ├── visualization_attnmap_faceid.ipynb └── visualization_attnmap_sdxl_plus-face.ipynb /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Stable Diffusion IP-Adapter with a negative image prompt 2 | Experimental implementation of negative image prompts. An image embedding is used for the unconditioned embedding in the same way as the negative prompt. Read the write-up [here](https://stable-diffusion-art.com/negative-image-prompt). 3 | 4 | Update: [ComfyUI version](https://github.com/sagiodev/ComfyUI_IPAdapter_plus) 5 | 6 | ## Usage 7 | Run `demo.ipynb` for a GUI. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/sagiodev/IP-Adapter-Negative/blob/main/demo.ipynb) 8 | 9 | You can specify any or all of the following 10 | - Text prompt 11 | - Image prompt 12 | - Negative text prompt 13 | - Negative image prompt 14 | 15 | It seems to behave the best when: 16 | 1. Supply a positive and a negative image prompt. Leave the text prompts empty. 17 | 2. Supply a text prompt and a negative image prompt. Leave the image prompt empty. 18 | 19 | You will need to adjust the positive and negative prompt weight to get the desired effect. 20 | 21 | ## Samples 22 | Cherry-picked examples 🤣 (Positive and negative image prompts only. Text prompts not used.) 23 | 24 | ![image](https://github.com/sagiodev/IP-Adapter-Negative/assets/3319909/ee07e920-e62b-41e7-9857-1fd896de00fb) 25 | 26 | ![image](https://github.com/sagiodev/IP-Adapter-Negative/assets/3319909/534d36a2-7ccc-4ec6-8298-56d40d1321ba) 27 | 28 | ![image](https://github.com/sagiodev/IP-Adapter-Negative/assets/3319909/b532b209-f372-4cf4-8d55-f17ade2e0595) 29 | 30 | 31 | ## IP-adapters 32 | Supports IP-adapter and IP-adapter Plus for SD 1.5 33 | 34 | 35 | 36 | 37 | # ___***IP-Adapter: Text Compatible Image Prompt Adapter for Text-to-Image Diffusion Models***___ 38 | 39 | 40 | 41 | 42 | [![GitHub](https://img.shields.io/github/stars/tencent-ailab/IP-Adapter?style=social)](https://github.com/tencent-ailab/IP-Adapter/) 43 | 44 | 45 | --- 46 | 47 | 48 | ## Introduction 49 | 50 | we present IP-Adapter, an effective and lightweight adapter to achieve image prompt capability for the pre-trained text-to-image diffusion models. An IP-Adapter with only 22M parameters can achieve comparable or even better performance to a fine-tuned image prompt model. IP-Adapter can be generalized not only to other custom models fine-tuned from the same base model, but also to controllable generation using existing controllable tools. Moreover, the image prompt can also work well with the text prompt to accomplish multimodal image generation. 51 | 52 | ![arch](assets/figs/fig1.png) 53 | 54 | ## Release 55 | - [2024/01/04] 🔥 Add an experimental version of IP-Adapter-FaceID for SDXL, more information can be found [here](https://huggingface.co/h94/IP-Adapter-FaceID). 56 | - [2023/12/29] 🔥 Add an experimental version of IP-Adapter-FaceID-PlusV2, more information can be found [here](https://huggingface.co/h94/IP-Adapter-FaceID). 57 | - [2023/12/27] 🔥 Add an experimental version of IP-Adapter-FaceID-Plus, more information can be found [here](https://huggingface.co/h94/IP-Adapter-FaceID). 58 | - [2023/12/20] 🔥 Add an experimental version of IP-Adapter-FaceID, more information can be found [here](https://huggingface.co/h94/IP-Adapter-FaceID). 59 | - [2023/11/22] IP-Adapter is available in [Diffusers](https://github.com/huggingface/diffusers/pull/5713) thanks to Diffusers Team. 60 | - [2023/11/10] 🔥 Add an updated version of IP-Adapter-Face. The demo is [here](ip_adapter-full-face_demo.ipynb). 61 | - [2023/11/05] 🔥 Add text-to-image [demo](ip_adapter_t2i_demo.ipynb) with IP-Adapter and [Kandinsky 2.2 Prior](https://huggingface.co/kandinsky-community/kandinsky-2-2-prior) 62 | - [2023/11/02] Support [safetensors](https://github.com/huggingface/safetensors) 63 | - [2023/9/08] 🔥 Update a new version of IP-Adapter with SDXL_1.0. More information can be found [here](#sdxl_10). 64 | - [2023/9/05] 🔥🔥🔥 IP-Adapter is supported in [WebUI](https://github.com/Mikubill/sd-webui-controlnet/discussions/2039) and [ComfyUI](https://github.com/laksjdjf/IPAdapter-ComfyUI) (or [ComfyUI_IPAdapter_plus](https://github.com/cubiq/ComfyUI_IPAdapter_plus)). 65 | - [2023/8/30] 🔥 Add an IP-Adapter with face image as prompt. The demo is [here](ip_adapter-plus-face_demo.ipynb). 66 | - [2023/8/29] 🔥 Release the training code. 67 | - [2023/8/23] 🔥 Add code and models of IP-Adapter with fine-grained features. The demo is [here](ip_adapter-plus_demo.ipynb). 68 | - [2023/8/18] 🔥 Add code and models for [SDXL 1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0). The demo is [here](ip_adapter_sdxl_demo.ipynb). 69 | - [2023/8/16] 🔥 We release the code and models. 70 | 71 | 72 | ## Installation 73 | 74 | ``` 75 | # install latest diffusers 76 | pip install diffusers==0.22.1 77 | 78 | # install ip-adapter 79 | pip install git+https://github.com/tencent-ailab/IP-Adapter.git 80 | 81 | # download the models 82 | cd IP-Adapter 83 | git lfs install 84 | git clone https://huggingface.co/h94/IP-Adapter 85 | mv IP-Adapter/models models 86 | mv IP-Adapter/sdxl_models sdxl_models 87 | 88 | # then you can use the notebook 89 | ``` 90 | 91 | ## Download Models 92 | 93 | you can download models from [here](https://huggingface.co/h94/IP-Adapter). To run the demo, you should also download the following models: 94 | - [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) 95 | - [stabilityai/sd-vae-ft-mse](https://huggingface.co/stabilityai/sd-vae-ft-mse) 96 | - [SG161222/Realistic_Vision_V4.0_noVAE](https://huggingface.co/SG161222/Realistic_Vision_V4.0_noVAE) 97 | - [ControlNet models](https://huggingface.co/lllyasviel) 98 | 99 | ## How to Use 100 | 101 | ### SD_1.5 102 | 103 | - [**ip_adapter_demo**](ip_adapter_demo.ipynb): image variations, image-to-image, and inpainting with image prompt. 104 | - [![**ip_adapter_demo**](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tencent-ailab/IP-Adapter/blob/main/ip_adapter_demo.ipynb) 105 | 106 | ![image variations](assets/demo/image_variations.jpg) 107 | 108 | ![image-to-image](assets/demo/image-to-image.jpg) 109 | 110 | ![inpainting](assets/demo/inpainting.jpg) 111 | 112 | - [**ip_adapter_controlnet_demo**](ip_adapter_controlnet_demo_new.ipynb), [**ip_adapter_t2i-adapter**](ip_adapter_t2i-adapter_demo.ipynb): structural generation with image prompt. 113 | - [![**ip_adapter_controlnet_demo**](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tencent-ailab/IP-Adapter/blob/main/ip_adapter_controlnet_demo.ipynb) 114 | 115 | ![structural_cond](assets/demo/structural_cond.jpg) 116 | ![structural_cond2](assets/demo/t2i-adapter_demo.jpg) 117 | 118 | - [**ip_adapter_multimodal_prompts_demo**](ip_adapter_multimodal_prompts_demo.ipynb): generation with multimodal prompts. 119 | - [![**ip_adapter_multimodal_prompts_demo**](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tencent-ailab/IP-Adapter/blob/main/ip_adapter_multimodal_prompts_demo.ipynb) 120 | 121 | ![multi_prompts](assets/demo/multi_prompts.jpg) 122 | 123 | - [**ip_adapter-plus_demo**](ip_adapter-plus_demo.ipynb): the demo of IP-Adapter with fine-grained features. 124 | 125 | ![ip_adpter_plus_image_variations](assets/demo/ip_adpter_plus_image_variations.jpg) 126 | ![ip_adpter_plus_multi](assets/demo/ip_adpter_plus_multi.jpg) 127 | 128 | - [**ip_adapter-plus-face_demo**](ip_adapter-plus-face_demo.ipynb): generation with face image as prompt. 129 | 130 | ![ip_adpter_plus_face](assets/demo/sd15_face.jpg) 131 | 132 | **Best Practice** 133 | - If you only use the image prompt, you can set the `scale=1.0` and `text_prompt=""`(or some generic text prompts, e.g. "best quality", you can also use any negative text prompt). If you lower the `scale`, more diverse images can be generated, but they may not be as consistent with the image prompt. 134 | - For multimodal prompts, you can adjust the `scale` to get the best results. In most cases, setting `scale=0.5` can get good results. For the version of SD 1.5, we recommend using community models to generate good images. 135 | 136 | **IP-Adapter for non-square images** 137 | 138 | As the image is center cropped in the default image processor of CLIP, IP-Adapter works best for square images. For the non square images, it will miss the information outside the center. But you can just resize to 224x224 for non-square images, the comparison is as follows: 139 | 140 | ![](assets/demo/crop_and_resize_cmp.jpg) 141 | 142 | ### SDXL_1.0 143 | 144 | - [**ip_adapter_sdxl_demo**](ip_adapter_sdxl_demo.ipynb): image variations with image prompt. 145 | - [**ip_adapter_sdxl_controlnet_demo**](ip_adapter_sdxl_controlnet_demo.ipynb): structural generation with image prompt. 146 | 147 | The comparison of **IP-Adapter_XL** with [Reimagine XL](https://clipdrop.co/stable-diffusion-reimagine) is shown as follows: 148 | 149 | ![sdxl_demo](assets/demo/sdxl_cmp.jpg) 150 | 151 | **Improvements in new version (2023.9.8)**: 152 | - **Switch to CLIP-ViT-H**: we trained the new IP-Adapter with [OpenCLIP-ViT-H-14](https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K) instead of [OpenCLIP-ViT-bigG-14](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k). Although ViT-bigG is much larger than ViT-H, our experimental results did not find a significant difference, and the smaller model can reduce the memory usage in the inference phase. 153 | - **A Faster and better training recipe**: In our previous version, training directly at a resolution of 1024x1024 proved to be highly inefficient. However, in the new version, we have implemented a more effective two-stage training strategy. Firstly, we perform pre-training at a resolution of 512x512. Then, we employ a multi-scale strategy for fine-tuning. (Maybe this training strategy can also be used to speed up the training of controlnet). 154 | 155 | ## How to Train 156 | For training, you should install [accelerate](https://github.com/huggingface/accelerate) and make your own dataset into a json file. 157 | 158 | ``` 159 | accelerate launch --num_processes 8 --multi_gpu --mixed_precision "fp16" \ 160 | tutorial_train.py \ 161 | --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5/" \ 162 | --image_encoder_path="{image_encoder_path}" \ 163 | --data_json_file="{data.json}" \ 164 | --data_root_path="{image_path}" \ 165 | --mixed_precision="fp16" \ 166 | --resolution=512 \ 167 | --train_batch_size=8 \ 168 | --dataloader_num_workers=4 \ 169 | --learning_rate=1e-04 \ 170 | --weight_decay=0.01 \ 171 | --output_dir="{output_dir}" \ 172 | --save_steps=10000 173 | ``` 174 | 175 | Once training is complete, you can convert the weights with the following code: 176 | 177 | ```python 178 | import torch 179 | ckpt = "checkpoint-50000/pytorch_model.bin" 180 | sd = torch.load(ckpt, map_location="cpu") 181 | image_proj_sd = {} 182 | ip_sd = {} 183 | for k in sd: 184 | if k.startswith("unet"): 185 | pass 186 | elif k.startswith("image_proj_model"): 187 | image_proj_sd[k.replace("image_proj_model.", "")] = sd[k] 188 | elif k.startswith("adapter_modules"): 189 | ip_sd[k.replace("adapter_modules.", "")] = sd[k] 190 | 191 | torch.save({"image_proj": image_proj_sd, "ip_adapter": ip_sd}, "ip_adapter.bin") 192 | ``` 193 | 194 | ## Third-party Usage 195 | - [IP-Adapter for WebUI](https://github.com/Mikubill/sd-webui-controlnet) [[release notes](https://github.com/Mikubill/sd-webui-controlnet/discussions/2039)] 196 | - IP-Adapter for ComfyUI [[IPAdapter-ComfyUI](https://github.com/laksjdjf/IPAdapter-ComfyUI) or [ComfyUI_IPAdapter_plus](https://github.com/cubiq/ComfyUI_IPAdapter_plus)] 197 | - [IP-Adapter for InvokeAI](https://github.com/invoke-ai/InvokeAI) [[release notes](https://github.com/invoke-ai/InvokeAI/releases/tag/v3.2.0)] 198 | - [IP-Adapter for AnimateDiff prompt travel](https://github.com/s9roll7/animatediff-cli-prompt-travel) 199 | - [Diffusers_IPAdapter](https://github.com/cubiq/Diffusers_IPAdapter): more features such as supporting multiple input images 200 | - [Official Diffusers ](https://github.com/huggingface/diffusers/pull/5713) 201 | 202 | ## Disclaimer 203 | 204 | This project strives to positively impact the domain of AI-driven image generation. Users are granted the freedom to create images using this tool, but they are expected to comply with local laws and utilize it in a responsible manner. **The developers do not assume any responsibility for potential misuse by users.** 205 | 206 | ## Citation 207 | If you find IP-Adapter useful for your research and applications, please cite using this BibTeX: 208 | ```bibtex 209 | @article{ye2023ip-adapter, 210 | title={IP-Adapter: Text Compatible Image Prompt Adapter for Text-to-Image Diffusion Models}, 211 | author={Ye, Hu and Zhang, Jun and Liu, Sibo and Han, Xiao and Yang, Wei}, 212 | booktitle={arXiv preprint arxiv:2308.06721}, 213 | year={2023} 214 | } 215 | ``` 216 | -------------------------------------------------------------------------------- /assets/demo/crop_and_resize_cmp.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sagiodev/IP-Adapter-Negative/912a655ef473d887847c8d61bb66d106eb39a26c/assets/demo/crop_and_resize_cmp.jpg -------------------------------------------------------------------------------- /assets/demo/image-to-image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sagiodev/IP-Adapter-Negative/912a655ef473d887847c8d61bb66d106eb39a26c/assets/demo/image-to-image.jpg -------------------------------------------------------------------------------- /assets/demo/image_variations.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sagiodev/IP-Adapter-Negative/912a655ef473d887847c8d61bb66d106eb39a26c/assets/demo/image_variations.jpg -------------------------------------------------------------------------------- /assets/demo/inpainting.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sagiodev/IP-Adapter-Negative/912a655ef473d887847c8d61bb66d106eb39a26c/assets/demo/inpainting.jpg -------------------------------------------------------------------------------- /assets/demo/ip_adpter_plus_image_variations.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sagiodev/IP-Adapter-Negative/912a655ef473d887847c8d61bb66d106eb39a26c/assets/demo/ip_adpter_plus_image_variations.jpg -------------------------------------------------------------------------------- /assets/demo/ip_adpter_plus_multi.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sagiodev/IP-Adapter-Negative/912a655ef473d887847c8d61bb66d106eb39a26c/assets/demo/ip_adpter_plus_multi.jpg -------------------------------------------------------------------------------- /assets/demo/multi_prompts.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sagiodev/IP-Adapter-Negative/912a655ef473d887847c8d61bb66d106eb39a26c/assets/demo/multi_prompts.jpg -------------------------------------------------------------------------------- /assets/demo/sd15_face.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sagiodev/IP-Adapter-Negative/912a655ef473d887847c8d61bb66d106eb39a26c/assets/demo/sd15_face.jpg -------------------------------------------------------------------------------- /assets/demo/sdxl_cmp.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sagiodev/IP-Adapter-Negative/912a655ef473d887847c8d61bb66d106eb39a26c/assets/demo/sdxl_cmp.jpg -------------------------------------------------------------------------------- /assets/demo/structural_cond.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sagiodev/IP-Adapter-Negative/912a655ef473d887847c8d61bb66d106eb39a26c/assets/demo/structural_cond.jpg -------------------------------------------------------------------------------- /assets/demo/t2i-adapter_demo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sagiodev/IP-Adapter-Negative/912a655ef473d887847c8d61bb66d106eb39a26c/assets/demo/t2i-adapter_demo.jpg -------------------------------------------------------------------------------- /assets/figs/fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sagiodev/IP-Adapter-Negative/912a655ef473d887847c8d61bb66d106eb39a26c/assets/figs/fig1.png -------------------------------------------------------------------------------- /assets/images/ai_face.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sagiodev/IP-Adapter-Negative/912a655ef473d887847c8d61bb66d106eb39a26c/assets/images/ai_face.png -------------------------------------------------------------------------------- /assets/images/ai_face2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sagiodev/IP-Adapter-Negative/912a655ef473d887847c8d61bb66d106eb39a26c/assets/images/ai_face2.png -------------------------------------------------------------------------------- /assets/images/girl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sagiodev/IP-Adapter-Negative/912a655ef473d887847c8d61bb66d106eb39a26c/assets/images/girl.png -------------------------------------------------------------------------------- /assets/images/river.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sagiodev/IP-Adapter-Negative/912a655ef473d887847c8d61bb66d106eb39a26c/assets/images/river.png -------------------------------------------------------------------------------- /assets/images/statue.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sagiodev/IP-Adapter-Negative/912a655ef473d887847c8d61bb66d106eb39a26c/assets/images/statue.png -------------------------------------------------------------------------------- /assets/images/stone.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sagiodev/IP-Adapter-Negative/912a655ef473d887847c8d61bb66d106eb39a26c/assets/images/stone.png -------------------------------------------------------------------------------- /assets/images/vermeer.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sagiodev/IP-Adapter-Negative/912a655ef473d887847c8d61bb66d106eb39a26c/assets/images/vermeer.jpg -------------------------------------------------------------------------------- /assets/images/woman.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sagiodev/IP-Adapter-Negative/912a655ef473d887847c8d61bb66d106eb39a26c/assets/images/woman.png -------------------------------------------------------------------------------- /assets/inpainting/image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sagiodev/IP-Adapter-Negative/912a655ef473d887847c8d61bb66d106eb39a26c/assets/inpainting/image.png -------------------------------------------------------------------------------- /assets/inpainting/mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sagiodev/IP-Adapter-Negative/912a655ef473d887847c8d61bb66d106eb39a26c/assets/inpainting/mask.png -------------------------------------------------------------------------------- /assets/structure_controls/depth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sagiodev/IP-Adapter-Negative/912a655ef473d887847c8d61bb66d106eb39a26c/assets/structure_controls/depth.png -------------------------------------------------------------------------------- /assets/structure_controls/depth2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sagiodev/IP-Adapter-Negative/912a655ef473d887847c8d61bb66d106eb39a26c/assets/structure_controls/depth2.png -------------------------------------------------------------------------------- /assets/structure_controls/openpose.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sagiodev/IP-Adapter-Negative/912a655ef473d887847c8d61bb66d106eb39a26c/assets/structure_controls/openpose.png -------------------------------------------------------------------------------- /ip_adapter/__init__.py: -------------------------------------------------------------------------------- 1 | from .ip_adapter import IPAdapter, IPAdapterPlus, IPAdapterPlusXL, IPAdapterXL, IPAdapterFull 2 | 3 | __all__ = [ 4 | "IPAdapter", 5 | "IPAdapterPlus", 6 | "IPAdapterPlusXL", 7 | "IPAdapterXL", 8 | "IPAdapterFull", 9 | ] 10 | -------------------------------------------------------------------------------- /ip_adapter/attention_processor.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class AttnProcessor(nn.Module): 8 | r""" 9 | Default processor for performing attention-related computations. 10 | """ 11 | 12 | def __init__( 13 | self, 14 | hidden_size=None, 15 | cross_attention_dim=None, 16 | ): 17 | super().__init__() 18 | 19 | def __call__( 20 | self, 21 | attn, 22 | hidden_states, 23 | encoder_hidden_states=None, 24 | attention_mask=None, 25 | temb=None, 26 | ): 27 | residual = hidden_states 28 | 29 | if attn.spatial_norm is not None: 30 | hidden_states = attn.spatial_norm(hidden_states, temb) 31 | 32 | input_ndim = hidden_states.ndim 33 | 34 | if input_ndim == 4: 35 | batch_size, channel, height, width = hidden_states.shape 36 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 37 | 38 | batch_size, sequence_length, _ = ( 39 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 40 | ) 41 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 42 | 43 | if attn.group_norm is not None: 44 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 45 | 46 | query = attn.to_q(hidden_states) 47 | 48 | if encoder_hidden_states is None: 49 | encoder_hidden_states = hidden_states 50 | elif attn.norm_cross: 51 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 52 | 53 | key = attn.to_k(encoder_hidden_states) 54 | value = attn.to_v(encoder_hidden_states) 55 | 56 | query = attn.head_to_batch_dim(query) 57 | key = attn.head_to_batch_dim(key) 58 | value = attn.head_to_batch_dim(value) 59 | 60 | attention_probs = attn.get_attention_scores(query, key, attention_mask) 61 | hidden_states = torch.bmm(attention_probs, value) 62 | hidden_states = attn.batch_to_head_dim(hidden_states) 63 | 64 | # linear proj 65 | hidden_states = attn.to_out[0](hidden_states) 66 | # dropout 67 | hidden_states = attn.to_out[1](hidden_states) 68 | 69 | if input_ndim == 4: 70 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 71 | 72 | if attn.residual_connection: 73 | hidden_states = hidden_states + residual 74 | 75 | hidden_states = hidden_states / attn.rescale_output_factor 76 | 77 | return hidden_states 78 | 79 | 80 | class IPAttnProcessor(nn.Module): 81 | r""" 82 | Attention processor for IP-Adapater. 83 | Args: 84 | hidden_size (`int`): 85 | The hidden size of the attention layer. 86 | cross_attention_dim (`int`): 87 | The number of channels in the `encoder_hidden_states`. 88 | scale (`float`, defaults to 1.0): 89 | the weight scale of image prompt. 90 | num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): 91 | The context length of the image features. 92 | """ 93 | 94 | def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4): 95 | super().__init__() 96 | 97 | self.hidden_size = hidden_size 98 | self.cross_attention_dim = cross_attention_dim 99 | self.scale = scale 100 | self.num_tokens = num_tokens 101 | 102 | self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) 103 | self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) 104 | 105 | def __call__( 106 | self, 107 | attn, 108 | hidden_states, 109 | encoder_hidden_states=None, 110 | attention_mask=None, 111 | temb=None, 112 | ): 113 | residual = hidden_states 114 | 115 | if attn.spatial_norm is not None: 116 | hidden_states = attn.spatial_norm(hidden_states, temb) 117 | 118 | input_ndim = hidden_states.ndim 119 | 120 | if input_ndim == 4: 121 | batch_size, channel, height, width = hidden_states.shape 122 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 123 | 124 | batch_size, sequence_length, _ = ( 125 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 126 | ) 127 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 128 | 129 | if attn.group_norm is not None: 130 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 131 | 132 | query = attn.to_q(hidden_states) 133 | 134 | if encoder_hidden_states is None: 135 | encoder_hidden_states = hidden_states 136 | else: 137 | # get encoder_hidden_states, ip_hidden_states 138 | end_pos = encoder_hidden_states.shape[1] - self.num_tokens 139 | encoder_hidden_states, ip_hidden_states = ( 140 | encoder_hidden_states[:, :end_pos, :], 141 | encoder_hidden_states[:, end_pos:, :], 142 | ) 143 | if attn.norm_cross: 144 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 145 | 146 | key = attn.to_k(encoder_hidden_states) 147 | value = attn.to_v(encoder_hidden_states) 148 | 149 | query = attn.head_to_batch_dim(query) 150 | key = attn.head_to_batch_dim(key) 151 | value = attn.head_to_batch_dim(value) 152 | 153 | attention_probs = attn.get_attention_scores(query, key, attention_mask) 154 | hidden_states = torch.bmm(attention_probs, value) 155 | hidden_states = attn.batch_to_head_dim(hidden_states) 156 | 157 | # for ip-adapter 158 | ip_key = self.to_k_ip(ip_hidden_states) 159 | ip_value = self.to_v_ip(ip_hidden_states) 160 | 161 | ip_key = attn.head_to_batch_dim(ip_key) 162 | ip_value = attn.head_to_batch_dim(ip_value) 163 | 164 | ip_attention_probs = attn.get_attention_scores(query, ip_key, None) 165 | self.attn_map = ip_attention_probs 166 | ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) 167 | ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) 168 | 169 | hidden_states = hidden_states + self.scale * ip_hidden_states 170 | 171 | # linear proj 172 | hidden_states = attn.to_out[0](hidden_states) 173 | # dropout 174 | hidden_states = attn.to_out[1](hidden_states) 175 | 176 | if input_ndim == 4: 177 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 178 | 179 | if attn.residual_connection: 180 | hidden_states = hidden_states + residual 181 | 182 | hidden_states = hidden_states / attn.rescale_output_factor 183 | 184 | return hidden_states 185 | 186 | 187 | class AttnProcessor2_0(torch.nn.Module): 188 | r""" 189 | Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). 190 | """ 191 | 192 | def __init__( 193 | self, 194 | hidden_size=None, 195 | cross_attention_dim=None, 196 | ): 197 | super().__init__() 198 | if not hasattr(F, "scaled_dot_product_attention"): 199 | raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") 200 | 201 | def __call__( 202 | self, 203 | attn, 204 | hidden_states, 205 | encoder_hidden_states=None, 206 | attention_mask=None, 207 | temb=None, 208 | ): 209 | residual = hidden_states 210 | 211 | if attn.spatial_norm is not None: 212 | hidden_states = attn.spatial_norm(hidden_states, temb) 213 | 214 | input_ndim = hidden_states.ndim 215 | 216 | if input_ndim == 4: 217 | batch_size, channel, height, width = hidden_states.shape 218 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 219 | 220 | batch_size, sequence_length, _ = ( 221 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 222 | ) 223 | 224 | if attention_mask is not None: 225 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 226 | # scaled_dot_product_attention expects attention_mask shape to be 227 | # (batch, heads, source_length, target_length) 228 | attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) 229 | 230 | if attn.group_norm is not None: 231 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 232 | 233 | query = attn.to_q(hidden_states) 234 | 235 | if encoder_hidden_states is None: 236 | encoder_hidden_states = hidden_states 237 | elif attn.norm_cross: 238 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 239 | 240 | key = attn.to_k(encoder_hidden_states) 241 | value = attn.to_v(encoder_hidden_states) 242 | 243 | inner_dim = key.shape[-1] 244 | head_dim = inner_dim // attn.heads 245 | 246 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 247 | 248 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 249 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 250 | 251 | # the output of sdp = (batch, num_heads, seq_len, head_dim) 252 | # TODO: add support for attn.scale when we move to Torch 2.1 253 | hidden_states = F.scaled_dot_product_attention( 254 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False 255 | ) 256 | 257 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 258 | hidden_states = hidden_states.to(query.dtype) 259 | 260 | # linear proj 261 | hidden_states = attn.to_out[0](hidden_states) 262 | # dropout 263 | hidden_states = attn.to_out[1](hidden_states) 264 | 265 | if input_ndim == 4: 266 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 267 | 268 | if attn.residual_connection: 269 | hidden_states = hidden_states + residual 270 | 271 | hidden_states = hidden_states / attn.rescale_output_factor 272 | 273 | return hidden_states 274 | 275 | 276 | class IPAttnProcessor2_0(torch.nn.Module): 277 | r""" 278 | Attention processor for IP-Adapater for PyTorch 2.0. 279 | Args: 280 | hidden_size (`int`): 281 | The hidden size of the attention layer. 282 | cross_attention_dim (`int`): 283 | The number of channels in the `encoder_hidden_states`. 284 | scale (`float`, defaults to 1.0): 285 | the weight scale of image prompt. 286 | num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): 287 | The context length of the image features. 288 | """ 289 | 290 | def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4): 291 | super().__init__() 292 | 293 | if not hasattr(F, "scaled_dot_product_attention"): 294 | raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") 295 | 296 | self.hidden_size = hidden_size 297 | self.cross_attention_dim = cross_attention_dim 298 | self.scale = scale 299 | self.num_tokens = num_tokens 300 | 301 | self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) 302 | self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) 303 | 304 | def __call__( 305 | self, 306 | attn, 307 | hidden_states, 308 | encoder_hidden_states=None, 309 | attention_mask=None, 310 | temb=None, 311 | ): 312 | residual = hidden_states 313 | 314 | if attn.spatial_norm is not None: 315 | hidden_states = attn.spatial_norm(hidden_states, temb) 316 | 317 | input_ndim = hidden_states.ndim 318 | 319 | if input_ndim == 4: 320 | batch_size, channel, height, width = hidden_states.shape 321 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 322 | 323 | batch_size, sequence_length, _ = ( 324 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 325 | ) 326 | 327 | if attention_mask is not None: 328 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 329 | # scaled_dot_product_attention expects attention_mask shape to be 330 | # (batch, heads, source_length, target_length) 331 | attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) 332 | 333 | if attn.group_norm is not None: 334 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 335 | 336 | query = attn.to_q(hidden_states) 337 | 338 | if encoder_hidden_states is None: 339 | encoder_hidden_states = hidden_states 340 | else: 341 | # get encoder_hidden_states, ip_hidden_states 342 | end_pos = encoder_hidden_states.shape[1] - self.num_tokens 343 | encoder_hidden_states, ip_hidden_states = ( 344 | encoder_hidden_states[:, :end_pos, :], 345 | encoder_hidden_states[:, end_pos:, :], 346 | ) 347 | if attn.norm_cross: 348 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 349 | 350 | key = attn.to_k(encoder_hidden_states) 351 | value = attn.to_v(encoder_hidden_states) 352 | 353 | inner_dim = key.shape[-1] 354 | head_dim = inner_dim // attn.heads 355 | 356 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 357 | 358 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 359 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 360 | 361 | # the output of sdp = (batch, num_heads, seq_len, head_dim) 362 | # TODO: add support for attn.scale when we move to Torch 2.1 363 | hidden_states = F.scaled_dot_product_attention( 364 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False 365 | ) 366 | 367 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 368 | hidden_states = hidden_states.to(query.dtype) 369 | 370 | # for ip-adapter 371 | ip_key = self.to_k_ip(ip_hidden_states) 372 | ip_value = self.to_v_ip(ip_hidden_states) 373 | 374 | ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 375 | ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 376 | 377 | # the output of sdp = (batch, num_heads, seq_len, head_dim) 378 | # TODO: add support for attn.scale when we move to Torch 2.1 379 | ip_hidden_states = F.scaled_dot_product_attention( 380 | query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False 381 | ) 382 | with torch.no_grad(): 383 | self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1) 384 | #print(self.attn_map.shape) 385 | 386 | ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 387 | ip_hidden_states = ip_hidden_states.to(query.dtype) 388 | 389 | scales = torch.ones(ip_hidden_states.shape, dtype= ip_hidden_states.dtype) 390 | scales[0:int(scales.shape[0]/2)] = self.scale[0] # uncond 391 | scales[int(scales.shape[0]/2):] = self.scale[1] # cond 392 | scales = scales.to(hidden_states.device) 393 | 394 | hidden_states = hidden_states + scales * ip_hidden_states 395 | 396 | # linear proj 397 | hidden_states = attn.to_out[0](hidden_states) 398 | # dropout 399 | hidden_states = attn.to_out[1](hidden_states) 400 | 401 | if input_ndim == 4: 402 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 403 | 404 | if attn.residual_connection: 405 | hidden_states = hidden_states + residual 406 | 407 | hidden_states = hidden_states / attn.rescale_output_factor 408 | 409 | return hidden_states 410 | 411 | 412 | ## for controlnet 413 | class CNAttnProcessor: 414 | r""" 415 | Default processor for performing attention-related computations. 416 | """ 417 | 418 | def __init__(self, num_tokens=4): 419 | self.num_tokens = num_tokens 420 | 421 | def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None): 422 | residual = hidden_states 423 | 424 | if attn.spatial_norm is not None: 425 | hidden_states = attn.spatial_norm(hidden_states, temb) 426 | 427 | input_ndim = hidden_states.ndim 428 | 429 | if input_ndim == 4: 430 | batch_size, channel, height, width = hidden_states.shape 431 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 432 | 433 | batch_size, sequence_length, _ = ( 434 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 435 | ) 436 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 437 | 438 | if attn.group_norm is not None: 439 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 440 | 441 | query = attn.to_q(hidden_states) 442 | 443 | if encoder_hidden_states is None: 444 | encoder_hidden_states = hidden_states 445 | else: 446 | end_pos = encoder_hidden_states.shape[1] - self.num_tokens 447 | encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text 448 | if attn.norm_cross: 449 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 450 | 451 | key = attn.to_k(encoder_hidden_states) 452 | value = attn.to_v(encoder_hidden_states) 453 | 454 | query = attn.head_to_batch_dim(query) 455 | key = attn.head_to_batch_dim(key) 456 | value = attn.head_to_batch_dim(value) 457 | 458 | attention_probs = attn.get_attention_scores(query, key, attention_mask) 459 | hidden_states = torch.bmm(attention_probs, value) 460 | hidden_states = attn.batch_to_head_dim(hidden_states) 461 | 462 | # linear proj 463 | hidden_states = attn.to_out[0](hidden_states) 464 | # dropout 465 | hidden_states = attn.to_out[1](hidden_states) 466 | 467 | if input_ndim == 4: 468 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 469 | 470 | if attn.residual_connection: 471 | hidden_states = hidden_states + residual 472 | 473 | hidden_states = hidden_states / attn.rescale_output_factor 474 | 475 | return hidden_states 476 | 477 | 478 | class CNAttnProcessor2_0: 479 | r""" 480 | Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). 481 | """ 482 | 483 | def __init__(self, num_tokens=4): 484 | if not hasattr(F, "scaled_dot_product_attention"): 485 | raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") 486 | self.num_tokens = num_tokens 487 | 488 | def __call__( 489 | self, 490 | attn, 491 | hidden_states, 492 | encoder_hidden_states=None, 493 | attention_mask=None, 494 | temb=None, 495 | ): 496 | residual = hidden_states 497 | 498 | if attn.spatial_norm is not None: 499 | hidden_states = attn.spatial_norm(hidden_states, temb) 500 | 501 | input_ndim = hidden_states.ndim 502 | 503 | if input_ndim == 4: 504 | batch_size, channel, height, width = hidden_states.shape 505 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 506 | 507 | batch_size, sequence_length, _ = ( 508 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 509 | ) 510 | 511 | if attention_mask is not None: 512 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 513 | # scaled_dot_product_attention expects attention_mask shape to be 514 | # (batch, heads, source_length, target_length) 515 | attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) 516 | 517 | if attn.group_norm is not None: 518 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 519 | 520 | query = attn.to_q(hidden_states) 521 | 522 | if encoder_hidden_states is None: 523 | encoder_hidden_states = hidden_states 524 | else: 525 | end_pos = encoder_hidden_states.shape[1] - self.num_tokens 526 | encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text 527 | if attn.norm_cross: 528 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 529 | 530 | key = attn.to_k(encoder_hidden_states) 531 | value = attn.to_v(encoder_hidden_states) 532 | 533 | inner_dim = key.shape[-1] 534 | head_dim = inner_dim // attn.heads 535 | 536 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 537 | 538 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 539 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 540 | 541 | # the output of sdp = (batch, num_heads, seq_len, head_dim) 542 | # TODO: add support for attn.scale when we move to Torch 2.1 543 | hidden_states = F.scaled_dot_product_attention( 544 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False 545 | ) 546 | 547 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 548 | hidden_states = hidden_states.to(query.dtype) 549 | 550 | # linear proj 551 | hidden_states = attn.to_out[0](hidden_states) 552 | # dropout 553 | hidden_states = attn.to_out[1](hidden_states) 554 | 555 | if input_ndim == 4: 556 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 557 | 558 | if attn.residual_connection: 559 | hidden_states = hidden_states + residual 560 | 561 | hidden_states = hidden_states / attn.rescale_output_factor 562 | 563 | return hidden_states 564 | -------------------------------------------------------------------------------- /ip_adapter/attention_processor_faceid.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from diffusers.models.lora import LoRALinearLayer 7 | 8 | 9 | class LoRAAttnProcessor(nn.Module): 10 | r""" 11 | Default processor for performing attention-related computations. 12 | """ 13 | 14 | def __init__( 15 | self, 16 | hidden_size=None, 17 | cross_attention_dim=None, 18 | rank=4, 19 | network_alpha=None, 20 | lora_scale=1.0, 21 | ): 22 | super().__init__() 23 | 24 | self.rank = rank 25 | self.lora_scale = lora_scale 26 | 27 | self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) 28 | self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) 29 | self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) 30 | self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) 31 | 32 | def __call__( 33 | self, 34 | attn, 35 | hidden_states, 36 | encoder_hidden_states=None, 37 | attention_mask=None, 38 | temb=None, 39 | ): 40 | residual = hidden_states 41 | 42 | if attn.spatial_norm is not None: 43 | hidden_states = attn.spatial_norm(hidden_states, temb) 44 | 45 | input_ndim = hidden_states.ndim 46 | 47 | if input_ndim == 4: 48 | batch_size, channel, height, width = hidden_states.shape 49 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 50 | 51 | batch_size, sequence_length, _ = ( 52 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 53 | ) 54 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 55 | 56 | if attn.group_norm is not None: 57 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 58 | 59 | query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states) 60 | 61 | if encoder_hidden_states is None: 62 | encoder_hidden_states = hidden_states 63 | elif attn.norm_cross: 64 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 65 | 66 | key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states) 67 | value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states) 68 | 69 | query = attn.head_to_batch_dim(query) 70 | key = attn.head_to_batch_dim(key) 71 | value = attn.head_to_batch_dim(value) 72 | 73 | attention_probs = attn.get_attention_scores(query, key, attention_mask) 74 | hidden_states = torch.bmm(attention_probs, value) 75 | hidden_states = attn.batch_to_head_dim(hidden_states) 76 | 77 | # linear proj 78 | hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states) 79 | # dropout 80 | hidden_states = attn.to_out[1](hidden_states) 81 | 82 | if input_ndim == 4: 83 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 84 | 85 | if attn.residual_connection: 86 | hidden_states = hidden_states + residual 87 | 88 | hidden_states = hidden_states / attn.rescale_output_factor 89 | 90 | return hidden_states 91 | 92 | 93 | class LoRAIPAttnProcessor(nn.Module): 94 | r""" 95 | Attention processor for IP-Adapater. 96 | Args: 97 | hidden_size (`int`): 98 | The hidden size of the attention layer. 99 | cross_attention_dim (`int`): 100 | The number of channels in the `encoder_hidden_states`. 101 | scale (`float`, defaults to 1.0): 102 | the weight scale of image prompt. 103 | num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): 104 | The context length of the image features. 105 | """ 106 | 107 | def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None, lora_scale=1.0, scale=1.0, num_tokens=4): 108 | super().__init__() 109 | 110 | self.rank = rank 111 | self.lora_scale = lora_scale 112 | 113 | self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) 114 | self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) 115 | self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) 116 | self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) 117 | 118 | self.hidden_size = hidden_size 119 | self.cross_attention_dim = cross_attention_dim 120 | self.scale = scale 121 | self.num_tokens = num_tokens 122 | 123 | self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) 124 | self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) 125 | 126 | def __call__( 127 | self, 128 | attn, 129 | hidden_states, 130 | encoder_hidden_states=None, 131 | attention_mask=None, 132 | temb=None, 133 | ): 134 | residual = hidden_states 135 | 136 | if attn.spatial_norm is not None: 137 | hidden_states = attn.spatial_norm(hidden_states, temb) 138 | 139 | input_ndim = hidden_states.ndim 140 | 141 | if input_ndim == 4: 142 | batch_size, channel, height, width = hidden_states.shape 143 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 144 | 145 | batch_size, sequence_length, _ = ( 146 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 147 | ) 148 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 149 | 150 | if attn.group_norm is not None: 151 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 152 | 153 | query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states) 154 | 155 | if encoder_hidden_states is None: 156 | encoder_hidden_states = hidden_states 157 | else: 158 | # get encoder_hidden_states, ip_hidden_states 159 | end_pos = encoder_hidden_states.shape[1] - self.num_tokens 160 | encoder_hidden_states, ip_hidden_states = ( 161 | encoder_hidden_states[:, :end_pos, :], 162 | encoder_hidden_states[:, end_pos:, :], 163 | ) 164 | if attn.norm_cross: 165 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 166 | 167 | key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states) 168 | value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states) 169 | 170 | query = attn.head_to_batch_dim(query) 171 | key = attn.head_to_batch_dim(key) 172 | value = attn.head_to_batch_dim(value) 173 | 174 | attention_probs = attn.get_attention_scores(query, key, attention_mask) 175 | hidden_states = torch.bmm(attention_probs, value) 176 | hidden_states = attn.batch_to_head_dim(hidden_states) 177 | 178 | # for ip-adapter 179 | ip_key = self.to_k_ip(ip_hidden_states) 180 | ip_value = self.to_v_ip(ip_hidden_states) 181 | 182 | ip_key = attn.head_to_batch_dim(ip_key) 183 | ip_value = attn.head_to_batch_dim(ip_value) 184 | 185 | ip_attention_probs = attn.get_attention_scores(query, ip_key, None) 186 | self.attn_map = ip_attention_probs 187 | ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) 188 | ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) 189 | 190 | hidden_states = hidden_states + self.scale * ip_hidden_states 191 | 192 | # linear proj 193 | hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states) 194 | # dropout 195 | hidden_states = attn.to_out[1](hidden_states) 196 | 197 | if input_ndim == 4: 198 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 199 | 200 | if attn.residual_connection: 201 | hidden_states = hidden_states + residual 202 | 203 | hidden_states = hidden_states / attn.rescale_output_factor 204 | 205 | return hidden_states 206 | 207 | 208 | class LoRAAttnProcessor2_0(nn.Module): 209 | 210 | r""" 211 | Default processor for performing attention-related computations. 212 | """ 213 | 214 | def __init__( 215 | self, 216 | hidden_size=None, 217 | cross_attention_dim=None, 218 | rank=4, 219 | network_alpha=None, 220 | lora_scale=1.0, 221 | ): 222 | super().__init__() 223 | 224 | self.rank = rank 225 | self.lora_scale = lora_scale 226 | 227 | self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) 228 | self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) 229 | self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) 230 | self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) 231 | 232 | def __call__( 233 | self, 234 | attn, 235 | hidden_states, 236 | encoder_hidden_states=None, 237 | attention_mask=None, 238 | temb=None, 239 | ): 240 | residual = hidden_states 241 | 242 | if attn.spatial_norm is not None: 243 | hidden_states = attn.spatial_norm(hidden_states, temb) 244 | 245 | input_ndim = hidden_states.ndim 246 | 247 | if input_ndim == 4: 248 | batch_size, channel, height, width = hidden_states.shape 249 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 250 | 251 | batch_size, sequence_length, _ = ( 252 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 253 | ) 254 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 255 | 256 | if attn.group_norm is not None: 257 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 258 | 259 | query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states) 260 | 261 | if encoder_hidden_states is None: 262 | encoder_hidden_states = hidden_states 263 | elif attn.norm_cross: 264 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 265 | 266 | key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states) 267 | value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states) 268 | 269 | inner_dim = key.shape[-1] 270 | head_dim = inner_dim // attn.heads 271 | 272 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 273 | 274 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 275 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 276 | 277 | # the output of sdp = (batch, num_heads, seq_len, head_dim) 278 | # TODO: add support for attn.scale when we move to Torch 2.1 279 | hidden_states = F.scaled_dot_product_attention( 280 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False 281 | ) 282 | 283 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 284 | hidden_states = hidden_states.to(query.dtype) 285 | 286 | # linear proj 287 | hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states) 288 | # dropout 289 | hidden_states = attn.to_out[1](hidden_states) 290 | 291 | if input_ndim == 4: 292 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 293 | 294 | if attn.residual_connection: 295 | hidden_states = hidden_states + residual 296 | 297 | hidden_states = hidden_states / attn.rescale_output_factor 298 | 299 | return hidden_states 300 | 301 | 302 | class LoRAIPAttnProcessor2_0(nn.Module): 303 | r""" 304 | Processor for implementing the LoRA attention mechanism. 305 | 306 | Args: 307 | hidden_size (`int`, *optional*): 308 | The hidden size of the attention layer. 309 | cross_attention_dim (`int`, *optional*): 310 | The number of channels in the `encoder_hidden_states`. 311 | rank (`int`, defaults to 4): 312 | The dimension of the LoRA update matrices. 313 | network_alpha (`int`, *optional*): 314 | Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. 315 | """ 316 | 317 | def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None, lora_scale=1.0, scale=1.0, num_tokens=4): 318 | super().__init__() 319 | 320 | self.rank = rank 321 | self.lora_scale = lora_scale 322 | self.num_tokens = num_tokens 323 | 324 | self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) 325 | self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) 326 | self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) 327 | self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) 328 | 329 | 330 | self.hidden_size = hidden_size 331 | self.cross_attention_dim = cross_attention_dim 332 | self.scale = scale 333 | 334 | self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) 335 | self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) 336 | 337 | def __call__( 338 | self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None 339 | ): 340 | residual = hidden_states 341 | 342 | if attn.spatial_norm is not None: 343 | hidden_states = attn.spatial_norm(hidden_states, temb) 344 | 345 | input_ndim = hidden_states.ndim 346 | 347 | if input_ndim == 4: 348 | batch_size, channel, height, width = hidden_states.shape 349 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 350 | 351 | batch_size, sequence_length, _ = ( 352 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 353 | ) 354 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 355 | 356 | if attn.group_norm is not None: 357 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 358 | 359 | query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states) 360 | #query = attn.head_to_batch_dim(query) 361 | 362 | if encoder_hidden_states is None: 363 | encoder_hidden_states = hidden_states 364 | else: 365 | # get encoder_hidden_states, ip_hidden_states 366 | end_pos = encoder_hidden_states.shape[1] - self.num_tokens 367 | encoder_hidden_states, ip_hidden_states = ( 368 | encoder_hidden_states[:, :end_pos, :], 369 | encoder_hidden_states[:, end_pos:, :], 370 | ) 371 | if attn.norm_cross: 372 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 373 | 374 | # for text 375 | key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states) 376 | value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states) 377 | 378 | inner_dim = key.shape[-1] 379 | head_dim = inner_dim // attn.heads 380 | 381 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 382 | 383 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 384 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 385 | 386 | # the output of sdp = (batch, num_heads, seq_len, head_dim) 387 | # TODO: add support for attn.scale when we move to Torch 2.1 388 | hidden_states = F.scaled_dot_product_attention( 389 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False 390 | ) 391 | 392 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 393 | hidden_states = hidden_states.to(query.dtype) 394 | 395 | # for ip 396 | ip_key = self.to_k_ip(ip_hidden_states) 397 | ip_value = self.to_v_ip(ip_hidden_states) 398 | 399 | ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 400 | ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 401 | 402 | # the output of sdp = (batch, num_heads, seq_len, head_dim) 403 | # TODO: add support for attn.scale when we move to Torch 2.1 404 | ip_hidden_states = F.scaled_dot_product_attention( 405 | query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False 406 | ) 407 | 408 | 409 | ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 410 | ip_hidden_states = ip_hidden_states.to(query.dtype) 411 | 412 | hidden_states = hidden_states + self.scale * ip_hidden_states 413 | 414 | # linear proj 415 | hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states) 416 | # dropout 417 | hidden_states = attn.to_out[1](hidden_states) 418 | 419 | if input_ndim == 4: 420 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 421 | 422 | if attn.residual_connection: 423 | hidden_states = hidden_states + residual 424 | 425 | hidden_states = hidden_states / attn.rescale_output_factor 426 | 427 | return hidden_states 428 | -------------------------------------------------------------------------------- /ip_adapter/ip_adapter.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List 3 | 4 | import torch 5 | from diffusers import StableDiffusionPipeline 6 | from diffusers.pipelines.controlnet import MultiControlNetModel 7 | from PIL import Image 8 | from safetensors import safe_open 9 | from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection 10 | 11 | from .utils import is_torch2_available 12 | 13 | if is_torch2_available(): 14 | from .attention_processor import ( 15 | AttnProcessor2_0 as AttnProcessor, 16 | ) 17 | from .attention_processor import ( 18 | CNAttnProcessor2_0 as CNAttnProcessor, 19 | ) 20 | from .attention_processor import ( 21 | IPAttnProcessor2_0 as IPAttnProcessor, 22 | ) 23 | else: 24 | from .attention_processor import AttnProcessor, CNAttnProcessor, IPAttnProcessor 25 | from .resampler import Resampler 26 | 27 | 28 | class ImageProjModel(torch.nn.Module): 29 | """Projection Model""" 30 | 31 | def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): 32 | super().__init__() 33 | 34 | self.cross_attention_dim = cross_attention_dim 35 | self.clip_extra_context_tokens = clip_extra_context_tokens 36 | self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) 37 | self.norm = torch.nn.LayerNorm(cross_attention_dim) 38 | 39 | def forward(self, image_embeds): 40 | embeds = image_embeds 41 | clip_extra_context_tokens = self.proj(embeds).reshape( 42 | -1, self.clip_extra_context_tokens, self.cross_attention_dim 43 | ) 44 | clip_extra_context_tokens = self.norm(clip_extra_context_tokens) 45 | return clip_extra_context_tokens 46 | 47 | 48 | class MLPProjModel(torch.nn.Module): 49 | """SD model with image prompt""" 50 | def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024): 51 | super().__init__() 52 | 53 | self.proj = torch.nn.Sequential( 54 | torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim), 55 | torch.nn.GELU(), 56 | torch.nn.Linear(clip_embeddings_dim, cross_attention_dim), 57 | torch.nn.LayerNorm(cross_attention_dim) 58 | ) 59 | 60 | def forward(self, image_embeds): 61 | clip_extra_context_tokens = self.proj(image_embeds) 62 | return clip_extra_context_tokens 63 | 64 | 65 | class IPAdapter: 66 | def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4): 67 | self.device = device 68 | self.image_encoder_path = image_encoder_path 69 | self.ip_ckpt = ip_ckpt 70 | self.num_tokens = num_tokens 71 | 72 | self.pipe = sd_pipe.to(self.device) 73 | self.set_ip_adapter() 74 | 75 | # load image encoder 76 | self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to( 77 | self.device, dtype=torch.float16 78 | ) 79 | self.clip_image_processor = CLIPImageProcessor() 80 | # image proj model 81 | self.image_proj_model = self.init_proj() 82 | 83 | self.load_ip_adapter() 84 | 85 | def init_proj(self): 86 | image_proj_model = ImageProjModel( 87 | cross_attention_dim=self.pipe.unet.config.cross_attention_dim, 88 | clip_embeddings_dim=self.image_encoder.config.projection_dim, 89 | clip_extra_context_tokens=self.num_tokens, 90 | ).to(self.device, dtype=torch.float16) 91 | return image_proj_model 92 | 93 | def set_ip_adapter(self): 94 | unet = self.pipe.unet 95 | attn_procs = {} 96 | for name in unet.attn_processors.keys(): 97 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim 98 | if name.startswith("mid_block"): 99 | hidden_size = unet.config.block_out_channels[-1] 100 | elif name.startswith("up_blocks"): 101 | block_id = int(name[len("up_blocks.")]) 102 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id] 103 | elif name.startswith("down_blocks"): 104 | block_id = int(name[len("down_blocks.")]) 105 | hidden_size = unet.config.block_out_channels[block_id] 106 | if cross_attention_dim is None: 107 | attn_procs[name] = AttnProcessor() 108 | else: 109 | attn_procs[name] = IPAttnProcessor( 110 | hidden_size=hidden_size, 111 | cross_attention_dim=cross_attention_dim, 112 | scale=1.0, 113 | num_tokens=self.num_tokens, 114 | ).to(self.device, dtype=torch.float16) 115 | unet.set_attn_processor(attn_procs) 116 | if hasattr(self.pipe, "controlnet"): 117 | if isinstance(self.pipe.controlnet, MultiControlNetModel): 118 | for controlnet in self.pipe.controlnet.nets: 119 | controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens)) 120 | else: 121 | self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens)) 122 | 123 | def load_ip_adapter(self): 124 | if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors": 125 | state_dict = {"image_proj": {}, "ip_adapter": {}} 126 | with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f: 127 | for key in f.keys(): 128 | if key.startswith("image_proj."): 129 | state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) 130 | elif key.startswith("ip_adapter."): 131 | state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) 132 | else: 133 | state_dict = torch.load(self.ip_ckpt, map_location="cpu") 134 | self.image_proj_model.load_state_dict(state_dict["image_proj"]) 135 | ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) 136 | ip_layers.load_state_dict(state_dict["ip_adapter"]) 137 | 138 | @torch.inference_mode() 139 | def get_image_embeds(self, pil_image=None, clip_image_embeds=None): 140 | if pil_image is not None: 141 | if isinstance(pil_image, Image.Image): 142 | pil_image = [pil_image] 143 | clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values 144 | clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds 145 | else: 146 | clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16) 147 | image_prompt_embeds = self.image_proj_model(clip_image_embeds) 148 | uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds)) 149 | return image_prompt_embeds, uncond_image_prompt_embeds 150 | 151 | def set_scale(self, scale): 152 | for attn_processor in self.pipe.unet.attn_processors.values(): 153 | if isinstance(attn_processor, IPAttnProcessor): 154 | attn_processor.scale = scale 155 | 156 | def generate( 157 | self, 158 | pil_image=None, 159 | clip_image_embeds=None, 160 | negative_pil_image=None, 161 | negative_clip_image_embeds=None, 162 | prompt=None, 163 | negative_prompt=None, 164 | scale=1.0, # weight for image prompt 165 | scale_start = 0.0, 166 | scale_stop = 1.0, 167 | scale_neg = 0.0, # weight for negative image prompt 168 | scale_neg_start = 0.0, 169 | scale_neg_stop = 1.0, 170 | num_samples=4, 171 | seed=None, 172 | guidance_scale=7.5, 173 | num_inference_steps=30, 174 | **kwargs, 175 | ): 176 | 177 | # set scales for negative and positive image prompt 178 | self.set_scale( (scale_neg if scale_neg_start == 0. else 0., scale if scale_start == 0. else 0.) ) 179 | self.pipe.ip_scale = scale 180 | self.pipe.ip_scale_start = scale_start 181 | self.pipe.ip_scale_stop = scale_stop 182 | self.pipe.ip_scale_neg = scale 183 | self.pipe.ip_scale_neg_start = scale_neg_start 184 | self.pipe.ip_scale_neg_stop = scale_neg_stop 185 | 186 | 187 | if pil_image is not None: 188 | num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) 189 | elif negative_pil_image: 190 | num_prompts = 1 if isinstance(negative_pil_image, Image.Image) else len(negative_pil_image) 191 | else: 192 | num_prompts = clip_image_embeds.size(0) 193 | 194 | if prompt is None: 195 | prompt = "best quality, high quality" 196 | if negative_prompt is None: 197 | negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" 198 | 199 | if not isinstance(prompt, List): 200 | prompt = [prompt] * num_prompts 201 | if not isinstance(negative_prompt, List): 202 | negative_prompt = [negative_prompt] * num_prompts 203 | 204 | if (pil_image or clip_image_embeds) and not (negative_pil_image or negative_clip_image_embeds): 205 | # positive image only 206 | print('positive image') 207 | image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds( 208 | pil_image=pil_image, clip_image_embeds=clip_image_embeds 209 | ) 210 | elif not (pil_image or clip_image_embeds) and (negative_pil_image or negative_clip_image_embeds): 211 | # negative prompt only 212 | print('negative image') 213 | uncond_image_prompt_embeds, image_prompt_embeds = self.get_image_embeds( 214 | pil_image=negative_pil_image, clip_image_embeds=negative_clip_image_embeds 215 | ) 216 | elif (pil_image or clip_image_embeds) and (negative_pil_image or negative_clip_image_embeds): 217 | # positive and negative images 218 | print('positive and negative image') 219 | image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds( 220 | pil_image=pil_image, clip_image_embeds=clip_image_embeds 221 | ) 222 | uncond_image_prompt_embeds, _ = self.get_image_embeds( 223 | pil_image=negative_pil_image, clip_image_embeds=negative_clip_image_embeds 224 | ) 225 | else: 226 | # No postive or negative images 227 | NotImplementedError("") 228 | 229 | bs_embed, seq_len, _ = image_prompt_embeds.shape 230 | image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) 231 | image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 232 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) 233 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 234 | 235 | with torch.inference_mode(): 236 | prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt( 237 | prompt, 238 | device=self.device, 239 | num_images_per_prompt=num_samples, 240 | do_classifier_free_guidance=True, 241 | negative_prompt=negative_prompt, 242 | ) 243 | prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) 244 | negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) 245 | 246 | def callback_weight_schedule(pipe, step_index, timestep, callback_kwargs): 247 | 248 | # turn on ip scale when it is between start and stop 249 | if step_index >= int(pipe.num_timesteps * pipe.ip_scale_start) and step_index <= int(pipe.num_timesteps * pipe.ip_scale_stop): 250 | scale = pipe.ip_scale 251 | else: 252 | scale = 0 253 | 254 | # turn on ip negative scale when it is between start and stop 255 | if step_index >= int(pipe.num_timesteps * pipe.ip_scale_neg_start) and step_index <= int(pipe.num_timesteps * pipe.ip_scale_neg_stop): 256 | scale_neg = pipe.ip_scale_neg 257 | else: 258 | scale_neg = 0 259 | 260 | # update guidance_scale and prompt_embeds 261 | self.set_scale((scale_neg, scale)) 262 | return callback_kwargs 263 | 264 | generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None 265 | images = self.pipe( 266 | prompt_embeds=prompt_embeds, 267 | negative_prompt_embeds=negative_prompt_embeds, 268 | guidance_scale=guidance_scale, 269 | num_inference_steps=num_inference_steps, 270 | generator=generator, 271 | callback_on_step_end= callback_weight_schedule, 272 | **kwargs, 273 | ).images 274 | 275 | return images 276 | 277 | 278 | class IPAdapterXL(IPAdapter): 279 | """SDXL""" 280 | 281 | def generate( 282 | self, 283 | pil_image, 284 | prompt=None, 285 | negative_prompt=None, 286 | scale=1.0, 287 | num_samples=4, 288 | seed=None, 289 | num_inference_steps=30, 290 | **kwargs, 291 | ): 292 | self.set_scale(scale) 293 | 294 | num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) 295 | 296 | if prompt is None: 297 | prompt = "best quality, high quality" 298 | if negative_prompt is None: 299 | negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" 300 | 301 | if not isinstance(prompt, List): 302 | prompt = [prompt] * num_prompts 303 | if not isinstance(negative_prompt, List): 304 | negative_prompt = [negative_prompt] * num_prompts 305 | 306 | image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image) 307 | bs_embed, seq_len, _ = image_prompt_embeds.shape 308 | image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) 309 | image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 310 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) 311 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 312 | 313 | with torch.inference_mode(): 314 | ( 315 | prompt_embeds, 316 | negative_prompt_embeds, 317 | pooled_prompt_embeds, 318 | negative_pooled_prompt_embeds, 319 | ) = self.pipe.encode_prompt( 320 | prompt, 321 | num_images_per_prompt=num_samples, 322 | do_classifier_free_guidance=True, 323 | negative_prompt=negative_prompt, 324 | ) 325 | prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) 326 | negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) 327 | 328 | generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None 329 | images = self.pipe( 330 | prompt_embeds=prompt_embeds, 331 | negative_prompt_embeds=negative_prompt_embeds, 332 | pooled_prompt_embeds=pooled_prompt_embeds, 333 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, 334 | num_inference_steps=num_inference_steps, 335 | 336 | generator=generator, 337 | **kwargs, 338 | ).images 339 | 340 | return images 341 | 342 | 343 | class IPAdapterPlus(IPAdapter): 344 | """IP-Adapter with fine-grained features""" 345 | 346 | def init_proj(self): 347 | image_proj_model = Resampler( 348 | dim=self.pipe.unet.config.cross_attention_dim, 349 | depth=4, 350 | dim_head=64, 351 | heads=12, 352 | num_queries=self.num_tokens, 353 | embedding_dim=self.image_encoder.config.hidden_size, 354 | output_dim=self.pipe.unet.config.cross_attention_dim, 355 | ff_mult=4, 356 | ).to(self.device, dtype=torch.float16) 357 | return image_proj_model 358 | 359 | @torch.inference_mode() 360 | def get_image_embeds(self, pil_image=None, clip_image_embeds=None): 361 | if isinstance(pil_image, Image.Image): 362 | pil_image = [pil_image] 363 | clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values 364 | clip_image = clip_image.to(self.device, dtype=torch.float16) 365 | clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] 366 | image_prompt_embeds = self.image_proj_model(clip_image_embeds) 367 | uncond_clip_image_embeds = self.image_encoder( 368 | torch.zeros_like(clip_image), output_hidden_states=True 369 | ).hidden_states[-2] 370 | uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) 371 | return image_prompt_embeds, uncond_image_prompt_embeds 372 | 373 | 374 | class IPAdapterFull(IPAdapterPlus): 375 | """IP-Adapter with full features""" 376 | 377 | def init_proj(self): 378 | image_proj_model = MLPProjModel( 379 | cross_attention_dim=self.pipe.unet.config.cross_attention_dim, 380 | clip_embeddings_dim=self.image_encoder.config.hidden_size, 381 | ).to(self.device, dtype=torch.float16) 382 | return image_proj_model 383 | 384 | 385 | class IPAdapterPlusXL(IPAdapter): 386 | """SDXL""" 387 | 388 | def init_proj(self): 389 | image_proj_model = Resampler( 390 | dim=1280, 391 | depth=4, 392 | dim_head=64, 393 | heads=20, 394 | num_queries=self.num_tokens, 395 | embedding_dim=self.image_encoder.config.hidden_size, 396 | output_dim=self.pipe.unet.config.cross_attention_dim, 397 | ff_mult=4, 398 | ).to(self.device, dtype=torch.float16) 399 | return image_proj_model 400 | 401 | @torch.inference_mode() 402 | def get_image_embeds(self, pil_image): 403 | if isinstance(pil_image, Image.Image): 404 | pil_image = [pil_image] 405 | clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values 406 | clip_image = clip_image.to(self.device, dtype=torch.float16) 407 | clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] 408 | image_prompt_embeds = self.image_proj_model(clip_image_embeds) 409 | uncond_clip_image_embeds = self.image_encoder( 410 | torch.zeros_like(clip_image), output_hidden_states=True 411 | ).hidden_states[-2] 412 | uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) 413 | return image_prompt_embeds, uncond_image_prompt_embeds 414 | 415 | def generate( 416 | self, 417 | pil_image, 418 | prompt=None, 419 | negative_prompt=None, 420 | scale=1.0, 421 | num_samples=4, 422 | seed=None, 423 | num_inference_steps=30, 424 | **kwargs, 425 | ): 426 | self.set_scale(scale) 427 | 428 | num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) 429 | 430 | if prompt is None: 431 | prompt = "best quality, high quality" 432 | if negative_prompt is None: 433 | negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" 434 | 435 | if not isinstance(prompt, List): 436 | prompt = [prompt] * num_prompts 437 | if not isinstance(negative_prompt, List): 438 | negative_prompt = [negative_prompt] * num_prompts 439 | 440 | image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image) 441 | bs_embed, seq_len, _ = image_prompt_embeds.shape 442 | image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) 443 | image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 444 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) 445 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 446 | 447 | with torch.inference_mode(): 448 | ( 449 | prompt_embeds, 450 | negative_prompt_embeds, 451 | pooled_prompt_embeds, 452 | negative_pooled_prompt_embeds, 453 | ) = self.pipe.encode_prompt( 454 | prompt, 455 | num_images_per_prompt=num_samples, 456 | do_classifier_free_guidance=True, 457 | negative_prompt=negative_prompt, 458 | ) 459 | prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) 460 | negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) 461 | 462 | generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None 463 | images = self.pipe( 464 | prompt_embeds=prompt_embeds, 465 | negative_prompt_embeds=negative_prompt_embeds, 466 | pooled_prompt_embeds=pooled_prompt_embeds, 467 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, 468 | num_inference_steps=num_inference_steps, 469 | generator=generator, 470 | **kwargs, 471 | ).images 472 | 473 | return images 474 | -------------------------------------------------------------------------------- /ip_adapter/ip_adapter_faceid.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List 3 | 4 | import torch 5 | from diffusers import StableDiffusionPipeline 6 | from diffusers.pipelines.controlnet import MultiControlNetModel 7 | from PIL import Image 8 | from safetensors import safe_open 9 | from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection 10 | 11 | from .attention_processor_faceid import LoRAAttnProcessor, LoRAIPAttnProcessor 12 | from .utils import is_torch2_available 13 | 14 | USE_DAFAULT_ATTN = False # should be True for visualization_attnmap 15 | if is_torch2_available() and (not USE_DAFAULT_ATTN): 16 | from .attention_processor_faceid import ( 17 | LoRAAttnProcessor2_0 as LoRAAttnProcessor, 18 | ) 19 | from .attention_processor_faceid import ( 20 | LoRAIPAttnProcessor2_0 as LoRAIPAttnProcessor, 21 | ) 22 | else: 23 | from .attention_processor_faceid import LoRAAttnProcessor, LoRAIPAttnProcessor 24 | from .resampler import PerceiverAttention, FeedForward 25 | 26 | 27 | class FacePerceiverResampler(torch.nn.Module): 28 | def __init__( 29 | self, 30 | *, 31 | dim=768, 32 | depth=4, 33 | dim_head=64, 34 | heads=16, 35 | embedding_dim=1280, 36 | output_dim=768, 37 | ff_mult=4, 38 | ): 39 | super().__init__() 40 | 41 | self.proj_in = torch.nn.Linear(embedding_dim, dim) 42 | self.proj_out = torch.nn.Linear(dim, output_dim) 43 | self.norm_out = torch.nn.LayerNorm(output_dim) 44 | self.layers = torch.nn.ModuleList([]) 45 | for _ in range(depth): 46 | self.layers.append( 47 | torch.nn.ModuleList( 48 | [ 49 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), 50 | FeedForward(dim=dim, mult=ff_mult), 51 | ] 52 | ) 53 | ) 54 | 55 | def forward(self, latents, x): 56 | x = self.proj_in(x) 57 | for attn, ff in self.layers: 58 | latents = attn(x, latents) + latents 59 | latents = ff(latents) + latents 60 | latents = self.proj_out(latents) 61 | return self.norm_out(latents) 62 | 63 | 64 | class MLPProjModel(torch.nn.Module): 65 | def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4): 66 | super().__init__() 67 | 68 | self.cross_attention_dim = cross_attention_dim 69 | self.num_tokens = num_tokens 70 | 71 | self.proj = torch.nn.Sequential( 72 | torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2), 73 | torch.nn.GELU(), 74 | torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens), 75 | ) 76 | self.norm = torch.nn.LayerNorm(cross_attention_dim) 77 | 78 | def forward(self, id_embeds): 79 | x = self.proj(id_embeds) 80 | x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) 81 | x = self.norm(x) 82 | return x 83 | 84 | 85 | class ProjPlusModel(torch.nn.Module): 86 | def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, clip_embeddings_dim=1280, num_tokens=4): 87 | super().__init__() 88 | 89 | self.cross_attention_dim = cross_attention_dim 90 | self.num_tokens = num_tokens 91 | 92 | self.proj = torch.nn.Sequential( 93 | torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2), 94 | torch.nn.GELU(), 95 | torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens), 96 | ) 97 | self.norm = torch.nn.LayerNorm(cross_attention_dim) 98 | 99 | self.perceiver_resampler = FacePerceiverResampler( 100 | dim=cross_attention_dim, 101 | depth=4, 102 | dim_head=64, 103 | heads=cross_attention_dim // 64, 104 | embedding_dim=clip_embeddings_dim, 105 | output_dim=cross_attention_dim, 106 | ff_mult=4, 107 | ) 108 | 109 | def forward(self, id_embeds, clip_embeds, shortcut=False, scale=1.0): 110 | 111 | x = self.proj(id_embeds) 112 | x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) 113 | x = self.norm(x) 114 | out = self.perceiver_resampler(x, clip_embeds) 115 | if shortcut: 116 | out = x + scale * out 117 | return out 118 | 119 | 120 | class IPAdapterFaceID: 121 | def __init__(self, sd_pipe, ip_ckpt, device, lora_rank=128, num_tokens=4, torch_dtype=torch.float16): 122 | self.device = device 123 | self.ip_ckpt = ip_ckpt 124 | self.lora_rank = lora_rank 125 | self.num_tokens = num_tokens 126 | self.torch_dtype = torch_dtype 127 | 128 | self.pipe = sd_pipe.to(self.device) 129 | self.set_ip_adapter() 130 | 131 | # image proj model 132 | self.image_proj_model = self.init_proj() 133 | 134 | self.load_ip_adapter() 135 | 136 | def init_proj(self): 137 | image_proj_model = MLPProjModel( 138 | cross_attention_dim=self.pipe.unet.config.cross_attention_dim, 139 | id_embeddings_dim=512, 140 | num_tokens=self.num_tokens, 141 | ).to(self.device, dtype=self.torch_dtype) 142 | return image_proj_model 143 | 144 | def set_ip_adapter(self): 145 | unet = self.pipe.unet 146 | attn_procs = {} 147 | for name in unet.attn_processors.keys(): 148 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim 149 | if name.startswith("mid_block"): 150 | hidden_size = unet.config.block_out_channels[-1] 151 | elif name.startswith("up_blocks"): 152 | block_id = int(name[len("up_blocks.")]) 153 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id] 154 | elif name.startswith("down_blocks"): 155 | block_id = int(name[len("down_blocks.")]) 156 | hidden_size = unet.config.block_out_channels[block_id] 157 | if cross_attention_dim is None: 158 | attn_procs[name] = LoRAAttnProcessor( 159 | hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=self.lora_rank, 160 | ).to(self.device, dtype=self.torch_dtype) 161 | else: 162 | attn_procs[name] = LoRAIPAttnProcessor( 163 | hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, rank=self.lora_rank, num_tokens=self.num_tokens, 164 | ).to(self.device, dtype=self.torch_dtype) 165 | unet.set_attn_processor(attn_procs) 166 | 167 | def load_ip_adapter(self): 168 | if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors": 169 | state_dict = {"image_proj": {}, "ip_adapter": {}} 170 | with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f: 171 | for key in f.keys(): 172 | if key.startswith("image_proj."): 173 | state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) 174 | elif key.startswith("ip_adapter."): 175 | state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) 176 | else: 177 | state_dict = torch.load(self.ip_ckpt, map_location="cpu") 178 | self.image_proj_model.load_state_dict(state_dict["image_proj"]) 179 | ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) 180 | ip_layers.load_state_dict(state_dict["ip_adapter"]) 181 | 182 | @torch.inference_mode() 183 | def get_image_embeds(self, faceid_embeds): 184 | 185 | faceid_embeds = faceid_embeds.to(self.device, dtype=self.torch_dtype) 186 | image_prompt_embeds = self.image_proj_model(faceid_embeds) 187 | uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(faceid_embeds)) 188 | return image_prompt_embeds, uncond_image_prompt_embeds 189 | 190 | def set_scale(self, scale): 191 | for attn_processor in self.pipe.unet.attn_processors.values(): 192 | if isinstance(attn_processor, LoRAIPAttnProcessor): 193 | attn_processor.scale = scale 194 | 195 | def generate( 196 | self, 197 | faceid_embeds=None, 198 | prompt=None, 199 | negative_prompt=None, 200 | scale=1.0, 201 | num_samples=4, 202 | seed=None, 203 | guidance_scale=7.5, 204 | num_inference_steps=30, 205 | **kwargs, 206 | ): 207 | self.set_scale(scale) 208 | 209 | 210 | num_prompts = faceid_embeds.size(0) 211 | 212 | if prompt is None: 213 | prompt = "best quality, high quality" 214 | if negative_prompt is None: 215 | negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" 216 | 217 | if not isinstance(prompt, List): 218 | prompt = [prompt] * num_prompts 219 | if not isinstance(negative_prompt, List): 220 | negative_prompt = [negative_prompt] * num_prompts 221 | 222 | image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds) 223 | 224 | bs_embed, seq_len, _ = image_prompt_embeds.shape 225 | image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) 226 | image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 227 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) 228 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 229 | 230 | with torch.inference_mode(): 231 | prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt( 232 | prompt, 233 | device=self.device, 234 | num_images_per_prompt=num_samples, 235 | do_classifier_free_guidance=True, 236 | negative_prompt=negative_prompt, 237 | ) 238 | prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) 239 | negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) 240 | 241 | generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None 242 | images = self.pipe( 243 | prompt_embeds=prompt_embeds, 244 | negative_prompt_embeds=negative_prompt_embeds, 245 | guidance_scale=guidance_scale, 246 | num_inference_steps=num_inference_steps, 247 | generator=generator, 248 | **kwargs, 249 | ).images 250 | 251 | return images 252 | 253 | 254 | class IPAdapterFaceIDPlus: 255 | def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, lora_rank=128, num_tokens=4, torch_dtype=torch.float16): 256 | self.device = device 257 | self.image_encoder_path = image_encoder_path 258 | self.ip_ckpt = ip_ckpt 259 | self.lora_rank = lora_rank 260 | self.num_tokens = num_tokens 261 | self.torch_dtype = torch_dtype 262 | 263 | self.pipe = sd_pipe.to(self.device) 264 | self.set_ip_adapter() 265 | 266 | # load image encoder 267 | self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to( 268 | self.device, dtype=self.torch_dtype 269 | ) 270 | self.clip_image_processor = CLIPImageProcessor() 271 | # image proj model 272 | self.image_proj_model = self.init_proj() 273 | 274 | self.load_ip_adapter() 275 | 276 | def init_proj(self): 277 | image_proj_model = ProjPlusModel( 278 | cross_attention_dim=self.pipe.unet.config.cross_attention_dim, 279 | id_embeddings_dim=512, 280 | clip_embeddings_dim=self.image_encoder.config.hidden_size, 281 | num_tokens=self.num_tokens, 282 | ).to(self.device, dtype=self.torch_dtype) 283 | return image_proj_model 284 | 285 | def set_ip_adapter(self): 286 | unet = self.pipe.unet 287 | attn_procs = {} 288 | for name in unet.attn_processors.keys(): 289 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim 290 | if name.startswith("mid_block"): 291 | hidden_size = unet.config.block_out_channels[-1] 292 | elif name.startswith("up_blocks"): 293 | block_id = int(name[len("up_blocks.")]) 294 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id] 295 | elif name.startswith("down_blocks"): 296 | block_id = int(name[len("down_blocks.")]) 297 | hidden_size = unet.config.block_out_channels[block_id] 298 | if cross_attention_dim is None: 299 | attn_procs[name] = LoRAAttnProcessor( 300 | hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=self.lora_rank, 301 | ).to(self.device, dtype=self.torch_dtype) 302 | else: 303 | attn_procs[name] = LoRAIPAttnProcessor( 304 | hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, rank=self.lora_rank, num_tokens=self.num_tokens, 305 | ).to(self.device, dtype=self.torch_dtype) 306 | unet.set_attn_processor(attn_procs) 307 | 308 | def load_ip_adapter(self): 309 | if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors": 310 | state_dict = {"image_proj": {}, "ip_adapter": {}} 311 | with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f: 312 | for key in f.keys(): 313 | if key.startswith("image_proj."): 314 | state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) 315 | elif key.startswith("ip_adapter."): 316 | state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) 317 | else: 318 | state_dict = torch.load(self.ip_ckpt, map_location="cpu") 319 | self.image_proj_model.load_state_dict(state_dict["image_proj"]) 320 | ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) 321 | ip_layers.load_state_dict(state_dict["ip_adapter"]) 322 | 323 | @torch.inference_mode() 324 | def get_image_embeds(self, faceid_embeds, face_image, s_scale, shortcut): 325 | if isinstance(face_image, Image.Image): 326 | pil_image = [face_image] 327 | clip_image = self.clip_image_processor(images=face_image, return_tensors="pt").pixel_values 328 | clip_image = clip_image.to(self.device, dtype=self.torch_dtype) 329 | clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] 330 | uncond_clip_image_embeds = self.image_encoder( 331 | torch.zeros_like(clip_image), output_hidden_states=True 332 | ).hidden_states[-2] 333 | 334 | faceid_embeds = faceid_embeds.to(self.device, dtype=self.torch_dtype) 335 | image_prompt_embeds = self.image_proj_model(faceid_embeds, clip_image_embeds, shortcut=shortcut, scale=s_scale) 336 | uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(faceid_embeds), uncond_clip_image_embeds, shortcut=shortcut, scale=s_scale) 337 | return image_prompt_embeds, uncond_image_prompt_embeds 338 | 339 | def set_scale(self, scale): 340 | for attn_processor in self.pipe.unet.attn_processors.values(): 341 | if isinstance(attn_processor, LoRAIPAttnProcessor): 342 | attn_processor.scale = scale 343 | 344 | def generate( 345 | self, 346 | face_image=None, 347 | faceid_embeds=None, 348 | prompt=None, 349 | negative_prompt=None, 350 | scale=1.0, 351 | num_samples=4, 352 | seed=None, 353 | guidance_scale=7.5, 354 | num_inference_steps=30, 355 | s_scale=1.0, 356 | shortcut=False, 357 | **kwargs, 358 | ): 359 | self.set_scale(scale) 360 | 361 | 362 | num_prompts = faceid_embeds.size(0) 363 | 364 | if prompt is None: 365 | prompt = "best quality, high quality" 366 | if negative_prompt is None: 367 | negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" 368 | 369 | if not isinstance(prompt, List): 370 | prompt = [prompt] * num_prompts 371 | if not isinstance(negative_prompt, List): 372 | negative_prompt = [negative_prompt] * num_prompts 373 | 374 | image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds, face_image, s_scale, shortcut) 375 | 376 | bs_embed, seq_len, _ = image_prompt_embeds.shape 377 | image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) 378 | image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 379 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) 380 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 381 | 382 | with torch.inference_mode(): 383 | prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt( 384 | prompt, 385 | device=self.device, 386 | num_images_per_prompt=num_samples, 387 | do_classifier_free_guidance=True, 388 | negative_prompt=negative_prompt, 389 | ) 390 | prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) 391 | negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) 392 | 393 | generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None 394 | images = self.pipe( 395 | prompt_embeds=prompt_embeds, 396 | negative_prompt_embeds=negative_prompt_embeds, 397 | guidance_scale=guidance_scale, 398 | num_inference_steps=num_inference_steps, 399 | generator=generator, 400 | **kwargs, 401 | ).images 402 | 403 | return images 404 | 405 | 406 | class IPAdapterFaceIDXL(IPAdapterFaceID): 407 | """SDXL""" 408 | 409 | def generate( 410 | self, 411 | faceid_embeds=None, 412 | prompt=None, 413 | negative_prompt=None, 414 | scale=1.0, 415 | num_samples=4, 416 | seed=None, 417 | num_inference_steps=30, 418 | **kwargs, 419 | ): 420 | self.set_scale(scale) 421 | 422 | num_prompts = faceid_embeds.size(0) 423 | 424 | if prompt is None: 425 | prompt = "best quality, high quality" 426 | if negative_prompt is None: 427 | negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" 428 | 429 | if not isinstance(prompt, List): 430 | prompt = [prompt] * num_prompts 431 | if not isinstance(negative_prompt, List): 432 | negative_prompt = [negative_prompt] * num_prompts 433 | 434 | image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds) 435 | 436 | bs_embed, seq_len, _ = image_prompt_embeds.shape 437 | image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) 438 | image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 439 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) 440 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 441 | 442 | with torch.inference_mode(): 443 | ( 444 | prompt_embeds, 445 | negative_prompt_embeds, 446 | pooled_prompt_embeds, 447 | negative_pooled_prompt_embeds, 448 | ) = self.pipe.encode_prompt( 449 | prompt, 450 | num_images_per_prompt=num_samples, 451 | do_classifier_free_guidance=True, 452 | negative_prompt=negative_prompt, 453 | ) 454 | prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) 455 | negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) 456 | 457 | generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None 458 | images = self.pipe( 459 | prompt_embeds=prompt_embeds, 460 | negative_prompt_embeds=negative_prompt_embeds, 461 | pooled_prompt_embeds=pooled_prompt_embeds, 462 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, 463 | num_inference_steps=num_inference_steps, 464 | generator=generator, 465 | **kwargs, 466 | ).images 467 | 468 | return images 469 | -------------------------------------------------------------------------------- /ip_adapter/ip_adapter_faceid_separate.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List 3 | 4 | import torch 5 | from diffusers import StableDiffusionPipeline 6 | from diffusers.pipelines.controlnet import MultiControlNetModel 7 | from PIL import Image 8 | from safetensors import safe_open 9 | from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection 10 | 11 | from .attention_processor_faceid import LoRAAttnProcessor, LoRAIPAttnProcessor 12 | from .utils import is_torch2_available 13 | 14 | USE_DAFAULT_ATTN = False # should be True for visualization_attnmap 15 | if is_torch2_available() and (not USE_DAFAULT_ATTN): 16 | from .attention_processor import ( 17 | AttnProcessor2_0 as AttnProcessor, 18 | ) 19 | from .attention_processor import ( 20 | IPAttnProcessor2_0 as IPAttnProcessor, 21 | ) 22 | else: 23 | from .attention_processor import AttnProcessor, IPAttnProcessor 24 | from .resampler import PerceiverAttention, FeedForward 25 | 26 | 27 | class FacePerceiverResampler(torch.nn.Module): 28 | def __init__( 29 | self, 30 | *, 31 | dim=768, 32 | depth=4, 33 | dim_head=64, 34 | heads=16, 35 | embedding_dim=1280, 36 | output_dim=768, 37 | ff_mult=4, 38 | ): 39 | super().__init__() 40 | 41 | self.proj_in = torch.nn.Linear(embedding_dim, dim) 42 | self.proj_out = torch.nn.Linear(dim, output_dim) 43 | self.norm_out = torch.nn.LayerNorm(output_dim) 44 | self.layers = torch.nn.ModuleList([]) 45 | for _ in range(depth): 46 | self.layers.append( 47 | torch.nn.ModuleList( 48 | [ 49 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), 50 | FeedForward(dim=dim, mult=ff_mult), 51 | ] 52 | ) 53 | ) 54 | 55 | def forward(self, latents, x): 56 | x = self.proj_in(x) 57 | for attn, ff in self.layers: 58 | latents = attn(x, latents) + latents 59 | latents = ff(latents) + latents 60 | latents = self.proj_out(latents) 61 | return self.norm_out(latents) 62 | 63 | 64 | class MLPProjModel(torch.nn.Module): 65 | def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4): 66 | super().__init__() 67 | 68 | self.cross_attention_dim = cross_attention_dim 69 | self.num_tokens = num_tokens 70 | 71 | self.proj = torch.nn.Sequential( 72 | torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2), 73 | torch.nn.GELU(), 74 | torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens), 75 | ) 76 | self.norm = torch.nn.LayerNorm(cross_attention_dim) 77 | 78 | def forward(self, id_embeds): 79 | x = self.proj(id_embeds) 80 | x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) 81 | x = self.norm(x) 82 | return x 83 | 84 | 85 | class ProjPlusModel(torch.nn.Module): 86 | def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, clip_embeddings_dim=1280, num_tokens=4): 87 | super().__init__() 88 | 89 | self.cross_attention_dim = cross_attention_dim 90 | self.num_tokens = num_tokens 91 | 92 | self.proj = torch.nn.Sequential( 93 | torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2), 94 | torch.nn.GELU(), 95 | torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens), 96 | ) 97 | self.norm = torch.nn.LayerNorm(cross_attention_dim) 98 | 99 | self.perceiver_resampler = FacePerceiverResampler( 100 | dim=cross_attention_dim, 101 | depth=4, 102 | dim_head=64, 103 | heads=cross_attention_dim // 64, 104 | embedding_dim=clip_embeddings_dim, 105 | output_dim=cross_attention_dim, 106 | ff_mult=4, 107 | ) 108 | 109 | def forward(self, id_embeds, clip_embeds, shortcut=False, scale=1.0): 110 | 111 | x = self.proj(id_embeds) 112 | x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) 113 | x = self.norm(x) 114 | out = self.perceiver_resampler(x, clip_embeds) 115 | if shortcut: 116 | out = x + scale * out 117 | return out 118 | 119 | 120 | class IPAdapterFaceID: 121 | def __init__(self, sd_pipe, ip_ckpt, device, num_tokens=4, torch_dtype=torch.float16): 122 | self.device = device 123 | self.ip_ckpt = ip_ckpt 124 | self.num_tokens = num_tokens 125 | self.torch_dtype = torch_dtype 126 | 127 | self.pipe = sd_pipe.to(self.device) 128 | self.set_ip_adapter() 129 | 130 | # image proj model 131 | self.image_proj_model = self.init_proj() 132 | 133 | self.load_ip_adapter() 134 | 135 | def init_proj(self): 136 | image_proj_model = MLPProjModel( 137 | cross_attention_dim=self.pipe.unet.config.cross_attention_dim, 138 | id_embeddings_dim=512, 139 | num_tokens=self.num_tokens, 140 | ).to(self.device, dtype=self.torch_dtype) 141 | return image_proj_model 142 | 143 | def set_ip_adapter(self): 144 | unet = self.pipe.unet 145 | attn_procs = {} 146 | for name in unet.attn_processors.keys(): 147 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim 148 | if name.startswith("mid_block"): 149 | hidden_size = unet.config.block_out_channels[-1] 150 | elif name.startswith("up_blocks"): 151 | block_id = int(name[len("up_blocks.")]) 152 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id] 153 | elif name.startswith("down_blocks"): 154 | block_id = int(name[len("down_blocks.")]) 155 | hidden_size = unet.config.block_out_channels[block_id] 156 | if cross_attention_dim is None: 157 | attn_procs[name] = AttnProcessor() 158 | else: 159 | attn_procs[name] = IPAttnProcessor( 160 | hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, num_tokens=self.num_tokens, 161 | ).to(self.device, dtype=self.torch_dtype) 162 | unet.set_attn_processor(attn_procs) 163 | 164 | def load_ip_adapter(self): 165 | if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors": 166 | state_dict = {"image_proj": {}, "ip_adapter": {}} 167 | with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f: 168 | for key in f.keys(): 169 | if key.startswith("image_proj."): 170 | state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) 171 | elif key.startswith("ip_adapter."): 172 | state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) 173 | else: 174 | state_dict = torch.load(self.ip_ckpt, map_location="cpu") 175 | self.image_proj_model.load_state_dict(state_dict["image_proj"]) 176 | ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) 177 | ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False) 178 | 179 | @torch.inference_mode() 180 | def get_image_embeds(self, faceid_embeds): 181 | 182 | faceid_embeds = faceid_embeds.to(self.device, dtype=self.torch_dtype) 183 | image_prompt_embeds = self.image_proj_model(faceid_embeds) 184 | uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(faceid_embeds)) 185 | return image_prompt_embeds, uncond_image_prompt_embeds 186 | 187 | def set_scale(self, scale): 188 | for attn_processor in self.pipe.unet.attn_processors.values(): 189 | if isinstance(attn_processor, LoRAIPAttnProcessor): 190 | attn_processor.scale = scale 191 | 192 | def generate( 193 | self, 194 | faceid_embeds=None, 195 | prompt=None, 196 | negative_prompt=None, 197 | scale=1.0, 198 | num_samples=4, 199 | seed=None, 200 | guidance_scale=7.5, 201 | num_inference_steps=30, 202 | **kwargs, 203 | ): 204 | self.set_scale(scale) 205 | 206 | 207 | num_prompts = faceid_embeds.size(0) 208 | 209 | if prompt is None: 210 | prompt = "best quality, high quality" 211 | if negative_prompt is None: 212 | negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" 213 | 214 | if not isinstance(prompt, List): 215 | prompt = [prompt] * num_prompts 216 | if not isinstance(negative_prompt, List): 217 | negative_prompt = [negative_prompt] * num_prompts 218 | 219 | image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds) 220 | 221 | bs_embed, seq_len, _ = image_prompt_embeds.shape 222 | image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) 223 | image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 224 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) 225 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 226 | 227 | with torch.inference_mode(): 228 | prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt( 229 | prompt, 230 | device=self.device, 231 | num_images_per_prompt=num_samples, 232 | do_classifier_free_guidance=True, 233 | negative_prompt=negative_prompt, 234 | ) 235 | prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) 236 | negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) 237 | 238 | generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None 239 | images = self.pipe( 240 | prompt_embeds=prompt_embeds, 241 | negative_prompt_embeds=negative_prompt_embeds, 242 | guidance_scale=guidance_scale, 243 | num_inference_steps=num_inference_steps, 244 | generator=generator, 245 | **kwargs, 246 | ).images 247 | 248 | return images 249 | 250 | 251 | class IPAdapterFaceIDPlus: 252 | def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4, torch_dtype=torch.float16): 253 | self.device = device 254 | self.image_encoder_path = image_encoder_path 255 | self.ip_ckpt = ip_ckpt 256 | self.num_tokens = num_tokens 257 | self.torch_dtype = torch_dtype 258 | 259 | self.pipe = sd_pipe.to(self.device) 260 | self.set_ip_adapter() 261 | 262 | # load image encoder 263 | self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to( 264 | self.device, dtype=self.torch_dtype 265 | ) 266 | self.clip_image_processor = CLIPImageProcessor() 267 | # image proj model 268 | self.image_proj_model = self.init_proj() 269 | 270 | self.load_ip_adapter() 271 | 272 | def init_proj(self): 273 | image_proj_model = ProjPlusModel( 274 | cross_attention_dim=self.pipe.unet.config.cross_attention_dim, 275 | id_embeddings_dim=512, 276 | clip_embeddings_dim=self.image_encoder.config.hidden_size, 277 | num_tokens=self.num_tokens, 278 | ).to(self.device, dtype=self.torch_dtype) 279 | return image_proj_model 280 | 281 | def set_ip_adapter(self): 282 | unet = self.pipe.unet 283 | attn_procs = {} 284 | for name in unet.attn_processors.keys(): 285 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim 286 | if name.startswith("mid_block"): 287 | hidden_size = unet.config.block_out_channels[-1] 288 | elif name.startswith("up_blocks"): 289 | block_id = int(name[len("up_blocks.")]) 290 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id] 291 | elif name.startswith("down_blocks"): 292 | block_id = int(name[len("down_blocks.")]) 293 | hidden_size = unet.config.block_out_channels[block_id] 294 | if cross_attention_dim is None: 295 | attn_procs[name] = AttnProcessor() 296 | else: 297 | attn_procs[name] = IPAttnProcessor( 298 | hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, num_tokens=self.num_tokens, 299 | ).to(self.device, dtype=self.torch_dtype) 300 | unet.set_attn_processor(attn_procs) 301 | 302 | def load_ip_adapter(self): 303 | if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors": 304 | state_dict = {"image_proj": {}, "ip_adapter": {}} 305 | with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f: 306 | for key in f.keys(): 307 | if key.startswith("image_proj."): 308 | state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) 309 | elif key.startswith("ip_adapter."): 310 | state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) 311 | else: 312 | state_dict = torch.load(self.ip_ckpt, map_location="cpu") 313 | self.image_proj_model.load_state_dict(state_dict["image_proj"]) 314 | ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) 315 | ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False) 316 | 317 | @torch.inference_mode() 318 | def get_image_embeds(self, faceid_embeds, face_image, s_scale, shortcut): 319 | if isinstance(face_image, Image.Image): 320 | pil_image = [face_image] 321 | clip_image = self.clip_image_processor(images=face_image, return_tensors="pt").pixel_values 322 | clip_image = clip_image.to(self.device, dtype=self.torch_dtype) 323 | clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] 324 | uncond_clip_image_embeds = self.image_encoder( 325 | torch.zeros_like(clip_image), output_hidden_states=True 326 | ).hidden_states[-2] 327 | 328 | faceid_embeds = faceid_embeds.to(self.device, dtype=self.torch_dtype) 329 | image_prompt_embeds = self.image_proj_model(faceid_embeds, clip_image_embeds, shortcut=shortcut, scale=s_scale) 330 | uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(faceid_embeds), uncond_clip_image_embeds, shortcut=shortcut, scale=s_scale) 331 | return image_prompt_embeds, uncond_image_prompt_embeds 332 | 333 | def set_scale(self, scale): 334 | for attn_processor in self.pipe.unet.attn_processors.values(): 335 | if isinstance(attn_processor, LoRAIPAttnProcessor): 336 | attn_processor.scale = scale 337 | 338 | def generate( 339 | self, 340 | face_image=None, 341 | faceid_embeds=None, 342 | prompt=None, 343 | negative_prompt=None, 344 | scale=1.0, 345 | num_samples=4, 346 | seed=None, 347 | guidance_scale=7.5, 348 | num_inference_steps=30, 349 | s_scale=1.0, 350 | shortcut=False, 351 | **kwargs, 352 | ): 353 | self.set_scale(scale) 354 | 355 | 356 | num_prompts = faceid_embeds.size(0) 357 | 358 | if prompt is None: 359 | prompt = "best quality, high quality" 360 | if negative_prompt is None: 361 | negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" 362 | 363 | if not isinstance(prompt, List): 364 | prompt = [prompt] * num_prompts 365 | if not isinstance(negative_prompt, List): 366 | negative_prompt = [negative_prompt] * num_prompts 367 | 368 | image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds, face_image, s_scale, shortcut) 369 | 370 | bs_embed, seq_len, _ = image_prompt_embeds.shape 371 | image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) 372 | image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 373 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) 374 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 375 | 376 | with torch.inference_mode(): 377 | prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt( 378 | prompt, 379 | device=self.device, 380 | num_images_per_prompt=num_samples, 381 | do_classifier_free_guidance=True, 382 | negative_prompt=negative_prompt, 383 | ) 384 | prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) 385 | negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) 386 | 387 | generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None 388 | images = self.pipe( 389 | prompt_embeds=prompt_embeds, 390 | negative_prompt_embeds=negative_prompt_embeds, 391 | guidance_scale=guidance_scale, 392 | num_inference_steps=num_inference_steps, 393 | generator=generator, 394 | **kwargs, 395 | ).images 396 | 397 | return images 398 | 399 | 400 | class IPAdapterFaceIDXL(IPAdapterFaceID): 401 | """SDXL""" 402 | 403 | def generate( 404 | self, 405 | faceid_embeds=None, 406 | prompt=None, 407 | negative_prompt=None, 408 | scale=1.0, 409 | num_samples=4, 410 | seed=None, 411 | num_inference_steps=30, 412 | **kwargs, 413 | ): 414 | self.set_scale(scale) 415 | 416 | num_prompts = faceid_embeds.size(0) 417 | 418 | if prompt is None: 419 | prompt = "best quality, high quality" 420 | if negative_prompt is None: 421 | negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" 422 | 423 | if not isinstance(prompt, List): 424 | prompt = [prompt] * num_prompts 425 | if not isinstance(negative_prompt, List): 426 | negative_prompt = [negative_prompt] * num_prompts 427 | 428 | image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds) 429 | 430 | bs_embed, seq_len, _ = image_prompt_embeds.shape 431 | image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) 432 | image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 433 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) 434 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 435 | 436 | with torch.inference_mode(): 437 | ( 438 | prompt_embeds, 439 | negative_prompt_embeds, 440 | pooled_prompt_embeds, 441 | negative_pooled_prompt_embeds, 442 | ) = self.pipe.encode_prompt( 443 | prompt, 444 | num_images_per_prompt=num_samples, 445 | do_classifier_free_guidance=True, 446 | negative_prompt=negative_prompt, 447 | ) 448 | prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) 449 | negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) 450 | 451 | generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None 452 | images = self.pipe( 453 | prompt_embeds=prompt_embeds, 454 | negative_prompt_embeds=negative_prompt_embeds, 455 | pooled_prompt_embeds=pooled_prompt_embeds, 456 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, 457 | num_inference_steps=num_inference_steps, 458 | generator=generator, 459 | **kwargs, 460 | ).images 461 | 462 | return images 463 | -------------------------------------------------------------------------------- /ip_adapter/resampler.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py 2 | # and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py 3 | 4 | import math 5 | 6 | import torch 7 | import torch.nn as nn 8 | from einops import rearrange 9 | from einops.layers.torch import Rearrange 10 | 11 | 12 | # FFN 13 | def FeedForward(dim, mult=4): 14 | inner_dim = int(dim * mult) 15 | return nn.Sequential( 16 | nn.LayerNorm(dim), 17 | nn.Linear(dim, inner_dim, bias=False), 18 | nn.GELU(), 19 | nn.Linear(inner_dim, dim, bias=False), 20 | ) 21 | 22 | 23 | def reshape_tensor(x, heads): 24 | bs, length, width = x.shape 25 | # (bs, length, width) --> (bs, length, n_heads, dim_per_head) 26 | x = x.view(bs, length, heads, -1) 27 | # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) 28 | x = x.transpose(1, 2) 29 | # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) 30 | x = x.reshape(bs, heads, length, -1) 31 | return x 32 | 33 | 34 | class PerceiverAttention(nn.Module): 35 | def __init__(self, *, dim, dim_head=64, heads=8): 36 | super().__init__() 37 | self.scale = dim_head**-0.5 38 | self.dim_head = dim_head 39 | self.heads = heads 40 | inner_dim = dim_head * heads 41 | 42 | self.norm1 = nn.LayerNorm(dim) 43 | self.norm2 = nn.LayerNorm(dim) 44 | 45 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 46 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) 47 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 48 | 49 | def forward(self, x, latents): 50 | """ 51 | Args: 52 | x (torch.Tensor): image features 53 | shape (b, n1, D) 54 | latent (torch.Tensor): latent features 55 | shape (b, n2, D) 56 | """ 57 | x = self.norm1(x) 58 | latents = self.norm2(latents) 59 | 60 | b, l, _ = latents.shape 61 | 62 | q = self.to_q(latents) 63 | kv_input = torch.cat((x, latents), dim=-2) 64 | k, v = self.to_kv(kv_input).chunk(2, dim=-1) 65 | 66 | q = reshape_tensor(q, self.heads) 67 | k = reshape_tensor(k, self.heads) 68 | v = reshape_tensor(v, self.heads) 69 | 70 | # attention 71 | scale = 1 / math.sqrt(math.sqrt(self.dim_head)) 72 | weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards 73 | weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) 74 | out = weight @ v 75 | 76 | out = out.permute(0, 2, 1, 3).reshape(b, l, -1) 77 | 78 | return self.to_out(out) 79 | 80 | 81 | class Resampler(nn.Module): 82 | def __init__( 83 | self, 84 | dim=1024, 85 | depth=8, 86 | dim_head=64, 87 | heads=16, 88 | num_queries=8, 89 | embedding_dim=768, 90 | output_dim=1024, 91 | ff_mult=4, 92 | max_seq_len: int = 257, # CLIP tokens + CLS token 93 | apply_pos_emb: bool = False, 94 | num_latents_mean_pooled: int = 0, # number of latents derived from mean pooled representation of the sequence 95 | ): 96 | super().__init__() 97 | self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None 98 | 99 | self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) 100 | 101 | self.proj_in = nn.Linear(embedding_dim, dim) 102 | 103 | self.proj_out = nn.Linear(dim, output_dim) 104 | self.norm_out = nn.LayerNorm(output_dim) 105 | 106 | self.to_latents_from_mean_pooled_seq = ( 107 | nn.Sequential( 108 | nn.LayerNorm(dim), 109 | nn.Linear(dim, dim * num_latents_mean_pooled), 110 | Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled), 111 | ) 112 | if num_latents_mean_pooled > 0 113 | else None 114 | ) 115 | 116 | self.layers = nn.ModuleList([]) 117 | for _ in range(depth): 118 | self.layers.append( 119 | nn.ModuleList( 120 | [ 121 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), 122 | FeedForward(dim=dim, mult=ff_mult), 123 | ] 124 | ) 125 | ) 126 | 127 | def forward(self, x): 128 | if self.pos_emb is not None: 129 | n, device = x.shape[1], x.device 130 | pos_emb = self.pos_emb(torch.arange(n, device=device)) 131 | x = x + pos_emb 132 | 133 | latents = self.latents.repeat(x.size(0), 1, 1) 134 | 135 | x = self.proj_in(x) 136 | 137 | if self.to_latents_from_mean_pooled_seq: 138 | meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool)) 139 | meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq) 140 | latents = torch.cat((meanpooled_latents, latents), dim=-2) 141 | 142 | for attn, ff in self.layers: 143 | latents = attn(x, latents) + latents 144 | latents = ff(latents) + latents 145 | 146 | latents = self.proj_out(latents) 147 | return self.norm_out(latents) 148 | 149 | 150 | def masked_mean(t, *, dim, mask=None): 151 | if mask is None: 152 | return t.mean(dim=dim) 153 | 154 | denom = mask.sum(dim=dim, keepdim=True) 155 | mask = rearrange(mask, "b n -> b n 1") 156 | masked_t = t.masked_fill(~mask, 0.0) 157 | 158 | return masked_t.sum(dim=dim) / denom.clamp(min=1e-5) 159 | -------------------------------------------------------------------------------- /ip_adapter/test_resampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from resampler import Resampler 3 | from transformers import CLIPVisionModel 4 | 5 | BATCH_SIZE = 2 6 | OUTPUT_DIM = 1280 7 | NUM_QUERIES = 8 8 | NUM_LATENTS_MEAN_POOLED = 4 # 0 for no mean pooling (previous behavior) 9 | APPLY_POS_EMB = True # False for no positional embeddings (previous behavior) 10 | IMAGE_ENCODER_NAME_OR_PATH = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K" 11 | 12 | 13 | def main(): 14 | image_encoder = CLIPVisionModel.from_pretrained(IMAGE_ENCODER_NAME_OR_PATH) 15 | embedding_dim = image_encoder.config.hidden_size 16 | print(f"image_encoder hidden size: ", embedding_dim) 17 | 18 | image_proj_model = Resampler( 19 | dim=1024, 20 | depth=2, 21 | dim_head=64, 22 | heads=16, 23 | num_queries=NUM_QUERIES, 24 | embedding_dim=embedding_dim, 25 | output_dim=OUTPUT_DIM, 26 | ff_mult=2, 27 | max_seq_len=257, 28 | apply_pos_emb=APPLY_POS_EMB, 29 | num_latents_mean_pooled=NUM_LATENTS_MEAN_POOLED, 30 | ) 31 | 32 | dummy_images = torch.randn(BATCH_SIZE, 3, 224, 224) 33 | with torch.no_grad(): 34 | image_embeds = image_encoder(dummy_images, output_hidden_states=True).hidden_states[-2] 35 | print("image_embds shape: ", image_embeds.shape) 36 | 37 | with torch.no_grad(): 38 | ip_tokens = image_proj_model(image_embeds) 39 | print("ip_tokens shape:", ip_tokens.shape) 40 | assert ip_tokens.shape == (BATCH_SIZE, NUM_QUERIES + NUM_LATENTS_MEAN_POOLED, OUTPUT_DIM) 41 | 42 | 43 | if __name__ == "__main__": 44 | main() 45 | -------------------------------------------------------------------------------- /ip_adapter/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from PIL import Image 5 | 6 | attn_maps = {} 7 | def hook_fn(name): 8 | def forward_hook(module, input, output): 9 | if hasattr(module.processor, "attn_map"): 10 | attn_maps[name] = module.processor.attn_map 11 | del module.processor.attn_map 12 | 13 | return forward_hook 14 | 15 | def register_cross_attention_hook(unet): 16 | for name, module in unet.named_modules(): 17 | if name.split('.')[-1].startswith('attn2'): 18 | module.register_forward_hook(hook_fn(name)) 19 | 20 | return unet 21 | 22 | def upscale(attn_map, target_size): 23 | attn_map = torch.mean(attn_map, dim=0) 24 | attn_map = attn_map.permute(1,0) 25 | temp_size = None 26 | 27 | for i in range(0,5): 28 | scale = 2 ** i 29 | if ( target_size[0] // scale ) * ( target_size[1] // scale) == attn_map.shape[1]*64: 30 | temp_size = (target_size[0]//(scale*8), target_size[1]//(scale*8)) 31 | break 32 | 33 | assert temp_size is not None, "temp_size cannot is None" 34 | 35 | attn_map = attn_map.view(attn_map.shape[0], *temp_size) 36 | 37 | attn_map = F.interpolate( 38 | attn_map.unsqueeze(0).to(dtype=torch.float32), 39 | size=target_size, 40 | mode='bilinear', 41 | align_corners=False 42 | )[0] 43 | 44 | attn_map = torch.softmax(attn_map, dim=0) 45 | return attn_map 46 | def get_net_attn_map(image_size, batch_size=2, instance_or_negative=False, detach=True): 47 | 48 | idx = 0 if instance_or_negative else 1 49 | net_attn_maps = [] 50 | 51 | for name, attn_map in attn_maps.items(): 52 | attn_map = attn_map.cpu() if detach else attn_map 53 | attn_map = torch.chunk(attn_map, batch_size)[idx].squeeze() 54 | attn_map = upscale(attn_map, image_size) 55 | net_attn_maps.append(attn_map) 56 | 57 | net_attn_maps = torch.mean(torch.stack(net_attn_maps,dim=0),dim=0) 58 | 59 | return net_attn_maps 60 | 61 | def attnmaps2images(net_attn_maps): 62 | 63 | #total_attn_scores = 0 64 | images = [] 65 | 66 | for attn_map in net_attn_maps: 67 | attn_map = attn_map.cpu().numpy() 68 | #total_attn_scores += attn_map.mean().item() 69 | 70 | normalized_attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255 71 | normalized_attn_map = normalized_attn_map.astype(np.uint8) 72 | #print("norm: ", normalized_attn_map.shape) 73 | image = Image.fromarray(normalized_attn_map) 74 | 75 | #image = fix_save_attn_map(attn_map) 76 | images.append(image) 77 | 78 | #print(total_attn_scores) 79 | return images 80 | def is_torch2_available(): 81 | return hasattr(F, "scaled_dot_product_attention") 82 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "ip-adapter" 3 | version = "0.1.0" 4 | description = "IP-Adapter: Text Compatible Image Prompt Adapter for Text-to-Image Diffusion Models" 5 | authors = ["Ye, Hu", "Zhang, Jun", "Liu, Sibo", "Han, Xiao", "Yang, Wei"] 6 | license = "Apache-2.0" 7 | readme = "README.md" 8 | packages = [{ include = "ip_adapter" }] 9 | 10 | [tool.poetry.dependencies] 11 | python = ">=3.6" 12 | 13 | [tool.ruff] 14 | line-length = 119 15 | # Deprecation of Cuda 11.6 and Python 3.7 support for PyTorch 2.0 16 | target-version = "py38" 17 | 18 | # A list of file patterns to omit from linting, in addition to those specified by exclude. 19 | extend-exclude = ["__pycache__", "*.pyc", "*.egg-info", ".cache"] 20 | 21 | select = ["E", "F", "W", "C90", "I", "UP", "B", "C4", "RET", "RUF", "SIM"] 22 | 23 | 24 | ignore = [ 25 | "UP006", # UP006: Use list instead of typing.List for type annotations 26 | "UP007", # UP007: Use X | Y for type annotations 27 | "UP009", 28 | "UP035", 29 | "UP038", 30 | "E402", 31 | "RET504", 32 | ] 33 | 34 | [tool.isort] 35 | profile = "black" 36 | 37 | [tool.black] 38 | line-length = 119 39 | skip-string-normalization = 1 40 | 41 | [build-system] 42 | requires = ["poetry-core"] 43 | build-backend = "poetry.core.masonry.api" 44 | -------------------------------------------------------------------------------- /tutorial_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import argparse 4 | from pathlib import Path 5 | import json 6 | import itertools 7 | import time 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | from torchvision import transforms 12 | from PIL import Image 13 | from transformers import CLIPImageProcessor 14 | from accelerate import Accelerator 15 | from accelerate.logging import get_logger 16 | from accelerate.utils import ProjectConfiguration 17 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel 18 | from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection 19 | 20 | from ip_adapter.ip_adapter import ImageProjModel 21 | from ip_adapter.utils import is_torch2_available 22 | if is_torch2_available(): 23 | from ip_adapter.attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor 24 | else: 25 | from ip_adapter.attention_processor import IPAttnProcessor, AttnProcessor 26 | 27 | 28 | # Dataset 29 | class MyDataset(torch.utils.data.Dataset): 30 | 31 | def __init__(self, json_file, tokenizer, size=512, t_drop_rate=0.05, i_drop_rate=0.05, ti_drop_rate=0.05, image_root_path=""): 32 | super().__init__() 33 | 34 | self.tokenizer = tokenizer 35 | self.size = size 36 | self.i_drop_rate = i_drop_rate 37 | self.t_drop_rate = t_drop_rate 38 | self.ti_drop_rate = ti_drop_rate 39 | self.image_root_path = image_root_path 40 | 41 | self.data = json.load(open(json_file)) # list of dict: [{"image_file": "1.png", "text": "A dog"}] 42 | 43 | self.transform = transforms.Compose([ 44 | transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR), 45 | transforms.CenterCrop(self.size), 46 | transforms.ToTensor(), 47 | transforms.Normalize([0.5], [0.5]), 48 | ]) 49 | self.clip_image_processor = CLIPImageProcessor() 50 | 51 | def __getitem__(self, idx): 52 | item = self.data[idx] 53 | text = item["text"] 54 | image_file = item["image_file"] 55 | 56 | # read image 57 | raw_image = Image.open(os.path.join(self.image_root_path, image_file)) 58 | image = self.transform(raw_image.convert("RGB")) 59 | clip_image = self.clip_image_processor(images=raw_image, return_tensors="pt").pixel_values 60 | 61 | # drop 62 | drop_image_embed = 0 63 | rand_num = random.random() 64 | if rand_num < self.i_drop_rate: 65 | drop_image_embed = 1 66 | elif rand_num < (self.i_drop_rate + self.t_drop_rate): 67 | text = "" 68 | elif rand_num < (self.i_drop_rate + self.t_drop_rate + self.ti_drop_rate): 69 | text = "" 70 | drop_image_embed = 1 71 | # get text and tokenize 72 | text_input_ids = self.tokenizer( 73 | text, 74 | max_length=self.tokenizer.model_max_length, 75 | padding="max_length", 76 | truncation=True, 77 | return_tensors="pt" 78 | ).input_ids 79 | 80 | return { 81 | "image": image, 82 | "text_input_ids": text_input_ids, 83 | "clip_image": clip_image, 84 | "drop_image_embed": drop_image_embed 85 | } 86 | 87 | def __len__(self): 88 | return len(self.data) 89 | 90 | 91 | def collate_fn(data): 92 | images = torch.stack([example["image"] for example in data]) 93 | text_input_ids = torch.cat([example["text_input_ids"] for example in data], dim=0) 94 | clip_images = torch.cat([example["clip_image"] for example in data], dim=0) 95 | drop_image_embeds = [example["drop_image_embed"] for example in data] 96 | 97 | return { 98 | "images": images, 99 | "text_input_ids": text_input_ids, 100 | "clip_images": clip_images, 101 | "drop_image_embeds": drop_image_embeds 102 | } 103 | 104 | 105 | class IPAdapter(torch.nn.Module): 106 | """IP-Adapter""" 107 | def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None): 108 | super().__init__() 109 | self.unet = unet 110 | self.image_proj_model = image_proj_model 111 | self.adapter_modules = adapter_modules 112 | 113 | if ckpt_path is not None: 114 | self.load_from_checkpoint(ckpt_path) 115 | 116 | def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds): 117 | ip_tokens = self.image_proj_model(image_embeds) 118 | encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1) 119 | # Predict the noise residual 120 | noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample 121 | return noise_pred 122 | 123 | def load_from_checkpoint(self, ckpt_path: str): 124 | # Calculate original checksums 125 | orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()])) 126 | orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()])) 127 | 128 | state_dict = torch.load(ckpt_path, map_location="cpu") 129 | 130 | # Load state dict for image_proj_model and adapter_modules 131 | self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True) 132 | self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True) 133 | 134 | # Calculate new checksums 135 | new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()])) 136 | new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()])) 137 | 138 | # Verify if the weights have changed 139 | assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!" 140 | assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!" 141 | 142 | print(f"Successfully loaded weights from checkpoint {ckpt_path}") 143 | 144 | 145 | 146 | def parse_args(): 147 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 148 | parser.add_argument( 149 | "--pretrained_model_name_or_path", 150 | type=str, 151 | default=None, 152 | required=True, 153 | help="Path to pretrained model or model identifier from huggingface.co/models.", 154 | ) 155 | parser.add_argument( 156 | "--pretrained_ip_adapter_path", 157 | type=str, 158 | default=None, 159 | help="Path to pretrained ip adapter model. If not specified weights are initialized randomly.", 160 | ) 161 | parser.add_argument( 162 | "--data_json_file", 163 | type=str, 164 | default=None, 165 | required=True, 166 | help="Training data", 167 | ) 168 | parser.add_argument( 169 | "--data_root_path", 170 | type=str, 171 | default="", 172 | required=True, 173 | help="Training data root path", 174 | ) 175 | parser.add_argument( 176 | "--image_encoder_path", 177 | type=str, 178 | default=None, 179 | required=True, 180 | help="Path to CLIP image encoder", 181 | ) 182 | parser.add_argument( 183 | "--output_dir", 184 | type=str, 185 | default="sd-ip_adapter", 186 | help="The output directory where the model predictions and checkpoints will be written.", 187 | ) 188 | parser.add_argument( 189 | "--logging_dir", 190 | type=str, 191 | default="logs", 192 | help=( 193 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" 194 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." 195 | ), 196 | ) 197 | parser.add_argument( 198 | "--resolution", 199 | type=int, 200 | default=512, 201 | help=( 202 | "The resolution for input images" 203 | ), 204 | ) 205 | parser.add_argument( 206 | "--learning_rate", 207 | type=float, 208 | default=1e-4, 209 | help="Learning rate to use.", 210 | ) 211 | parser.add_argument("--weight_decay", type=float, default=1e-2, help="Weight decay to use.") 212 | parser.add_argument("--num_train_epochs", type=int, default=100) 213 | parser.add_argument( 214 | "--train_batch_size", type=int, default=8, help="Batch size (per device) for the training dataloader." 215 | ) 216 | parser.add_argument( 217 | "--dataloader_num_workers", 218 | type=int, 219 | default=0, 220 | help=( 221 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." 222 | ), 223 | ) 224 | parser.add_argument( 225 | "--save_steps", 226 | type=int, 227 | default=2000, 228 | help=( 229 | "Save a checkpoint of the training state every X updates" 230 | ), 231 | ) 232 | parser.add_argument( 233 | "--mixed_precision", 234 | type=str, 235 | default=None, 236 | choices=["no", "fp16", "bf16"], 237 | help=( 238 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" 239 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" 240 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." 241 | ), 242 | ) 243 | parser.add_argument( 244 | "--report_to", 245 | type=str, 246 | default="tensorboard", 247 | help=( 248 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' 249 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' 250 | ), 251 | ) 252 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 253 | 254 | args = parser.parse_args() 255 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 256 | if env_local_rank != -1 and env_local_rank != args.local_rank: 257 | args.local_rank = env_local_rank 258 | 259 | return args 260 | 261 | 262 | def main(): 263 | args = parse_args() 264 | logging_dir = Path(args.output_dir, args.logging_dir) 265 | 266 | accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) 267 | 268 | accelerator = Accelerator( 269 | mixed_precision=args.mixed_precision, 270 | log_with=args.report_to, 271 | project_config=accelerator_project_config, 272 | ) 273 | 274 | if accelerator.is_main_process: 275 | if args.output_dir is not None: 276 | os.makedirs(args.output_dir, exist_ok=True) 277 | 278 | # Load scheduler, tokenizer and models. 279 | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") 280 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") 281 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder") 282 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") 283 | unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") 284 | image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path) 285 | # freeze parameters of models to save more memory 286 | unet.requires_grad_(False) 287 | vae.requires_grad_(False) 288 | text_encoder.requires_grad_(False) 289 | image_encoder.requires_grad_(False) 290 | 291 | #ip-adapter 292 | image_proj_model = ImageProjModel( 293 | cross_attention_dim=unet.config.cross_attention_dim, 294 | clip_embeddings_dim=image_encoder.config.projection_dim, 295 | clip_extra_context_tokens=4, 296 | ) 297 | # init adapter modules 298 | attn_procs = {} 299 | unet_sd = unet.state_dict() 300 | for name in unet.attn_processors.keys(): 301 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim 302 | if name.startswith("mid_block"): 303 | hidden_size = unet.config.block_out_channels[-1] 304 | elif name.startswith("up_blocks"): 305 | block_id = int(name[len("up_blocks.")]) 306 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id] 307 | elif name.startswith("down_blocks"): 308 | block_id = int(name[len("down_blocks.")]) 309 | hidden_size = unet.config.block_out_channels[block_id] 310 | if cross_attention_dim is None: 311 | attn_procs[name] = AttnProcessor() 312 | else: 313 | layer_name = name.split(".processor")[0] 314 | weights = { 315 | "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"], 316 | "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"], 317 | } 318 | attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) 319 | attn_procs[name].load_state_dict(weights) 320 | unet.set_attn_processor(attn_procs) 321 | adapter_modules = torch.nn.ModuleList(unet.attn_processors.values()) 322 | 323 | ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules, args.pretrained_ip_adapter_path) 324 | 325 | weight_dtype = torch.float32 326 | if accelerator.mixed_precision == "fp16": 327 | weight_dtype = torch.float16 328 | elif accelerator.mixed_precision == "bf16": 329 | weight_dtype = torch.bfloat16 330 | #unet.to(accelerator.device, dtype=weight_dtype) 331 | vae.to(accelerator.device, dtype=weight_dtype) 332 | text_encoder.to(accelerator.device, dtype=weight_dtype) 333 | image_encoder.to(accelerator.device, dtype=weight_dtype) 334 | 335 | # optimizer 336 | params_to_opt = itertools.chain(ip_adapter.image_proj_model.parameters(), ip_adapter.adapter_modules.parameters()) 337 | optimizer = torch.optim.AdamW(params_to_opt, lr=args.learning_rate, weight_decay=args.weight_decay) 338 | 339 | # dataloader 340 | train_dataset = MyDataset(args.data_json_file, tokenizer=tokenizer, size=args.resolution, image_root_path=args.data_root_path) 341 | train_dataloader = torch.utils.data.DataLoader( 342 | train_dataset, 343 | shuffle=True, 344 | collate_fn=collate_fn, 345 | batch_size=args.train_batch_size, 346 | num_workers=args.dataloader_num_workers, 347 | ) 348 | 349 | # Prepare everything with our `accelerator`. 350 | ip_adapter, optimizer, train_dataloader = accelerator.prepare(ip_adapter, optimizer, train_dataloader) 351 | 352 | global_step = 0 353 | for epoch in range(0, args.num_train_epochs): 354 | begin = time.perf_counter() 355 | for step, batch in enumerate(train_dataloader): 356 | load_data_time = time.perf_counter() - begin 357 | with accelerator.accumulate(ip_adapter): 358 | # Convert images to latent space 359 | with torch.no_grad(): 360 | latents = vae.encode(batch["images"].to(accelerator.device, dtype=weight_dtype)).latent_dist.sample() 361 | latents = latents * vae.config.scaling_factor 362 | 363 | # Sample noise that we'll add to the latents 364 | noise = torch.randn_like(latents) 365 | bsz = latents.shape[0] 366 | # Sample a random timestep for each image 367 | timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device) 368 | timesteps = timesteps.long() 369 | 370 | # Add noise to the latents according to the noise magnitude at each timestep 371 | # (this is the forward diffusion process) 372 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 373 | 374 | with torch.no_grad(): 375 | image_embeds = image_encoder(batch["clip_images"].to(accelerator.device, dtype=weight_dtype)).image_embeds 376 | image_embeds_ = [] 377 | for image_embed, drop_image_embed in zip(image_embeds, batch["drop_image_embeds"]): 378 | if drop_image_embed == 1: 379 | image_embeds_.append(torch.zeros_like(image_embed)) 380 | else: 381 | image_embeds_.append(image_embed) 382 | image_embeds = torch.stack(image_embeds_) 383 | 384 | with torch.no_grad(): 385 | encoder_hidden_states = text_encoder(batch["text_input_ids"].to(accelerator.device))[0] 386 | 387 | noise_pred = ip_adapter(noisy_latents, timesteps, encoder_hidden_states, image_embeds) 388 | 389 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") 390 | 391 | # Gather the losses across all processes for logging (if we use distributed training). 392 | avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean().item() 393 | 394 | # Backpropagate 395 | accelerator.backward(loss) 396 | optimizer.step() 397 | optimizer.zero_grad() 398 | 399 | if accelerator.is_main_process: 400 | print("Epoch {}, step {}, data_time: {}, time: {}, step_loss: {}".format( 401 | epoch, step, load_data_time, time.perf_counter() - begin, avg_loss)) 402 | 403 | global_step += 1 404 | 405 | if global_step % args.save_steps == 0: 406 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") 407 | accelerator.save_state(save_path) 408 | 409 | begin = time.perf_counter() 410 | 411 | if __name__ == "__main__": 412 | main() 413 | -------------------------------------------------------------------------------- /tutorial_train_faceid.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import argparse 4 | from pathlib import Path 5 | import json 6 | import itertools 7 | import time 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | from torchvision import transforms 12 | from PIL import Image 13 | from transformers import CLIPImageProcessor 14 | from accelerate import Accelerator 15 | from accelerate.logging import get_logger 16 | from accelerate.utils import ProjectConfiguration 17 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel 18 | from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection 19 | 20 | from ip_adapter.ip_adapter_faceid import MLPProjModel 21 | from ip_adapter.utils import is_torch2_available 22 | from ip_adapter.attention_processor_faceid import LoRAAttnProcessor, LoRAIPAttnProcessor 23 | 24 | 25 | # Dataset 26 | class MyDataset(torch.utils.data.Dataset): 27 | 28 | def __init__(self, json_file, tokenizer, size=512, t_drop_rate=0.05, i_drop_rate=0.05, ti_drop_rate=0.05, image_root_path=""): 29 | super().__init__() 30 | 31 | self.tokenizer = tokenizer 32 | self.size = size 33 | self.i_drop_rate = i_drop_rate 34 | self.t_drop_rate = t_drop_rate 35 | self.ti_drop_rate = ti_drop_rate 36 | self.image_root_path = image_root_path 37 | 38 | self.data = json.load(open(json_file)) # list of dict: [{"image_file": "1.png", "id_embed_file": "faceid.bin"}] 39 | 40 | self.transform = transforms.Compose([ 41 | transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR), 42 | transforms.CenterCrop(self.size), 43 | transforms.ToTensor(), 44 | transforms.Normalize([0.5], [0.5]), 45 | ]) 46 | 47 | 48 | 49 | def __getitem__(self, idx): 50 | item = self.data[idx] 51 | text = item["text"] 52 | image_file = item["image_file"] 53 | 54 | # read image 55 | raw_image = Image.open(os.path.join(self.image_root_path, image_file)) 56 | image = self.transform(raw_image.convert("RGB")) 57 | 58 | face_id_embed = torch.load(item["id_embed_file"], map_location="cpu") 59 | face_id_embed = torch.from_numpy(face_id_embed) 60 | 61 | # drop 62 | drop_image_embed = 0 63 | rand_num = random.random() 64 | if rand_num < self.i_drop_rate: 65 | drop_image_embed = 1 66 | elif rand_num < (self.i_drop_rate + self.t_drop_rate): 67 | text = "" 68 | elif rand_num < (self.i_drop_rate + self.t_drop_rate + self.ti_drop_rate): 69 | text = "" 70 | drop_image_embed = 1 71 | if drop_image_embed: 72 | face_id_embed = torch.zeros_like(face_id_embed) 73 | # get text and tokenize 74 | text_input_ids = self.tokenizer( 75 | text, 76 | max_length=self.tokenizer.model_max_length, 77 | padding="max_length", 78 | truncation=True, 79 | return_tensors="pt" 80 | ).input_ids 81 | 82 | return { 83 | "image": image, 84 | "text_input_ids": text_input_ids, 85 | "face_id_embed": face_id_embed, 86 | "drop_image_embed": drop_image_embed 87 | } 88 | 89 | def __len__(self): 90 | return len(self.data) 91 | 92 | 93 | def collate_fn(data): 94 | images = torch.stack([example["image"] for example in data]) 95 | text_input_ids = torch.cat([example["text_input_ids"] for example in data], dim=0) 96 | face_id_embed = torch.stack([example["face_id_embed"] for example in data]) 97 | drop_image_embeds = [example["drop_image_embed"] for example in data] 98 | 99 | return { 100 | "images": images, 101 | "text_input_ids": text_input_ids, 102 | "face_id_embed": face_id_embed, 103 | "drop_image_embeds": drop_image_embeds 104 | } 105 | 106 | 107 | class IPAdapter(torch.nn.Module): 108 | """IP-Adapter""" 109 | def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None): 110 | super().__init__() 111 | self.unet = unet 112 | self.image_proj_model = image_proj_model 113 | self.adapter_modules = adapter_modules 114 | 115 | if ckpt_path is not None: 116 | self.load_from_checkpoint(ckpt_path) 117 | 118 | def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds): 119 | ip_tokens = self.image_proj_model(image_embeds) 120 | encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1) 121 | # Predict the noise residual 122 | noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample 123 | return noise_pred 124 | 125 | def load_from_checkpoint(self, ckpt_path: str): 126 | # Calculate original checksums 127 | orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()])) 128 | orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()])) 129 | 130 | state_dict = torch.load(ckpt_path, map_location="cpu") 131 | 132 | # Load state dict for image_proj_model and adapter_modules 133 | self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True) 134 | self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True) 135 | 136 | # Calculate new checksums 137 | new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()])) 138 | new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()])) 139 | 140 | # Verify if the weights have changed 141 | assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!" 142 | assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!" 143 | 144 | print(f"Successfully loaded weights from checkpoint {ckpt_path}") 145 | 146 | 147 | 148 | def parse_args(): 149 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 150 | parser.add_argument( 151 | "--pretrained_model_name_or_path", 152 | type=str, 153 | default=None, 154 | required=True, 155 | help="Path to pretrained model or model identifier from huggingface.co/models.", 156 | ) 157 | parser.add_argument( 158 | "--pretrained_ip_adapter_path", 159 | type=str, 160 | default=None, 161 | help="Path to pretrained ip adapter model. If not specified weights are initialized randomly.", 162 | ) 163 | parser.add_argument( 164 | "--data_json_file", 165 | type=str, 166 | default=None, 167 | required=True, 168 | help="Training data", 169 | ) 170 | parser.add_argument( 171 | "--data_root_path", 172 | type=str, 173 | default="", 174 | required=True, 175 | help="Training data root path", 176 | ) 177 | parser.add_argument( 178 | "--image_encoder_path", 179 | type=str, 180 | default=None, 181 | required=True, 182 | help="Path to CLIP image encoder", 183 | ) 184 | parser.add_argument( 185 | "--output_dir", 186 | type=str, 187 | default="sd-ip_adapter", 188 | help="The output directory where the model predictions and checkpoints will be written.", 189 | ) 190 | parser.add_argument( 191 | "--logging_dir", 192 | type=str, 193 | default="logs", 194 | help=( 195 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" 196 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." 197 | ), 198 | ) 199 | parser.add_argument( 200 | "--resolution", 201 | type=int, 202 | default=512, 203 | help=( 204 | "The resolution for input images" 205 | ), 206 | ) 207 | parser.add_argument( 208 | "--learning_rate", 209 | type=float, 210 | default=1e-4, 211 | help="Learning rate to use.", 212 | ) 213 | parser.add_argument("--weight_decay", type=float, default=1e-2, help="Weight decay to use.") 214 | parser.add_argument("--num_train_epochs", type=int, default=100) 215 | parser.add_argument( 216 | "--train_batch_size", type=int, default=8, help="Batch size (per device) for the training dataloader." 217 | ) 218 | parser.add_argument( 219 | "--dataloader_num_workers", 220 | type=int, 221 | default=0, 222 | help=( 223 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." 224 | ), 225 | ) 226 | parser.add_argument( 227 | "--save_steps", 228 | type=int, 229 | default=2000, 230 | help=( 231 | "Save a checkpoint of the training state every X updates" 232 | ), 233 | ) 234 | parser.add_argument( 235 | "--mixed_precision", 236 | type=str, 237 | default=None, 238 | choices=["no", "fp16", "bf16"], 239 | help=( 240 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" 241 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" 242 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." 243 | ), 244 | ) 245 | parser.add_argument( 246 | "--report_to", 247 | type=str, 248 | default="tensorboard", 249 | help=( 250 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' 251 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' 252 | ), 253 | ) 254 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 255 | 256 | args = parser.parse_args() 257 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 258 | if env_local_rank != -1 and env_local_rank != args.local_rank: 259 | args.local_rank = env_local_rank 260 | 261 | return args 262 | 263 | 264 | def main(): 265 | args = parse_args() 266 | logging_dir = Path(args.output_dir, args.logging_dir) 267 | 268 | accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) 269 | 270 | accelerator = Accelerator( 271 | mixed_precision=args.mixed_precision, 272 | log_with=args.report_to, 273 | project_config=accelerator_project_config, 274 | ) 275 | 276 | if accelerator.is_main_process: 277 | if args.output_dir is not None: 278 | os.makedirs(args.output_dir, exist_ok=True) 279 | 280 | # Load scheduler, tokenizer and models. 281 | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") 282 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") 283 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder") 284 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") 285 | unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") 286 | # image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path) 287 | # freeze parameters of models to save more memory 288 | unet.requires_grad_(False) 289 | vae.requires_grad_(False) 290 | text_encoder.requires_grad_(False) 291 | #image_encoder.requires_grad_(False) 292 | 293 | #ip-adapter 294 | image_proj_model = MLPProjModel( 295 | cross_attention_dim=unet.config.cross_attention_dim, 296 | id_embeddings_dim=512, 297 | num_tokens=4, 298 | ) 299 | # init adapter modules 300 | lora_rank = 128 301 | attn_procs = {} 302 | unet_sd = unet.state_dict() 303 | for name in unet.attn_processors.keys(): 304 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim 305 | if name.startswith("mid_block"): 306 | hidden_size = unet.config.block_out_channels[-1] 307 | elif name.startswith("up_blocks"): 308 | block_id = int(name[len("up_blocks.")]) 309 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id] 310 | elif name.startswith("down_blocks"): 311 | block_id = int(name[len("down_blocks.")]) 312 | hidden_size = unet.config.block_out_channels[block_id] 313 | if cross_attention_dim is None: 314 | attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=lora_rank) 315 | else: 316 | layer_name = name.split(".processor")[0] 317 | weights = { 318 | "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"], 319 | "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"], 320 | } 321 | attn_procs[name] = LoRAIPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=lora_rank) 322 | attn_procs[name].load_state_dict(weights, strict=False) 323 | unet.set_attn_processor(attn_procs) 324 | adapter_modules = torch.nn.ModuleList(unet.attn_processors.values()) 325 | 326 | ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules, args.pretrained_ip_adapter_path) 327 | 328 | weight_dtype = torch.float32 329 | if accelerator.mixed_precision == "fp16": 330 | weight_dtype = torch.float16 331 | elif accelerator.mixed_precision == "bf16": 332 | weight_dtype = torch.bfloat16 333 | #unet.to(accelerator.device, dtype=weight_dtype) 334 | vae.to(accelerator.device, dtype=weight_dtype) 335 | text_encoder.to(accelerator.device, dtype=weight_dtype) 336 | #image_encoder.to(accelerator.device, dtype=weight_dtype) 337 | 338 | # optimizer 339 | params_to_opt = itertools.chain(ip_adapter.image_proj_model.parameters(), ip_adapter.adapter_modules.parameters()) 340 | optimizer = torch.optim.AdamW(params_to_opt, lr=args.learning_rate, weight_decay=args.weight_decay) 341 | 342 | # dataloader 343 | train_dataset = MyDataset(args.data_json_file, tokenizer=tokenizer, size=args.resolution, image_root_path=args.data_root_path) 344 | train_dataloader = torch.utils.data.DataLoader( 345 | train_dataset, 346 | shuffle=True, 347 | collate_fn=collate_fn, 348 | batch_size=args.train_batch_size, 349 | num_workers=args.dataloader_num_workers, 350 | ) 351 | 352 | # Prepare everything with our `accelerator`. 353 | ip_adapter, optimizer, train_dataloader = accelerator.prepare(ip_adapter, optimizer, train_dataloader) 354 | 355 | global_step = 0 356 | for epoch in range(0, args.num_train_epochs): 357 | begin = time.perf_counter() 358 | for step, batch in enumerate(train_dataloader): 359 | load_data_time = time.perf_counter() - begin 360 | with accelerator.accumulate(ip_adapter): 361 | # Convert images to latent space 362 | with torch.no_grad(): 363 | latents = vae.encode(batch["images"].to(accelerator.device, dtype=weight_dtype)).latent_dist.sample() 364 | latents = latents * vae.config.scaling_factor 365 | 366 | # Sample noise that we'll add to the latents 367 | noise = torch.randn_like(latents) 368 | bsz = latents.shape[0] 369 | # Sample a random timestep for each image 370 | timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device) 371 | timesteps = timesteps.long() 372 | 373 | # Add noise to the latents according to the noise magnitude at each timestep 374 | # (this is the forward diffusion process) 375 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 376 | 377 | image_embeds = batch["face_id_embed"].to(accelerator.device, dtype=weight_dtype) 378 | 379 | with torch.no_grad(): 380 | encoder_hidden_states = text_encoder(batch["text_input_ids"].to(accelerator.device))[0] 381 | 382 | noise_pred = ip_adapter(noisy_latents, timesteps, encoder_hidden_states, image_embeds) 383 | 384 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") 385 | 386 | # Gather the losses across all processes for logging (if we use distributed training). 387 | avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean().item() 388 | 389 | # Backpropagate 390 | accelerator.backward(loss) 391 | optimizer.step() 392 | optimizer.zero_grad() 393 | 394 | if accelerator.is_main_process: 395 | print("Epoch {}, step {}, data_time: {}, time: {}, step_loss: {}".format( 396 | epoch, step, load_data_time, time.perf_counter() - begin, avg_loss)) 397 | 398 | global_step += 1 399 | 400 | if global_step % args.save_steps == 0: 401 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") 402 | accelerator.save_state(save_path) 403 | 404 | begin = time.perf_counter() 405 | 406 | if __name__ == "__main__": 407 | main() 408 | -------------------------------------------------------------------------------- /tutorial_train_plus.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import argparse 4 | from pathlib import Path 5 | import json 6 | import itertools 7 | import time 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | from torchvision import transforms 12 | from PIL import Image 13 | from transformers import CLIPImageProcessor 14 | from accelerate import Accelerator 15 | from accelerate.logging import get_logger 16 | from accelerate.utils import ProjectConfiguration 17 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel 18 | from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection 19 | 20 | from ip_adapter.resampler import Resampler 21 | from ip_adapter.utils import is_torch2_available 22 | if is_torch2_available(): 23 | from ip_adapter.attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor 24 | else: 25 | from ip_adapter.attention_processor import IPAttnProcessor, AttnProcessor 26 | 27 | 28 | # Dataset 29 | class MyDataset(torch.utils.data.Dataset): 30 | 31 | def __init__(self, json_file, tokenizer, size=512, t_drop_rate=0.05, i_drop_rate=0.05, ti_drop_rate=0.05, image_root_path=""): 32 | super().__init__() 33 | 34 | self.tokenizer = tokenizer 35 | self.size = size 36 | self.i_drop_rate = i_drop_rate 37 | self.t_drop_rate = t_drop_rate 38 | self.ti_drop_rate = ti_drop_rate 39 | self.image_root_path = image_root_path 40 | 41 | self.data = json.load(open(json_file)) # list of dict: [{"image_file": "1.png", "text": "A dog"}] 42 | 43 | self.transform = transforms.Compose([ 44 | transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR), 45 | transforms.CenterCrop(self.size), 46 | transforms.ToTensor(), 47 | transforms.Normalize([0.5], [0.5]), 48 | ]) 49 | self.clip_image_processor = CLIPImageProcessor() 50 | 51 | def __getitem__(self, idx): 52 | item = self.data[idx] 53 | text = item["text"] 54 | image_file = item["image_file"] 55 | 56 | # read image 57 | raw_image = Image.open(os.path.join(self.image_root_path, image_file)) 58 | image = self.transform(raw_image.convert("RGB")) 59 | clip_image = self.clip_image_processor(images=raw_image, return_tensors="pt").pixel_values 60 | 61 | # drop 62 | drop_image_embed = 0 63 | rand_num = random.random() 64 | if rand_num < self.i_drop_rate: 65 | drop_image_embed = 1 66 | elif rand_num < (self.i_drop_rate + self.t_drop_rate): 67 | text = "" 68 | elif rand_num < (self.i_drop_rate + self.t_drop_rate + self.ti_drop_rate): 69 | text = "" 70 | drop_image_embed = 1 71 | # get text and tokenize 72 | text_input_ids = self.tokenizer( 73 | text, 74 | max_length=self.tokenizer.model_max_length, 75 | padding="max_length", 76 | truncation=True, 77 | return_tensors="pt" 78 | ).input_ids 79 | 80 | return { 81 | "image": image, 82 | "text_input_ids": text_input_ids, 83 | "clip_image": clip_image, 84 | "drop_image_embed": drop_image_embed 85 | } 86 | 87 | def __len__(self): 88 | return len(self.data) 89 | 90 | 91 | def collate_fn(data): 92 | images = torch.stack([example["image"] for example in data]) 93 | text_input_ids = torch.cat([example["text_input_ids"] for example in data], dim=0) 94 | clip_images = torch.cat([example["clip_image"] for example in data], dim=0) 95 | drop_image_embeds = [example["drop_image_embed"] for example in data] 96 | 97 | return { 98 | "images": images, 99 | "text_input_ids": text_input_ids, 100 | "clip_images": clip_images, 101 | "drop_image_embeds": drop_image_embeds 102 | } 103 | 104 | 105 | class IPAdapter(torch.nn.Module): 106 | """IP-Adapter""" 107 | def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None): 108 | super().__init__() 109 | self.unet = unet 110 | self.image_proj_model = image_proj_model 111 | self.adapter_modules = adapter_modules 112 | 113 | if ckpt_path is not None: 114 | self.load_from_checkpoint(ckpt_path) 115 | 116 | def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds): 117 | ip_tokens = self.image_proj_model(image_embeds) 118 | encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1) 119 | # Predict the noise residual 120 | noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample 121 | return noise_pred 122 | 123 | def load_from_checkpoint(self, ckpt_path: str): 124 | # Calculate original checksums 125 | orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()])) 126 | orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()])) 127 | 128 | state_dict = torch.load(ckpt_path, map_location="cpu") 129 | 130 | # Check if 'latents' exists in both the saved state_dict and the current model's state_dict 131 | strict_load_image_proj_model = True 132 | if "latents" in state_dict["image_proj"] and "latents" in self.image_proj_model.state_dict(): 133 | # Check if the shapes are mismatched 134 | if state_dict["image_proj"]["latents"].shape != self.image_proj_model.state_dict()["latents"].shape: 135 | print(f"Shapes of 'image_proj.latents' in checkpoint {ckpt_path} and current model do not match.") 136 | print("Removing 'latents' from checkpoint and loading the rest of the weights.") 137 | del state_dict["image_proj"]["latents"] 138 | strict_load_image_proj_model = False 139 | 140 | # Load state dict for image_proj_model and adapter_modules 141 | self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict_load_image_proj_model) 142 | self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True) 143 | 144 | # Calculate new checksums 145 | new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()])) 146 | new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()])) 147 | 148 | # Verify if the weights have changed 149 | assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!" 150 | assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!" 151 | 152 | print(f"Successfully loaded weights from checkpoint {ckpt_path}") 153 | 154 | 155 | 156 | def parse_args(): 157 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 158 | parser.add_argument( 159 | "--pretrained_model_name_or_path", 160 | type=str, 161 | default=None, 162 | required=True, 163 | help="Path to pretrained model or model identifier from huggingface.co/models.", 164 | ) 165 | parser.add_argument( 166 | "--pretrained_ip_adapter_path", 167 | type=str, 168 | default=None, 169 | help="Path to pretrained ip adapter model. If not specified weights are initialized randomly.", 170 | ) 171 | parser.add_argument( 172 | "--num_tokens", 173 | type=int, 174 | default=16, 175 | help="Number of tokens to query from the CLIP image encoding.", 176 | ) 177 | parser.add_argument( 178 | "--data_json_file", 179 | type=str, 180 | default=None, 181 | required=True, 182 | help="Training data", 183 | ) 184 | parser.add_argument( 185 | "--data_root_path", 186 | type=str, 187 | default="", 188 | required=True, 189 | help="Training data root path", 190 | ) 191 | parser.add_argument( 192 | "--image_encoder_path", 193 | type=str, 194 | default=None, 195 | required=True, 196 | help="Path to CLIP image encoder", 197 | ) 198 | parser.add_argument( 199 | "--output_dir", 200 | type=str, 201 | default="sd-ip_adapter", 202 | help="The output directory where the model predictions and checkpoints will be written.", 203 | ) 204 | parser.add_argument( 205 | "--logging_dir", 206 | type=str, 207 | default="logs", 208 | help=( 209 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" 210 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." 211 | ), 212 | ) 213 | parser.add_argument( 214 | "--resolution", 215 | type=int, 216 | default=512, 217 | help=( 218 | "The resolution for input images" 219 | ), 220 | ) 221 | parser.add_argument( 222 | "--learning_rate", 223 | type=float, 224 | default=1e-4, 225 | help="Learning rate to use.", 226 | ) 227 | parser.add_argument("--weight_decay", type=float, default=1e-2, help="Weight decay to use.") 228 | parser.add_argument("--num_train_epochs", type=int, default=100) 229 | parser.add_argument( 230 | "--train_batch_size", type=int, default=8, help="Batch size (per device) for the training dataloader." 231 | ) 232 | parser.add_argument( 233 | "--dataloader_num_workers", 234 | type=int, 235 | default=0, 236 | help=( 237 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." 238 | ), 239 | ) 240 | parser.add_argument( 241 | "--save_steps", 242 | type=int, 243 | default=2000, 244 | help=( 245 | "Save a checkpoint of the training state every X updates" 246 | ), 247 | ) 248 | parser.add_argument( 249 | "--mixed_precision", 250 | type=str, 251 | default=None, 252 | choices=["no", "fp16", "bf16"], 253 | help=( 254 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" 255 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" 256 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." 257 | ), 258 | ) 259 | parser.add_argument( 260 | "--report_to", 261 | type=str, 262 | default="tensorboard", 263 | help=( 264 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' 265 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' 266 | ), 267 | ) 268 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 269 | 270 | args = parser.parse_args() 271 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 272 | if env_local_rank != -1 and env_local_rank != args.local_rank: 273 | args.local_rank = env_local_rank 274 | 275 | return args 276 | 277 | 278 | def main(): 279 | args = parse_args() 280 | logging_dir = Path(args.output_dir, args.logging_dir) 281 | 282 | accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) 283 | 284 | accelerator = Accelerator( 285 | mixed_precision=args.mixed_precision, 286 | log_with=args.report_to, 287 | project_config=accelerator_project_config, 288 | ) 289 | 290 | if accelerator.is_main_process: 291 | if args.output_dir is not None: 292 | os.makedirs(args.output_dir, exist_ok=True) 293 | 294 | # Load scheduler, tokenizer and models. 295 | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") 296 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") 297 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder") 298 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") 299 | unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") 300 | image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path) 301 | # freeze parameters of models to save more memory 302 | unet.requires_grad_(False) 303 | vae.requires_grad_(False) 304 | text_encoder.requires_grad_(False) 305 | image_encoder.requires_grad_(False) 306 | 307 | #ip-adapter-plus 308 | image_proj_model = Resampler( 309 | dim=unet.config.cross_attention_dim, 310 | depth=4, 311 | dim_head=64, 312 | heads=12, 313 | num_queries=args.num_tokens, 314 | embedding_dim=image_encoder.config.hidden_size, 315 | output_dim=unet.config.cross_attention_dim, 316 | ff_mult=4 317 | ) 318 | # init adapter modules 319 | attn_procs = {} 320 | unet_sd = unet.state_dict() 321 | for name in unet.attn_processors.keys(): 322 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim 323 | if name.startswith("mid_block"): 324 | hidden_size = unet.config.block_out_channels[-1] 325 | elif name.startswith("up_blocks"): 326 | block_id = int(name[len("up_blocks.")]) 327 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id] 328 | elif name.startswith("down_blocks"): 329 | block_id = int(name[len("down_blocks.")]) 330 | hidden_size = unet.config.block_out_channels[block_id] 331 | if cross_attention_dim is None: 332 | attn_procs[name] = AttnProcessor() 333 | else: 334 | layer_name = name.split(".processor")[0] 335 | weights = { 336 | "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"], 337 | "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"], 338 | } 339 | attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, num_tokens=args.num_tokens) 340 | attn_procs[name].load_state_dict(weights) 341 | unet.set_attn_processor(attn_procs) 342 | adapter_modules = torch.nn.ModuleList(unet.attn_processors.values()) 343 | 344 | ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules, args.pretrained_ip_adapter_path) 345 | 346 | weight_dtype = torch.float32 347 | if accelerator.mixed_precision == "fp16": 348 | weight_dtype = torch.float16 349 | elif accelerator.mixed_precision == "bf16": 350 | weight_dtype = torch.bfloat16 351 | #unet.to(accelerator.device, dtype=weight_dtype) 352 | vae.to(accelerator.device, dtype=weight_dtype) 353 | text_encoder.to(accelerator.device, dtype=weight_dtype) 354 | image_encoder.to(accelerator.device, dtype=weight_dtype) 355 | 356 | # optimizer 357 | params_to_opt = itertools.chain(ip_adapter.image_proj_model.parameters(), ip_adapter.adapter_modules.parameters()) 358 | optimizer = torch.optim.AdamW(params_to_opt, lr=args.learning_rate, weight_decay=args.weight_decay) 359 | 360 | # dataloader 361 | train_dataset = MyDataset(args.data_json_file, tokenizer=tokenizer, size=args.resolution, image_root_path=args.data_root_path) 362 | train_dataloader = torch.utils.data.DataLoader( 363 | train_dataset, 364 | shuffle=True, 365 | collate_fn=collate_fn, 366 | batch_size=args.train_batch_size, 367 | num_workers=args.dataloader_num_workers, 368 | ) 369 | 370 | # Prepare everything with our `accelerator`. 371 | ip_adapter, optimizer, train_dataloader = accelerator.prepare(ip_adapter, optimizer, train_dataloader) 372 | 373 | global_step = 0 374 | for epoch in range(0, args.num_train_epochs): 375 | begin = time.perf_counter() 376 | for step, batch in enumerate(train_dataloader): 377 | load_data_time = time.perf_counter() - begin 378 | with accelerator.accumulate(ip_adapter): 379 | # Convert images to latent space 380 | with torch.no_grad(): 381 | latents = vae.encode(batch["images"].to(accelerator.device, dtype=weight_dtype)).latent_dist.sample() 382 | latents = latents * vae.config.scaling_factor 383 | 384 | # Sample noise that we'll add to the latents 385 | noise = torch.randn_like(latents) 386 | bsz = latents.shape[0] 387 | # Sample a random timestep for each image 388 | timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device) 389 | timesteps = timesteps.long() 390 | 391 | # Add noise to the latents according to the noise magnitude at each timestep 392 | # (this is the forward diffusion process) 393 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 394 | 395 | clip_images = [] 396 | for clip_image, drop_image_embed in zip(batch["clip_images"], batch["drop_image_embeds"]): 397 | if drop_image_embed == 1: 398 | clip_images.append(torch.zeros_like(clip_image)) 399 | else: 400 | clip_images.append(clip_image) 401 | clip_images = torch.stack(clip_images, dim=0) 402 | with torch.no_grad(): 403 | image_embeds = image_encoder(clip_images.to(accelerator.device, dtype=weight_dtype), output_hidden_states=True).hidden_states[-2] 404 | 405 | with torch.no_grad(): 406 | encoder_hidden_states = text_encoder(batch["text_input_ids"].to(accelerator.device))[0] 407 | 408 | noise_pred = ip_adapter(noisy_latents, timesteps, encoder_hidden_states, image_embeds) 409 | 410 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") 411 | 412 | # Gather the losses across all processes for logging (if we use distributed training). 413 | avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean().item() 414 | 415 | # Backpropagate 416 | accelerator.backward(loss) 417 | optimizer.step() 418 | optimizer.zero_grad() 419 | 420 | if accelerator.is_main_process: 421 | print("Epoch {}, step {}, data_time: {}, time: {}, step_loss: {}".format( 422 | epoch, step, load_data_time, time.perf_counter() - begin, avg_loss)) 423 | 424 | global_step += 1 425 | 426 | if global_step % args.save_steps == 0: 427 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") 428 | accelerator.save_state(save_path) 429 | 430 | begin = time.perf_counter() 431 | 432 | if __name__ == "__main__": 433 | main() 434 | -------------------------------------------------------------------------------- /tutorial_train_sdxl.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import argparse 4 | from pathlib import Path 5 | import json 6 | import itertools 7 | import time 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | import numpy as np 12 | from torchvision import transforms 13 | from PIL import Image 14 | from transformers import CLIPImageProcessor 15 | from accelerate import Accelerator 16 | from accelerate.logging import get_logger 17 | from accelerate.utils import ProjectConfiguration 18 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel 19 | from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection, CLIPTextModelWithProjection 20 | 21 | from ip_adapter.ip_adapter import ImageProjModel 22 | from ip_adapter.utils import is_torch2_available 23 | if is_torch2_available(): 24 | from ip_adapter.attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor 25 | else: 26 | from ip_adapter.attention_processor import IPAttnProcessor, AttnProcessor 27 | 28 | 29 | # Dataset 30 | class MyDataset(torch.utils.data.Dataset): 31 | 32 | def __init__(self, json_file, tokenizer, tokenizer_2, size=1024, center_crop=True, t_drop_rate=0.05, i_drop_rate=0.05, ti_drop_rate=0.05, image_root_path=""): 33 | super().__init__() 34 | 35 | self.tokenizer = tokenizer 36 | self.tokenizer_2 = tokenizer_2 37 | self.size = size 38 | self.center_crop = center_crop 39 | self.i_drop_rate = i_drop_rate 40 | self.t_drop_rate = t_drop_rate 41 | self.ti_drop_rate = ti_drop_rate 42 | self.image_root_path = image_root_path 43 | 44 | self.data = json.load(open(json_file)) # list of dict: [{"image_file": "1.png", "text": "A dog"}] 45 | 46 | self.transform = transforms.Compose([ 47 | transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR), 48 | transforms.ToTensor(), 49 | transforms.Normalize([0.5], [0.5]), 50 | ]) 51 | 52 | self.clip_image_processor = CLIPImageProcessor() 53 | 54 | def __getitem__(self, idx): 55 | item = self.data[idx] 56 | text = item["text"] 57 | image_file = item["image_file"] 58 | 59 | # read image 60 | raw_image = Image.open(os.path.join(self.image_root_path, image_file)) 61 | 62 | # original size 63 | original_width, original_height = raw_image.size 64 | original_size = torch.tensor([original_height, original_width]) 65 | 66 | image_tensor = self.transform(raw_image.convert("RGB")) 67 | # random crop 68 | delta_h = image_tensor.shape[1] - self.size 69 | delta_w = image_tensor.shape[2] - self.size 70 | assert not all([delta_h, delta_w]) 71 | 72 | if self.center_crop: 73 | top = delta_h // 2 74 | left = delta_w // 2 75 | else: 76 | top = np.random.randint(0, delta_h + 1) 77 | left = np.random.randint(0, delta_w + 1) 78 | image = transforms.functional.crop( 79 | image_tensor, top=top, left=left, height=self.size, width=self.size 80 | ) 81 | crop_coords_top_left = torch.tensor([top, left]) 82 | 83 | clip_image = self.clip_image_processor(images=raw_image, return_tensors="pt").pixel_values 84 | 85 | # drop 86 | drop_image_embed = 0 87 | rand_num = random.random() 88 | if rand_num < self.i_drop_rate: 89 | drop_image_embed = 1 90 | elif rand_num < (self.i_drop_rate + self.t_drop_rate): 91 | text = "" 92 | elif rand_num < (self.i_drop_rate + self.t_drop_rate + self.ti_drop_rate): 93 | text = "" 94 | drop_image_embed = 1 95 | 96 | # get text and tokenize 97 | text_input_ids = self.tokenizer( 98 | text, 99 | max_length=self.tokenizer.model_max_length, 100 | padding="max_length", 101 | truncation=True, 102 | return_tensors="pt" 103 | ).input_ids 104 | 105 | text_input_ids_2 = self.tokenizer_2( 106 | text, 107 | max_length=self.tokenizer_2.model_max_length, 108 | padding="max_length", 109 | truncation=True, 110 | return_tensors="pt" 111 | ).input_ids 112 | 113 | return { 114 | "image": image, 115 | "text_input_ids": text_input_ids, 116 | "text_input_ids_2": text_input_ids_2, 117 | "clip_image": clip_image, 118 | "drop_image_embed": drop_image_embed, 119 | "original_size": original_size, 120 | "crop_coords_top_left": crop_coords_top_left, 121 | "target_size": torch.tensor([self.size, self.size]), 122 | } 123 | 124 | 125 | def __len__(self): 126 | return len(self.data) 127 | 128 | 129 | def collate_fn(data): 130 | images = torch.stack([example["image"] for example in data]) 131 | text_input_ids = torch.cat([example["text_input_ids"] for example in data], dim=0) 132 | text_input_ids_2 = torch.cat([example["text_input_ids_2"] for example in data], dim=0) 133 | clip_images = torch.cat([example["clip_image"] for example in data], dim=0) 134 | drop_image_embeds = [example["drop_image_embed"] for example in data] 135 | original_size = torch.stack([example["original_size"] for example in data]) 136 | crop_coords_top_left = torch.stack([example["crop_coords_top_left"] for example in data]) 137 | target_size = torch.stack([example["target_size"] for example in data]) 138 | 139 | return { 140 | "images": images, 141 | "text_input_ids": text_input_ids, 142 | "text_input_ids_2": text_input_ids_2, 143 | "clip_images": clip_images, 144 | "drop_image_embeds": drop_image_embeds, 145 | "original_size": original_size, 146 | "crop_coords_top_left": crop_coords_top_left, 147 | "target_size": target_size, 148 | } 149 | 150 | 151 | class IPAdapter(torch.nn.Module): 152 | """IP-Adapter""" 153 | def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None): 154 | super().__init__() 155 | self.unet = unet 156 | self.image_proj_model = image_proj_model 157 | self.adapter_modules = adapter_modules 158 | 159 | if ckpt_path is not None: 160 | self.load_from_checkpoint(ckpt_path) 161 | 162 | def forward(self, noisy_latents, timesteps, encoder_hidden_states, unet_added_cond_kwargs, image_embeds): 163 | ip_tokens = self.image_proj_model(image_embeds) 164 | encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1) 165 | # Predict the noise residual 166 | noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=unet_added_cond_kwargs).sample 167 | return noise_pred 168 | 169 | def load_from_checkpoint(self, ckpt_path: str): 170 | # Calculate original checksums 171 | orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()])) 172 | orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()])) 173 | 174 | state_dict = torch.load(ckpt_path, map_location="cpu") 175 | 176 | # Load state dict for image_proj_model and adapter_modules 177 | self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True) 178 | self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True) 179 | 180 | # Calculate new checksums 181 | new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()])) 182 | new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()])) 183 | 184 | # Verify if the weights have changed 185 | assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!" 186 | assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!" 187 | 188 | print(f"Successfully loaded weights from checkpoint {ckpt_path}") 189 | 190 | 191 | def parse_args(): 192 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 193 | parser.add_argument( 194 | "--pretrained_model_name_or_path", 195 | type=str, 196 | default=None, 197 | required=True, 198 | help="Path to pretrained model or model identifier from huggingface.co/models.", 199 | ) 200 | parser.add_argument( 201 | "--pretrained_ip_adapter_path", 202 | type=str, 203 | default=None, 204 | help="Path to pretrained ip adapter model. If not specified weights are initialized randomly.", 205 | ) 206 | parser.add_argument( 207 | "--data_json_file", 208 | type=str, 209 | default=None, 210 | required=True, 211 | help="Training data", 212 | ) 213 | parser.add_argument( 214 | "--data_root_path", 215 | type=str, 216 | default="", 217 | required=True, 218 | help="Training data root path", 219 | ) 220 | parser.add_argument( 221 | "--image_encoder_path", 222 | type=str, 223 | default=None, 224 | required=True, 225 | help="Path to CLIP image encoder", 226 | ) 227 | parser.add_argument( 228 | "--output_dir", 229 | type=str, 230 | default="sd-ip_adapter", 231 | help="The output directory where the model predictions and checkpoints will be written.", 232 | ) 233 | parser.add_argument( 234 | "--logging_dir", 235 | type=str, 236 | default="logs", 237 | help=( 238 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" 239 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." 240 | ), 241 | ) 242 | parser.add_argument( 243 | "--resolution", 244 | type=int, 245 | default=512, 246 | help=( 247 | "The resolution for input images" 248 | ), 249 | ) 250 | parser.add_argument( 251 | "--learning_rate", 252 | type=float, 253 | default=1e-4, 254 | help="Learning rate to use.", 255 | ) 256 | parser.add_argument("--weight_decay", type=float, default=1e-2, help="Weight decay to use.") 257 | parser.add_argument("--num_train_epochs", type=int, default=100) 258 | parser.add_argument( 259 | "--train_batch_size", type=int, default=8, help="Batch size (per device) for the training dataloader." 260 | ) 261 | parser.add_argument("--noise_offset", type=float, default=None, help="noise offset") 262 | parser.add_argument( 263 | "--dataloader_num_workers", 264 | type=int, 265 | default=0, 266 | help=( 267 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." 268 | ), 269 | ) 270 | parser.add_argument( 271 | "--save_steps", 272 | type=int, 273 | default=2000, 274 | help=( 275 | "Save a checkpoint of the training state every X updates" 276 | ), 277 | ) 278 | parser.add_argument( 279 | "--mixed_precision", 280 | type=str, 281 | default=None, 282 | choices=["no", "fp16", "bf16"], 283 | help=( 284 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" 285 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" 286 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." 287 | ), 288 | ) 289 | parser.add_argument( 290 | "--report_to", 291 | type=str, 292 | default="tensorboard", 293 | help=( 294 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' 295 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' 296 | ), 297 | ) 298 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 299 | 300 | args = parser.parse_args() 301 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 302 | if env_local_rank != -1 and env_local_rank != args.local_rank: 303 | args.local_rank = env_local_rank 304 | 305 | return args 306 | 307 | 308 | def main(): 309 | args = parse_args() 310 | logging_dir = Path(args.output_dir, args.logging_dir) 311 | 312 | accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) 313 | 314 | accelerator = Accelerator( 315 | mixed_precision=args.mixed_precision, 316 | log_with=args.report_to, 317 | project_config=accelerator_project_config, 318 | ) 319 | 320 | if accelerator.is_main_process: 321 | if args.output_dir is not None: 322 | os.makedirs(args.output_dir, exist_ok=True) 323 | 324 | # Load scheduler, tokenizer and models. 325 | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") 326 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") 327 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder") 328 | tokenizer_2 = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer_2") 329 | text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder_2") 330 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") 331 | unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") 332 | image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path) 333 | # freeze parameters of models to save more memory 334 | unet.requires_grad_(False) 335 | vae.requires_grad_(False) 336 | text_encoder.requires_grad_(False) 337 | text_encoder_2.requires_grad_(False) 338 | image_encoder.requires_grad_(False) 339 | 340 | #ip-adapter 341 | num_tokens = 4 342 | image_proj_model = ImageProjModel( 343 | cross_attention_dim=unet.config.cross_attention_dim, 344 | clip_embeddings_dim=image_encoder.config.projection_dim, 345 | clip_extra_context_tokens=num_tokens, 346 | ) 347 | # init adapter modules 348 | attn_procs = {} 349 | unet_sd = unet.state_dict() 350 | for name in unet.attn_processors.keys(): 351 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim 352 | if name.startswith("mid_block"): 353 | hidden_size = unet.config.block_out_channels[-1] 354 | elif name.startswith("up_blocks"): 355 | block_id = int(name[len("up_blocks.")]) 356 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id] 357 | elif name.startswith("down_blocks"): 358 | block_id = int(name[len("down_blocks.")]) 359 | hidden_size = unet.config.block_out_channels[block_id] 360 | if cross_attention_dim is None: 361 | attn_procs[name] = AttnProcessor() 362 | else: 363 | layer_name = name.split(".processor")[0] 364 | weights = { 365 | "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"], 366 | "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"], 367 | } 368 | attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, num_tokens=num_tokens) 369 | attn_procs[name].load_state_dict(weights) 370 | unet.set_attn_processor(attn_procs) 371 | adapter_modules = torch.nn.ModuleList(unet.attn_processors.values()) 372 | 373 | ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules, args.pretrained_ip_adapter_path) 374 | 375 | weight_dtype = torch.float32 376 | if accelerator.mixed_precision == "fp16": 377 | weight_dtype = torch.float16 378 | elif accelerator.mixed_precision == "bf16": 379 | weight_dtype = torch.bfloat16 380 | #unet.to(accelerator.device, dtype=weight_dtype) 381 | vae.to(accelerator.device) # use fp32 382 | text_encoder.to(accelerator.device, dtype=weight_dtype) 383 | text_encoder_2.to(accelerator.device, dtype=weight_dtype) 384 | image_encoder.to(accelerator.device, dtype=weight_dtype) 385 | 386 | # optimizer 387 | params_to_opt = itertools.chain(ip_adapter.image_proj_model.parameters(), ip_adapter.adapter_modules.parameters()) 388 | optimizer = torch.optim.AdamW(params_to_opt, lr=args.learning_rate, weight_decay=args.weight_decay) 389 | 390 | # dataloader 391 | train_dataset = MyDataset(args.data_json_file, tokenizer=tokenizer, tokenizer_2=tokenizer_2, size=args.resolution, image_root_path=args.data_root_path) 392 | train_dataloader = torch.utils.data.DataLoader( 393 | train_dataset, 394 | shuffle=True, 395 | collate_fn=collate_fn, 396 | batch_size=args.train_batch_size, 397 | num_workers=args.dataloader_num_workers, 398 | ) 399 | 400 | # Prepare everything with our `accelerator`. 401 | ip_adapter, optimizer, train_dataloader = accelerator.prepare(ip_adapter, optimizer, train_dataloader) 402 | 403 | global_step = 0 404 | for epoch in range(0, args.num_train_epochs): 405 | begin = time.perf_counter() 406 | for step, batch in enumerate(train_dataloader): 407 | load_data_time = time.perf_counter() - begin 408 | with accelerator.accumulate(ip_adapter): 409 | # Convert images to latent space 410 | with torch.no_grad(): 411 | # vae of sdxl should use fp32 412 | latents = vae.encode(batch["images"].to(accelerator.device, dtype=torch.float32)).latent_dist.sample() 413 | latents = latents * vae.config.scaling_factor 414 | latents = latents.to(accelerator.device, dtype=weight_dtype) 415 | 416 | # Sample noise that we'll add to the latents 417 | noise = torch.randn_like(latents) 418 | if args.noise_offset: 419 | # https://www.crosslabs.org//blog/diffusion-with-offset-noise 420 | noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1)).to(accelerator.device, dtype=weight_dtype) 421 | 422 | bsz = latents.shape[0] 423 | # Sample a random timestep for each image 424 | timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device) 425 | timesteps = timesteps.long() 426 | 427 | # Add noise to the latents according to the noise magnitude at each timestep 428 | # (this is the forward diffusion process) 429 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 430 | 431 | with torch.no_grad(): 432 | image_embeds = image_encoder(batch["clip_images"].to(accelerator.device, dtype=weight_dtype)).image_embeds 433 | image_embeds_ = [] 434 | for image_embed, drop_image_embed in zip(image_embeds, batch["drop_image_embeds"]): 435 | if drop_image_embed == 1: 436 | image_embeds_.append(torch.zeros_like(image_embed)) 437 | else: 438 | image_embeds_.append(image_embed) 439 | image_embeds = torch.stack(image_embeds_) 440 | 441 | with torch.no_grad(): 442 | encoder_output = text_encoder(batch['text_input_ids'].to(accelerator.device), output_hidden_states=True) 443 | text_embeds = encoder_output.hidden_states[-2] 444 | encoder_output_2 = text_encoder_2(batch['text_input_ids_2'].to(accelerator.device), output_hidden_states=True) 445 | pooled_text_embeds = encoder_output_2[0] 446 | text_embeds_2 = encoder_output_2.hidden_states[-2] 447 | text_embeds = torch.concat([text_embeds, text_embeds_2], dim=-1) # concat 448 | 449 | # add cond 450 | add_time_ids = [ 451 | batch["original_size"].to(accelerator.device), 452 | batch["crop_coords_top_left"].to(accelerator.device), 453 | batch["target_size"].to(accelerator.device), 454 | ] 455 | add_time_ids = torch.cat(add_time_ids, dim=1).to(accelerator.device, dtype=weight_dtype) 456 | unet_added_cond_kwargs = {"text_embeds": pooled_text_embeds, "time_ids": add_time_ids} 457 | 458 | noise_pred = ip_adapter(noisy_latents, timesteps, text_embeds, unet_added_cond_kwargs, image_embeds) 459 | 460 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") 461 | 462 | # Gather the losses across all processes for logging (if we use distributed training). 463 | avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean().item() 464 | 465 | # Backpropagate 466 | accelerator.backward(loss) 467 | optimizer.step() 468 | optimizer.zero_grad() 469 | 470 | if accelerator.is_main_process: 471 | print("Epoch {}, step {}, data_time: {}, time: {}, step_loss: {}".format( 472 | epoch, step, load_data_time, time.perf_counter() - begin, avg_loss)) 473 | 474 | global_step += 1 475 | 476 | if global_step % args.save_steps == 0: 477 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") 478 | accelerator.save_state(save_path) 479 | 480 | begin = time.perf_counter() 481 | 482 | if __name__ == "__main__": 483 | main() 484 | --------------------------------------------------------------------------------