├── README.md ├── assets ├── 0.jpg ├── 2.jpg ├── 3.jpg ├── 4.jpg ├── 5.jpg ├── background_mask.png ├── comparison.png ├── composition_mask.png ├── example1.png ├── example2.png ├── example3.png ├── female_mask.png ├── ip_background.png ├── ip_composition_image.png ├── ip_female_style.png ├── ip_male_style.png ├── male_mask.png ├── multi_instantstyle.png ├── multi_mask.png ├── overture-creations-5sI6fQgYIuo.png ├── overture-creations-5sI6fQgYIuo_mask.png ├── overture-creations-5sI6fQgYIuo_mask_inverse.png ├── page0.png ├── pipe.png ├── subtraction.png ├── tree.png └── yann-lecun.jpg ├── attn_blocks.py ├── attn_blocks_sd15.py ├── gradio_demo ├── .gitignore ├── README.md ├── app.py ├── assets │ ├── 0.jpg │ ├── 1.jpg │ ├── 2.jpg │ ├── 3.jpg │ └── yann-lecun.jpg └── requirements.txt ├── infer_style.py ├── infer_style_controlnet.py ├── infer_style_inpainting.py ├── infer_style_plus.py ├── infer_style_sd15.py ├── ip_adapter ├── __init__.py ├── attention_processor.py ├── ip_adapter.py ├── resampler.py └── utils.py └── notebooks └── instant_style_controlnet_sdxl_demo.ipynb /README.md: -------------------------------------------------------------------------------- 1 |
2 |

InstantStyle: Free Lunch towards Style-Preserving in Text-to-Image Generation

3 | 4 | [**Haofan Wang**](https://haofanwang.github.io/)* · [**Matteo Spinelli**](https://github.com/cubiq) · [**Qixun Wang**](https://github.com/wangqixun) · [**Xu Bai**](https://huggingface.co/baymin0220) · [**Zekui Qin**](https://github.com/ZekuiQin) · [**Anthony Chen**](https://antonioo-c.github.io/) 5 | 6 | InstantX Team 7 | 8 | *corresponding authors 9 | 10 | 11 | 12 | [![Hugging Face](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Space-red)](https://huggingface.co/spaces/InstantX/InstantStyle) 13 | [![ModelScope](https://img.shields.io/badge/ModelScope-Studios-blue)](https://modelscope.cn/studios/instantx/InstantStyle/summary) 14 | [![GitHub](https://img.shields.io/github/stars/InstantStyle/InstantStyle?style=social)](https://github.com/InstantStyle/InstantStyle) 15 | 16 |
17 | 18 | InstantStyle is a general framework that employs two straightforward yet potent techniques for achieving an effective disentanglement of style and content from reference images. 19 | 20 | 21 | 22 |
23 | 24 |
25 | 26 | 27 | ## Principle 28 | 29 | Separating Content from Image. Benefit from the good characterization of CLIP global features, after subtracting the content text fea- tures from the image features, the style and content can be explicitly decoupled. Although simple, this strategy is quite effective in mitigating content leakage. 30 |

31 | 32 |

33 | 34 | Injecting into Style Blocks Only. Empirically, each layer of a deep network captures different semantic information the key observation in our work is that there exists two specific attention layers handling style. Specifically, we find up blocks.0.attentions.1 and down blocks.2.attentions.1 capture style (color, material, atmosphere) and spatial layout (structure, composition) respectively. 35 |

36 | 37 |

38 | 39 | ## Release 40 | - [2024/07/06] 🔥 We release [CSGO](https://github.com/instantX-research/CSGO) page for content-style composition. Code will be released soon. 41 | - [2024/07/01] 🔥 We release [InstantStyle-Plus](https://instantstyle-plus.github.io/) report for content preserving. 42 | - [2024/04/29] 🔥 We support InstantStyle natively in diffusers, usage can be found [here](https://github.com/InstantStyle/InstantStyle?tab=readme-ov-file#use-in-diffusers) 43 | - [2024/04/24] 🔥 InstantStyle for fast generation, find demos at [InstantStyle-SDXL-Lightning](https://huggingface.co/spaces/radames/InstantStyle-SDXL-Lightning) and [InstantStyle-Hyper-SDXL](https://huggingface.co/spaces/radames/InstantStyle-Hyper-SDXL). 44 | - [2024/04/24] 🔥 We support [HiDiffusion](https://github.com/megvii-research/HiDiffusion) for generating highres images, find more information [here](https://github.com/InstantStyle/InstantStyle/tree/main?tab=readme-ov-file#high-resolution-generation). 45 | - [2024/04/23] 🔥 InstantStyle has been natively supported in diffusers, more information can be found [here](https://github.com/huggingface/diffusers/pull/7668). 46 | - [2024/04/20] 🔥 InstantStyle is supported in [Mikubill/sd-webui-controlnet](https://github.com/Mikubill/sd-webui-controlnet/discussions/2770). 47 | - [2024/04/11] 🔥 We add the experimental distributed inference feature. Check it [here](https://github.com/InstantStyle/InstantStyle?tab=readme-ov-file#distributed-inference). 48 | - [2024/04/10] 🔥 We support an [online demo](https://modelscope.cn/studios/instantx/InstantStyle/summary) on ModelScope. 49 | - [2024/04/09] 🔥 We support an [online demo](https://huggingface.co/spaces/InstantX/InstantStyle) on Huggingface. 50 | - [2024/04/09] 🔥 We support SDXL-inpainting, more information can be found [here](https://github.com/InstantStyle/InstantStyle/blob/main/infer_style_inpainting.py). 51 | - [2024/04/08] 🔥 InstantStyle is supported in [AnyV2V](https://tiger-ai-lab.github.io/AnyV2V/) for stylized video-to-video editing, demo can be found [here](https://twitter.com/vinesmsuic/status/1777170927500787782). 52 | - [2024/04/07] 🔥 We support image-based stylization, more information can be found [here](https://github.com/InstantStyle/InstantStyle/blob/main/infer_style_controlnet.py). 53 | - [2024/04/07] 🔥 We support an experimental version for SD1.5, more information can be found [here](https://github.com/InstantStyle/InstantStyle/blob/main/infer_style_sd15.py). 54 | - [2024/04/03] 🔥 InstantStyle is supported in [ComfyUI_IPAdapter_plus](https://github.com/cubiq/ComfyUI_IPAdapter_plus) developed by our co-author. 55 | - [2024/04/03] 🔥 We release the [technical report](https://arxiv.org/abs/2404.02733). 56 | 57 | ## Demos 58 | 59 | ### Stylized Synthesis 60 | 61 |

62 | 63 | 64 |

65 | 66 | ### Image-based Stylized Synthesis 67 | 68 |

69 | 70 |

71 | 72 | ### Comparison with Previous Works 73 | 74 |

75 | 76 |

77 | 78 | ## Download 79 | Follow [IP-Adapter](https://github.com/tencent-ailab/IP-Adapter?tab=readme-ov-file#download-models) to download pre-trained checkpoints from [here](https://huggingface.co/h94/IP-Adapter). 80 | 81 | ``` 82 | git clone https://github.com/InstantStyle/InstantStyle.git 83 | cd InstantStyle 84 | 85 | # download the models 86 | git lfs install 87 | git clone https://huggingface.co/h94/IP-Adapter 88 | mv IP-Adapter/models models 89 | mv IP-Adapter/sdxl_models sdxl_models 90 | ``` 91 | 92 | ## Usage 93 | 94 | Our method is fully compatible with [IP-Adapter](https://github.com/tencent-ailab/IP-Adapter). For feature subtraction, it only works for global feature instead of patch features. For SD1.5, you can find a demo at [infer_style_sd15.py](https://github.com/InstantStyle/InstantStyle/blob/main/infer_style_sd15.py), but we find that SD1.5 has weaker perception and understanding of style information, thus this demo is experimental only. All block names can be found in [attn_blocks.py](https://github.com/InstantStyle/InstantStyle/blob/main/attn_blocks.py) and [attn_blocks_sd15.py](https://github.com/InstantStyle/InstantStyle/blob/main/attn_blocks_sd15.py) for SDXL and SD1.5 respectively. 95 | 96 | ```python 97 | import torch 98 | from diffusers import StableDiffusionXLPipeline 99 | from PIL import Image 100 | 101 | from ip_adapter import IPAdapterXL 102 | 103 | base_model_path = "stabilityai/stable-diffusion-xl-base-1.0" 104 | image_encoder_path = "sdxl_models/image_encoder" 105 | ip_ckpt = "sdxl_models/ip-adapter_sdxl.bin" 106 | device = "cuda" 107 | 108 | # load SDXL pipeline 109 | pipe = StableDiffusionXLPipeline.from_pretrained( 110 | base_model_path, 111 | torch_dtype=torch.float16, 112 | add_watermarker=False, 113 | ) 114 | 115 | # reduce memory consumption 116 | pipe.enable_vae_tiling() 117 | 118 | # load ip-adapter 119 | # target_blocks=["block"] for original IP-Adapter 120 | # target_blocks=["up_blocks.0.attentions.1"] for style blocks only 121 | # target_blocks = ["up_blocks.0.attentions.1", "down_blocks.2.attentions.1"] # for style+layout blocks 122 | ip_model = IPAdapterXL(pipe, image_encoder_path, ip_ckpt, device, target_blocks=["up_blocks.0.attentions.1"]) 123 | 124 | image = "./assets/0.jpg" 125 | image = Image.open(image) 126 | image.resize((512, 512)) 127 | 128 | # generate image variations with only image prompt 129 | images = ip_model.generate(pil_image=image, 130 | prompt="a cat, masterpiece, best quality, high quality", 131 | negative_prompt= "text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry", 132 | scale=1.0, 133 | guidance_scale=5, 134 | num_samples=1, 135 | num_inference_steps=30, 136 | seed=42, 137 | #neg_content_prompt="a rabbit", 138 | #neg_content_scale=0.5, 139 | ) 140 | 141 | images[0].save("result.png") 142 | ``` 143 | 144 | ## Use in diffusers 145 | InstantStyle has already been integrated into [diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/ip_adapter#style--layout-control) (please make sure that you have installed diffusers>=0.28.0.dev0), making the usage significantly simpler. You can now control the per-transformer behavior of each IP-Adapter with the set_ip_adapter_scale() method, using a configuration dictionary as shown below: 146 | 147 | ```python 148 | from diffusers import StableDiffusionXLPipeline 149 | from PIL import Image 150 | import torch 151 | 152 | # load SDXL pipeline 153 | pipe = StableDiffusionXLPipeline.from_pretrained( 154 | "stabilityai/stable-diffusion-xl-base-1.0", 155 | torch_dtype=torch.float16, 156 | add_watermarker=False, 157 | ) 158 | 159 | # load ip-adapter 160 | pipe.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin") 161 | pipe.enable_vae_tiling() 162 | 163 | # configure ip-adapter scales. 164 | scale = { 165 | "down": {"block_2": [0.0, 1.0]}, 166 | "up": {"block_0": [0.0, 1.0, 0.0]}, 167 | } 168 | pipeline.set_ip_adapter_scale(scale) 169 | ``` 170 | 171 | In this example. We set ```scale=1.0``` for IP-Adapter in the second transformer of down-part, block 2, and the second in up-part, block 0. Note that there are 2 transformers in down-part block 2 so the list is of length 2, and so do the up-part block 0. The rest IP-Adapter will have a zero scale which means disable them in all the other layers. 172 | 173 | With the help of ```set_ip_adapter_scale()```, we can now configure IP-Adapters without a need of reloading them everytime we want to test the IP-Adapter behaviors. 174 | 175 | ```python 176 | # for original IP-Adapter 177 | scale = 1.0 178 | pipeline.set_ip_adapter_scale(scale) 179 | 180 | # for style blocks only 181 | scale = { 182 | "up": {"block_0": [0.0, 1.0, 0.0]}, 183 | } 184 | pipeline.set_ip_adapter_scale(scale) 185 | ``` 186 | 187 | ### Multiple IP-Adapter images with masks 188 | You can also load multiple IP-Adapters, together with multiple IP-Adapter images with masks for more precisely layout control just as that in [IP-Adapter](https://huggingface.co/docs/diffusers/main/en/using-diffusers/ip_adapter#ip-adapter-masking) do. 189 | 190 | ```python 191 | from diffusers import StableDiffusionXLPipeline 192 | from diffusers.image_processor import IPAdapterMaskProcessor 193 | from transformers import CLIPVisionModelWithProjection 194 | from PIL import Image 195 | import torch 196 | 197 | image_encoder = CLIPVisionModelWithProjection.from_pretrained( 198 | "h94/IP-Adapter", subfolder="models/image_encoder", torch_dtype=torch.float16 199 | ).to("cuda") 200 | 201 | pipe = StableDiffusionXLPipeline.from_pretrained( 202 | "RunDiffusion/Juggernaut-XL-v9", torch_dtype=torch.float16, image_encoder=image_encoder, variant="fp16" 203 | ).to("cuda") 204 | 205 | pipe.load_ip_adapter( 206 | ["ostris/ip-composition-adapter", "h94/IP-Adapter"], 207 | subfolder=["", "sdxl_models"], 208 | weight_name=[ 209 | "ip_plus_composition_sdxl.safetensors", 210 | "ip-adapter_sdxl_vit-h.safetensors", 211 | ], 212 | image_encoder_folder=None, 213 | ) 214 | 215 | scale_1 = { 216 | "down": [[0.0, 0.0, 1.0]], 217 | "mid": [[0.0, 0.0, 1.0]], 218 | "up": {"block_0": [[0.0, 0.0, 1.0], [1.0, 1.0, 1.0], [0.0, 0.0, 1.0]], "block_1": [[0.0, 0.0, 1.0]]}, 219 | } 220 | # activate the first IP-Adapter in everywhere in the model, 221 | # configure the second one for precise style control to each masked input. 222 | pipe.set_ip_adapter_scale([1.0, scale_1]) 223 | 224 | processor = IPAdapterMaskProcessor() 225 | female_mask = Image.open("./assets/female_mask.png") 226 | male_mask = Image.open("./assets/male_mask.png") 227 | background_mask = Image.open("./assets/background_mask.png") 228 | composition_mask = Image.open("./assets/composition_mask.png") 229 | mask1 = processor.preprocess([composition_mask], height=1024, width=1024) 230 | mask2 = processor.preprocess([female_mask, male_mask, background_mask], height=1024, width=1024) 231 | mask2 = mask2.reshape(1, mask2.shape[0], mask2.shape[2], mask2.shape[3]) # output -> (1, 3, 1024, 1024) 232 | 233 | ip_female_style = Image.open("./assets/ip_female_style.png") 234 | ip_male_style = Image.open("./assets/ip_male_style.png") 235 | ip_background = Image.open("./assets/ip_background.png") 236 | ip_composition_image = Image.open("./assets/ip_composition_image.png") 237 | 238 | image = pipe( 239 | prompt="high quality, cinematic photo, cinemascope, 35mm, film grain, highly detailed", 240 | negative_prompt="", 241 | ip_adapter_image=[ip_composition_image, [ip_female_style, ip_male_style, ip_background]], 242 | cross_attention_kwargs={"ip_adapter_masks": [mask1, mask2]}, 243 | guidance_scale=6.5, 244 | num_inference_steps=25, 245 | ).images[0] 246 | image 247 | 248 | ``` 249 | 250 |

251 | 252 |

253 | 254 | ## High Resolution Generation 255 | We employ [HiDiffusion](https://github.com/megvii-research/HiDiffusion) to seamlessly generate high-resolution images, you can install via `pip install hidiffusion`. 256 | 257 | ```python 258 | from hidiffusion import apply_hidiffusion, remove_hidiffusion 259 | 260 | # reduce memory consumption 261 | pipe.enable_vae_tiling() 262 | 263 | # apply hidiffusion with a single line of code. 264 | apply_hidiffusion(pipe) 265 | 266 | ... 267 | 268 | # generate image at higher resolution 269 | images = ip_model.generate(pil_image=image, 270 | prompt="a cat, masterpiece, best quality, high quality", 271 | negative_prompt= "text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry", 272 | scale=1.0, 273 | guidance_scale=5, 274 | num_samples=1, 275 | num_inference_steps=30, 276 | seed=42, 277 | height=2048, 278 | width=2048 279 | ) 280 | ``` 281 | 282 | ## Distributed Inference 283 | On distributed setups, you can run inference across multiple GPUs with 🤗 Accelerate or PyTorch Distributed, which is useful for generating with multiple prompts in parallel, in case you have limited VRAM on each GPU. More information can be found [here](https://huggingface.co/docs/diffusers/main/en/training/distributed_inference#device-placement). Make sure you have installed diffusers from the source and the lastest accelerate. 284 | 285 | ```python 286 | max_memory = {0:"10GB", 1:"10GB"} 287 | pipe = StableDiffusionXLPipeline.from_pretrained( 288 | base_model_path, 289 | torch_dtype=torch.float16, 290 | add_watermarker=False, 291 | device_map="balanced", 292 | max_memory=max_memory 293 | ) 294 | ``` 295 | 296 | ## Start a local gradio demo 297 | Run the following command: 298 | ```sh 299 | git clone https://github.com/InstantStyle/InstantStyle.git 300 | cd ./InstantStyle/gradio_demo/ 301 | pip install -r requirements.txt 302 | python app.py 303 | ``` 304 | 305 | ## Resources 306 | - [InstantStyle for WebUI](https://github.com/Mikubill/sd-webui-controlnet/discussions/2770) 307 | - [InstantStyle for ComfyUI](https://github.com/cubiq/ComfyUI_IPAdapter_plus) 308 | - [InstantID](https://github.com/InstantID/InstantID) 309 | 310 | ## Disclaimer 311 | The pretrained checkpoints follow the license in [IP-Adapter](https://github.com/tencent-ailab/IP-Adapter?tab=readme-ov-file#download-models). Users are granted the freedom to create images using this tool, but they are obligated to comply with local laws and utilize it responsibly. The developers will not assume any responsibility for potential misuse by users. 312 | 313 | ## Acknowledgements 314 | InstantStyle is developed by the InstantX team and is highly built on [IP-Adapter](https://github.com/tencent-ailab/IP-Adapter), which has been unfairly compared by many other works. We at InstantStyle make IP-Adapter great again. Additionally, we acknowledge [Hu Ye](https://github.com/xiaohu2015) for his valuable discussion. 315 | 316 | ## Star History 317 | [![Star History Chart](https://api.star-history.com/svg?repos=InstantStyle/InstantStyle&type=Date)](https://star-history.com/#InstantStyle/InstantStyle&Date) 318 | 319 | ## Cite 320 | If you find InstantStyle useful for your research and applications, please cite us using this BibTeX: 321 | 322 | ```bibtex 323 | @article{wang2024instantstyle, 324 | title={InstantStyle-Plus: Style Transfer with Content-Preserving in Text-to-Image Generation}, 325 | author={Wang, Haofan and Xing, Peng and Huang, Renyuan and Ai, Hao and Wang, Qixun and Bai, Xu}, 326 | journal={arXiv preprint arXiv:2407.00788}, 327 | year={2024} 328 | } 329 | 330 | @article{wang2024instantstyle, 331 | title={InstantStyle: Free Lunch towards Style-Preserving in Text-to-Image Generation}, 332 | author={Wang, Haofan and Wang, Qixun and Bai, Xu and Qin, Zekui and Chen, Anthony}, 333 | journal={arXiv preprint arXiv:2404.02733}, 334 | year={2024} 335 | } 336 | ``` 337 | 338 | For any question, feel free to contact us via haofanwang.ai@gmail.com. 339 | -------------------------------------------------------------------------------- /assets/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instantX-research/InstantStyle/6b40588e263c958653353ec24eb7eb990cfa3da7/assets/0.jpg -------------------------------------------------------------------------------- /assets/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instantX-research/InstantStyle/6b40588e263c958653353ec24eb7eb990cfa3da7/assets/2.jpg -------------------------------------------------------------------------------- /assets/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instantX-research/InstantStyle/6b40588e263c958653353ec24eb7eb990cfa3da7/assets/3.jpg -------------------------------------------------------------------------------- /assets/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instantX-research/InstantStyle/6b40588e263c958653353ec24eb7eb990cfa3da7/assets/4.jpg -------------------------------------------------------------------------------- /assets/5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instantX-research/InstantStyle/6b40588e263c958653353ec24eb7eb990cfa3da7/assets/5.jpg -------------------------------------------------------------------------------- /assets/background_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instantX-research/InstantStyle/6b40588e263c958653353ec24eb7eb990cfa3da7/assets/background_mask.png -------------------------------------------------------------------------------- /assets/comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instantX-research/InstantStyle/6b40588e263c958653353ec24eb7eb990cfa3da7/assets/comparison.png -------------------------------------------------------------------------------- /assets/composition_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instantX-research/InstantStyle/6b40588e263c958653353ec24eb7eb990cfa3da7/assets/composition_mask.png -------------------------------------------------------------------------------- /assets/example1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instantX-research/InstantStyle/6b40588e263c958653353ec24eb7eb990cfa3da7/assets/example1.png -------------------------------------------------------------------------------- /assets/example2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instantX-research/InstantStyle/6b40588e263c958653353ec24eb7eb990cfa3da7/assets/example2.png -------------------------------------------------------------------------------- /assets/example3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instantX-research/InstantStyle/6b40588e263c958653353ec24eb7eb990cfa3da7/assets/example3.png -------------------------------------------------------------------------------- /assets/female_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instantX-research/InstantStyle/6b40588e263c958653353ec24eb7eb990cfa3da7/assets/female_mask.png -------------------------------------------------------------------------------- /assets/ip_background.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instantX-research/InstantStyle/6b40588e263c958653353ec24eb7eb990cfa3da7/assets/ip_background.png -------------------------------------------------------------------------------- /assets/ip_composition_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instantX-research/InstantStyle/6b40588e263c958653353ec24eb7eb990cfa3da7/assets/ip_composition_image.png -------------------------------------------------------------------------------- /assets/ip_female_style.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instantX-research/InstantStyle/6b40588e263c958653353ec24eb7eb990cfa3da7/assets/ip_female_style.png -------------------------------------------------------------------------------- /assets/ip_male_style.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instantX-research/InstantStyle/6b40588e263c958653353ec24eb7eb990cfa3da7/assets/ip_male_style.png -------------------------------------------------------------------------------- /assets/male_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instantX-research/InstantStyle/6b40588e263c958653353ec24eb7eb990cfa3da7/assets/male_mask.png -------------------------------------------------------------------------------- /assets/multi_instantstyle.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instantX-research/InstantStyle/6b40588e263c958653353ec24eb7eb990cfa3da7/assets/multi_instantstyle.png -------------------------------------------------------------------------------- /assets/multi_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instantX-research/InstantStyle/6b40588e263c958653353ec24eb7eb990cfa3da7/assets/multi_mask.png -------------------------------------------------------------------------------- /assets/overture-creations-5sI6fQgYIuo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instantX-research/InstantStyle/6b40588e263c958653353ec24eb7eb990cfa3da7/assets/overture-creations-5sI6fQgYIuo.png -------------------------------------------------------------------------------- /assets/overture-creations-5sI6fQgYIuo_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instantX-research/InstantStyle/6b40588e263c958653353ec24eb7eb990cfa3da7/assets/overture-creations-5sI6fQgYIuo_mask.png -------------------------------------------------------------------------------- /assets/overture-creations-5sI6fQgYIuo_mask_inverse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instantX-research/InstantStyle/6b40588e263c958653353ec24eb7eb990cfa3da7/assets/overture-creations-5sI6fQgYIuo_mask_inverse.png -------------------------------------------------------------------------------- /assets/page0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instantX-research/InstantStyle/6b40588e263c958653353ec24eb7eb990cfa3da7/assets/page0.png -------------------------------------------------------------------------------- /assets/pipe.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instantX-research/InstantStyle/6b40588e263c958653353ec24eb7eb990cfa3da7/assets/pipe.png -------------------------------------------------------------------------------- /assets/subtraction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instantX-research/InstantStyle/6b40588e263c958653353ec24eb7eb990cfa3da7/assets/subtraction.png -------------------------------------------------------------------------------- /assets/tree.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instantX-research/InstantStyle/6b40588e263c958653353ec24eb7eb990cfa3da7/assets/tree.png -------------------------------------------------------------------------------- /assets/yann-lecun.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instantX-research/InstantStyle/6b40588e263c958653353ec24eb7eb990cfa3da7/assets/yann-lecun.jpg -------------------------------------------------------------------------------- /attn_blocks.py: -------------------------------------------------------------------------------- 1 | 'down_blocks.1.attentions.0.transformer_blocks.0.attn1.processor', 2 | 'down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor', 3 | 'down_blocks.1.attentions.0.transformer_blocks.1.attn1.processor', 4 | 'down_blocks.1.attentions.0.transformer_blocks.1.attn2.processor', 5 | 'down_blocks.1.attentions.1.transformer_blocks.0.attn1.processor', 6 | 'down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor', 7 | 'down_blocks.1.attentions.1.transformer_blocks.1.attn1.processor', 8 | 'down_blocks.1.attentions.1.transformer_blocks.1.attn2.processor', 9 | 'down_blocks.2.attentions.0.transformer_blocks.0.attn1.processor', 10 | 'down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor', 11 | 'down_blocks.2.attentions.0.transformer_blocks.1.attn1.processor', 12 | 'down_blocks.2.attentions.0.transformer_blocks.1.attn2.processor', 13 | 'down_blocks.2.attentions.0.transformer_blocks.2.attn1.processor', 14 | 'down_blocks.2.attentions.0.transformer_blocks.2.attn2.processor', 15 | 'down_blocks.2.attentions.0.transformer_blocks.3.attn1.processor', 16 | 'down_blocks.2.attentions.0.transformer_blocks.3.attn2.processor', 17 | 'down_blocks.2.attentions.0.transformer_blocks.4.attn1.processor', 18 | 'down_blocks.2.attentions.0.transformer_blocks.4.attn2.processor', 19 | 'down_blocks.2.attentions.0.transformer_blocks.5.attn1.processor', 20 | 'down_blocks.2.attentions.0.transformer_blocks.5.attn2.processor', 21 | 'down_blocks.2.attentions.0.transformer_blocks.6.attn1.processor', 22 | 'down_blocks.2.attentions.0.transformer_blocks.6.attn2.processor', 23 | 'down_blocks.2.attentions.0.transformer_blocks.7.attn1.processor', 24 | 'down_blocks.2.attentions.0.transformer_blocks.7.attn2.processor', 25 | 'down_blocks.2.attentions.0.transformer_blocks.8.attn1.processor', 26 | 'down_blocks.2.attentions.0.transformer_blocks.8.attn2.processor', 27 | 'down_blocks.2.attentions.0.transformer_blocks.9.attn1.processor', 28 | 'down_blocks.2.attentions.0.transformer_blocks.9.attn2.processor', 29 | 'down_blocks.2.attentions.1.transformer_blocks.0.attn1.processor', 30 | 'down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor', 31 | 'down_blocks.2.attentions.1.transformer_blocks.1.attn1.processor', 32 | 'down_blocks.2.attentions.1.transformer_blocks.1.attn2.processor', 33 | 'down_blocks.2.attentions.1.transformer_blocks.2.attn1.processor', 34 | 'down_blocks.2.attentions.1.transformer_blocks.2.attn2.processor', 35 | 'down_blocks.2.attentions.1.transformer_blocks.3.attn1.processor', 36 | 'down_blocks.2.attentions.1.transformer_blocks.3.attn2.processor', 37 | 'down_blocks.2.attentions.1.transformer_blocks.4.attn1.processor', 38 | 'down_blocks.2.attentions.1.transformer_blocks.4.attn2.processor', 39 | 'down_blocks.2.attentions.1.transformer_blocks.5.attn1.processor', 40 | 'down_blocks.2.attentions.1.transformer_blocks.5.attn2.processor', 41 | 'down_blocks.2.attentions.1.transformer_blocks.6.attn1.processor', 42 | 'down_blocks.2.attentions.1.transformer_blocks.6.attn2.processor', 43 | 'down_blocks.2.attentions.1.transformer_blocks.7.attn1.processor', 44 | 'down_blocks.2.attentions.1.transformer_blocks.7.attn2.processor', 45 | 'down_blocks.2.attentions.1.transformer_blocks.8.attn1.processor', 46 | 'down_blocks.2.attentions.1.transformer_blocks.8.attn2.processor', 47 | 'down_blocks.2.attentions.1.transformer_blocks.9.attn1.processor', 48 | 'down_blocks.2.attentions.1.transformer_blocks.9.attn2.processor', 49 | 'up_blocks.0.attentions.0.transformer_blocks.0.attn1.processor', 50 | 'up_blocks.0.attentions.0.transformer_blocks.0.attn2.processor', 51 | 'up_blocks.0.attentions.0.transformer_blocks.1.attn1.processor', 52 | 'up_blocks.0.attentions.0.transformer_blocks.1.attn2.processor', 53 | 'up_blocks.0.attentions.0.transformer_blocks.2.attn1.processor', 54 | 'up_blocks.0.attentions.0.transformer_blocks.2.attn2.processor', 55 | 'up_blocks.0.attentions.0.transformer_blocks.3.attn1.processor', 56 | 'up_blocks.0.attentions.0.transformer_blocks.3.attn2.processor', 57 | 'up_blocks.0.attentions.0.transformer_blocks.4.attn1.processor', 58 | 'up_blocks.0.attentions.0.transformer_blocks.4.attn2.processor', 59 | 'up_blocks.0.attentions.0.transformer_blocks.5.attn1.processor', 60 | 'up_blocks.0.attentions.0.transformer_blocks.5.attn2.processor', 61 | 'up_blocks.0.attentions.0.transformer_blocks.6.attn1.processor', 62 | 'up_blocks.0.attentions.0.transformer_blocks.6.attn2.processor', 63 | 'up_blocks.0.attentions.0.transformer_blocks.7.attn1.processor', 64 | 'up_blocks.0.attentions.0.transformer_blocks.7.attn2.processor', 65 | 'up_blocks.0.attentions.0.transformer_blocks.8.attn1.processor', 66 | 'up_blocks.0.attentions.0.transformer_blocks.8.attn2.processor', 67 | 'up_blocks.0.attentions.0.transformer_blocks.9.attn1.processor', 68 | 'up_blocks.0.attentions.0.transformer_blocks.9.attn2.processor', 69 | 'up_blocks.0.attentions.1.transformer_blocks.0.attn1.processor', 70 | 'up_blocks.0.attentions.1.transformer_blocks.0.attn2.processor', 71 | 'up_blocks.0.attentions.1.transformer_blocks.1.attn1.processor', 72 | 'up_blocks.0.attentions.1.transformer_blocks.1.attn2.processor', 73 | 'up_blocks.0.attentions.1.transformer_blocks.2.attn1.processor', 74 | 'up_blocks.0.attentions.1.transformer_blocks.2.attn2.processor', 75 | 'up_blocks.0.attentions.1.transformer_blocks.3.attn1.processor', 76 | 'up_blocks.0.attentions.1.transformer_blocks.3.attn2.processor', 77 | 'up_blocks.0.attentions.1.transformer_blocks.4.attn1.processor', 78 | 'up_blocks.0.attentions.1.transformer_blocks.4.attn2.processor', 79 | 'up_blocks.0.attentions.1.transformer_blocks.5.attn1.processor', 80 | 'up_blocks.0.attentions.1.transformer_blocks.5.attn2.processor', 81 | 'up_blocks.0.attentions.1.transformer_blocks.6.attn1.processor', 82 | 'up_blocks.0.attentions.1.transformer_blocks.6.attn2.processor', 83 | 'up_blocks.0.attentions.1.transformer_blocks.7.attn1.processor', 84 | 'up_blocks.0.attentions.1.transformer_blocks.7.attn2.processor', 85 | 'up_blocks.0.attentions.1.transformer_blocks.8.attn1.processor', 86 | 'up_blocks.0.attentions.1.transformer_blocks.8.attn2.processor', 87 | 'up_blocks.0.attentions.1.transformer_blocks.9.attn1.processor', 88 | 'up_blocks.0.attentions.1.transformer_blocks.9.attn2.processor', 89 | 'up_blocks.0.attentions.2.transformer_blocks.0.attn1.processor', 90 | 'up_blocks.0.attentions.2.transformer_blocks.0.attn2.processor', 91 | 'up_blocks.0.attentions.2.transformer_blocks.1.attn1.processor', 92 | 'up_blocks.0.attentions.2.transformer_blocks.1.attn2.processor', 93 | 'up_blocks.0.attentions.2.transformer_blocks.2.attn1.processor', 94 | 'up_blocks.0.attentions.2.transformer_blocks.2.attn2.processor', 95 | 'up_blocks.0.attentions.2.transformer_blocks.3.attn1.processor', 96 | 'up_blocks.0.attentions.2.transformer_blocks.3.attn2.processor', 97 | 'up_blocks.0.attentions.2.transformer_blocks.4.attn1.processor', 98 | 'up_blocks.0.attentions.2.transformer_blocks.4.attn2.processor', 99 | 'up_blocks.0.attentions.2.transformer_blocks.5.attn1.processor', 100 | 'up_blocks.0.attentions.2.transformer_blocks.5.attn2.processor', 101 | 'up_blocks.0.attentions.2.transformer_blocks.6.attn1.processor', 102 | 'up_blocks.0.attentions.2.transformer_blocks.6.attn2.processor', 103 | 'up_blocks.0.attentions.2.transformer_blocks.7.attn1.processor', 104 | 'up_blocks.0.attentions.2.transformer_blocks.7.attn2.processor', 105 | 'up_blocks.0.attentions.2.transformer_blocks.8.attn1.processor', 106 | 'up_blocks.0.attentions.2.transformer_blocks.8.attn2.processor', 107 | 'up_blocks.0.attentions.2.transformer_blocks.9.attn1.processor', 108 | 'up_blocks.0.attentions.2.transformer_blocks.9.attn2.processor', 109 | 'up_blocks.1.attentions.0.transformer_blocks.0.attn1.processor', 110 | 'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor', 111 | 'up_blocks.1.attentions.0.transformer_blocks.1.attn1.processor', 112 | 'up_blocks.1.attentions.0.transformer_blocks.1.attn2.processor', 113 | 'up_blocks.1.attentions.1.transformer_blocks.0.attn1.processor', 114 | 'up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor', 115 | 'up_blocks.1.attentions.1.transformer_blocks.1.attn1.processor', 116 | 'up_blocks.1.attentions.1.transformer_blocks.1.attn2.processor', 117 | 'up_blocks.1.attentions.2.transformer_blocks.0.attn1.processor', 118 | 'up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor', 119 | 'up_blocks.1.attentions.2.transformer_blocks.1.attn1.processor', 120 | 'up_blocks.1.attentions.2.transformer_blocks.1.attn2.processor', 121 | 'mid_block.attentions.0.transformer_blocks.0.attn1.processor', 122 | 'mid_block.attentions.0.transformer_blocks.0.attn2.processor', 123 | 'mid_block.attentions.0.transformer_blocks.1.attn1.processor', 124 | 'mid_block.attentions.0.transformer_blocks.1.attn2.processor', 125 | 'mid_block.attentions.0.transformer_blocks.2.attn1.processor', 126 | 'mid_block.attentions.0.transformer_blocks.2.attn2.processor', 127 | 'mid_block.attentions.0.transformer_blocks.3.attn1.processor', 128 | 'mid_block.attentions.0.transformer_blocks.3.attn2.processor', 129 | 'mid_block.attentions.0.transformer_blocks.4.attn1.processor', 130 | 'mid_block.attentions.0.transformer_blocks.4.attn2.processor', 131 | 'mid_block.attentions.0.transformer_blocks.5.attn1.processor', 132 | 'mid_block.attentions.0.transformer_blocks.5.attn2.processor', 133 | 'mid_block.attentions.0.transformer_blocks.6.attn1.processor', 134 | 'mid_block.attentions.0.transformer_blocks.6.attn2.processor', 135 | 'mid_block.attentions.0.transformer_blocks.7.attn1.processor', 136 | 'mid_block.attentions.0.transformer_blocks.7.attn2.processor', 137 | 'mid_block.attentions.0.transformer_blocks.8.attn1.processor', 138 | 'mid_block.attentions.0.transformer_blocks.8.attn2.processor', 139 | 'mid_block.attentions.0.transformer_blocks.9.attn1.processor', 140 | 'mid_block.attentions.0.transformer_blocks.9.attn2.processor' -------------------------------------------------------------------------------- /attn_blocks_sd15.py: -------------------------------------------------------------------------------- 1 | 'down_blocks.0.attentions.0.transformer_blocks.0.attn1.processor', 2 | 'down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor', 3 | 'down_blocks.0.attentions.1.transformer_blocks.0.attn1.processor', 4 | 'down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor', 5 | 'down_blocks.1.attentions.0.transformer_blocks.0.attn1.processor', 6 | 'down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor', 7 | 'down_blocks.1.attentions.1.transformer_blocks.0.attn1.processor', 8 | 'down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor', 9 | 'down_blocks.2.attentions.0.transformer_blocks.0.attn1.processor', 10 | 'down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor', 11 | 'down_blocks.2.attentions.1.transformer_blocks.0.attn1.processor', 12 | 'down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor', 13 | 'up_blocks.1.attentions.0.transformer_blocks.0.attn1.processor', 14 | 'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor', 15 | 'up_blocks.1.attentions.1.transformer_blocks.0.attn1.processor', 16 | 'up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor', 17 | 'up_blocks.1.attentions.2.transformer_blocks.0.attn1.processor', 18 | 'up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor', 19 | 'up_blocks.2.attentions.0.transformer_blocks.0.attn1.processor', 20 | 'up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor', 21 | 'up_blocks.2.attentions.1.transformer_blocks.0.attn1.processor', 22 | 'up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor', 23 | 'up_blocks.2.attentions.2.transformer_blocks.0.attn1.processor', 24 | 'up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor', 25 | 'up_blocks.3.attentions.0.transformer_blocks.0.attn1.processor', 26 | 'up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor', 27 | 'up_blocks.3.attentions.1.transformer_blocks.0.attn1.processor', 28 | 'up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor', 29 | 'up_blocks.3.attentions.2.transformer_blocks.0.attn1.processor', 30 | 'up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor', 31 | 'mid_block.attentions.0.transformer_blocks.0.attn1.processor', 32 | 'mid_block.attentions.0.transformer_blocks.0.attn2.processor' -------------------------------------------------------------------------------- /gradio_demo/.gitignore: -------------------------------------------------------------------------------- 1 | models/* 2 | sdxl_models/* 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # poetry 101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 102 | # This is especially recommended for binary packages to ensure reproducibility, and is more 103 | # commonly ignored for libraries. 104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 105 | #poetry.lock 106 | 107 | # pdm 108 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 109 | #pdm.lock 110 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 111 | # in version control. 112 | # https://pdm.fming.dev/#use-with-ide 113 | .pdm.toml 114 | 115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 116 | __pypackages__/ 117 | 118 | # Celery stuff 119 | celerybeat-schedule 120 | celerybeat.pid 121 | 122 | # SageMath parsed files 123 | *.sage.py 124 | 125 | # Environments 126 | .env 127 | .venv 128 | env/ 129 | venv/ 130 | ENV/ 131 | env.bak/ 132 | venv.bak/ 133 | 134 | # Spyder project settings 135 | .spyderproject 136 | .spyproject 137 | 138 | # Rope project settings 139 | .ropeproject 140 | 141 | # mkdocs documentation 142 | /site 143 | 144 | # mypy 145 | .mypy_cache/ 146 | .dmypy.json 147 | dmypy.json 148 | 149 | # Pyre type checker 150 | .pyre/ 151 | 152 | # pytype static type analyzer 153 | .pytype/ 154 | 155 | # Cython debug symbols 156 | cython_debug/ 157 | 158 | # PyCharm 159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 161 | # and can be added to the global gitignore or merged into this file. For a more nuclear 162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 163 | #.idea/ 164 | huggingface/ 165 | checkpoints/ 166 | models/ 167 | 168 | # Cog 169 | .cog 170 | 171 | gradio_cached_examples -------------------------------------------------------------------------------- /gradio_demo/README.md: -------------------------------------------------------------------------------- 1 | title: InstantStyle 2 | emoji: 👁 3 | colorFrom: blue 4 | colorTo: purple 5 | sdk: gradio 6 | sdk_version: 4.26.0 7 | app_file: app.py 8 | pinned: false 9 | license: apache-2.0 -------------------------------------------------------------------------------- /gradio_demo/app.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../') 3 | 4 | import os 5 | import cv2 6 | import torch 7 | import random 8 | import numpy as np 9 | from PIL import Image 10 | from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline 11 | 12 | import gradio as gr 13 | 14 | from ip_adapter import IPAdapterXL 15 | 16 | # global variable 17 | MAX_SEED = np.iinfo(np.int32).max 18 | device = "cuda" if torch.cuda.is_available() else "cpu" 19 | dtype = torch.float16 if str(device).__contains__("cuda") else torch.float32 20 | 21 | # initialization 22 | base_model_path = "stabilityai/stable-diffusion-xl-base-1.0" 23 | image_encoder_path = "sdxl_models/image_encoder" 24 | ip_ckpt = "sdxl_models/ip-adapter_sdxl.bin" 25 | 26 | controlnet_path = "diffusers/controlnet-canny-sdxl-1.0" 27 | controlnet = ControlNetModel.from_pretrained(controlnet_path, use_safetensors=False, torch_dtype=torch.float16).to(device) 28 | 29 | # load SDXL pipeline 30 | pipe = StableDiffusionXLControlNetPipeline.from_pretrained( 31 | base_model_path, 32 | controlnet=controlnet, 33 | torch_dtype=torch.float16, 34 | add_watermarker=False, 35 | ) 36 | pipe.enable_vae_tiling() 37 | 38 | # load ip-adapter 39 | # target_blocks=["block"] for original IP-Adapter 40 | # target_blocks=["up_blocks.0.attentions.1"] for style blocks only 41 | # target_blocks = ["up_blocks.0.attentions.1", "down_blocks.2.attentions.1"] # for style+layout blocks 42 | ip_model = IPAdapterXL(pipe, image_encoder_path, ip_ckpt, device, target_blocks=["up_blocks.0.attentions.1"]) 43 | 44 | def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: 45 | if randomize_seed: 46 | seed = random.randint(0, MAX_SEED) 47 | return seed 48 | 49 | def resize_img( 50 | input_image, 51 | max_side=1280, 52 | min_side=1024, 53 | size=None, 54 | pad_to_max_side=False, 55 | mode=Image.BILINEAR, 56 | base_pixel_number=64, 57 | ): 58 | w, h = input_image.size 59 | if size is not None: 60 | w_resize_new, h_resize_new = size 61 | else: 62 | ratio = min_side / min(h, w) 63 | w, h = round(ratio * w), round(ratio * h) 64 | ratio = max_side / max(h, w) 65 | input_image = input_image.resize([round(ratio * w), round(ratio * h)], mode) 66 | w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number 67 | h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number 68 | input_image = input_image.resize([w_resize_new, h_resize_new], mode) 69 | 70 | if pad_to_max_side: 71 | res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255 72 | offset_x = (max_side - w_resize_new) // 2 73 | offset_y = (max_side - h_resize_new) // 2 74 | res[ 75 | offset_y : offset_y + h_resize_new, offset_x : offset_x + w_resize_new 76 | ] = np.array(input_image) 77 | input_image = Image.fromarray(res) 78 | return input_image 79 | 80 | def get_example(): 81 | case = [ 82 | [ 83 | "./assets/0.jpg", 84 | None, 85 | "a cat, masterpiece, best quality, high quality", 86 | 1.0, 87 | 0.0 88 | ], 89 | [ 90 | "./assets/1.jpg", 91 | None, 92 | "a cat, masterpiece, best quality, high quality", 93 | 1.0, 94 | 0.0 95 | ], 96 | [ 97 | "./assets/2.jpg", 98 | None, 99 | "a cat, masterpiece, best quality, high quality", 100 | 1.0, 101 | 0.0 102 | ], 103 | [ 104 | "./assets/3.jpg", 105 | None, 106 | "a cat, masterpiece, best quality, high quality", 107 | 1.0, 108 | 0.0 109 | ], 110 | [ 111 | "./assets/2.jpg", 112 | "./assets/yann-lecun.jpg", 113 | "a man, masterpiece, best quality, high quality", 114 | 1.0, 115 | 0.6 116 | ], 117 | ] 118 | return case 119 | 120 | def run_for_examples(style_image, source_image, prompt, scale, control_scale): 121 | 122 | return create_image( 123 | image_pil=style_image, 124 | input_image=source_image, 125 | prompt=prompt, 126 | n_prompt="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry", 127 | scale=scale, 128 | control_scale=control_scale, 129 | guidance_scale=5, 130 | num_samples=1, 131 | num_inference_steps=20, 132 | seed=42, 133 | target="Load only style blocks", 134 | neg_content_prompt="", 135 | neg_content_scale=0, 136 | ) 137 | 138 | def create_image(image_pil, 139 | input_image, 140 | prompt, 141 | n_prompt, 142 | scale, 143 | control_scale, 144 | guidance_scale, 145 | num_samples, 146 | num_inference_steps, 147 | seed, 148 | target="Load only style blocks", 149 | neg_content_prompt=None, 150 | neg_content_scale=0): 151 | 152 | if target =="Load original IP-Adapter": 153 | # target_blocks=["blocks"] for original IP-Adapter 154 | ip_model = IPAdapterXL(pipe, image_encoder_path, ip_ckpt, device, target_blocks=["blocks"]) 155 | elif target=="Load only style blocks": 156 | # target_blocks=["up_blocks.0.attentions.1"] for style blocks only 157 | ip_model = IPAdapterXL(pipe, image_encoder_path, ip_ckpt, device, target_blocks=["up_blocks.0.attentions.1"]) 158 | elif target == "Load style+layout block": 159 | # target_blocks = ["up_blocks.0.attentions.1", "down_blocks.2.attentions.1"] # for style+layout blocks 160 | ip_model = IPAdapterXL(pipe, image_encoder_path, ip_ckpt, device, target_blocks=["up_blocks.0.attentions.1", "down_blocks.2.attentions.1"]) 161 | 162 | if input_image is not None: 163 | input_image = resize_img(input_image, max_side=1024) 164 | cv_input_image = pil_to_cv2(input_image) 165 | detected_map = cv2.Canny(cv_input_image, 50, 200) 166 | canny_map = Image.fromarray(cv2.cvtColor(detected_map, cv2.COLOR_BGR2RGB)) 167 | else: 168 | canny_map = Image.new('RGB', (1024, 1024), color=(255, 255, 255)) 169 | control_scale = 0 170 | 171 | if float(control_scale) == 0: 172 | canny_map = canny_map.resize((1024,1024)) 173 | 174 | if len(neg_content_prompt) > 0 and neg_content_scale != 0: 175 | images = ip_model.generate(pil_image=image_pil, 176 | prompt=prompt, 177 | negative_prompt=n_prompt, 178 | scale=scale, 179 | guidance_scale=guidance_scale, 180 | num_samples=num_samples, 181 | num_inference_steps=num_inference_steps, 182 | seed=seed, 183 | image=canny_map, 184 | controlnet_conditioning_scale=float(control_scale), 185 | neg_content_prompt=neg_content_prompt, 186 | neg_content_scale=neg_content_scale 187 | ) 188 | else: 189 | images = ip_model.generate(pil_image=image_pil, 190 | prompt=prompt, 191 | negative_prompt=n_prompt, 192 | scale=scale, 193 | guidance_scale=guidance_scale, 194 | num_samples=num_samples, 195 | num_inference_steps=num_inference_steps, 196 | seed=seed, 197 | image=canny_map, 198 | controlnet_conditioning_scale=float(control_scale), 199 | ) 200 | return images 201 | 202 | def pil_to_cv2(image_pil): 203 | image_np = np.array(image_pil) 204 | image_cv2 = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) 205 | return image_cv2 206 | 207 | # Description 208 | title = r""" 209 |

InstantStyle: Free Lunch towards Style-Preserving in Text-to-Image Generation

210 | """ 211 | 212 | description = r""" 213 | Official 🤗 Gradio demo for InstantStyle: Free Lunch towards Style-Preserving in Text-to-Image Generation.
214 | How to use:
215 | 1. Upload a style image. 216 | 2. Set stylization mode, only use style block by default. 217 | 2. Enter a text prompt, as done in normal text-to-image models. 218 | 3. Click the Submit button to begin customization. 219 | 4. Share your stylized photo with your friends and enjoy! 😊 220 | Advanced usage:
221 | 1. Click advanced options. 222 | 2. Upload another source image for image-based stylization using ControlNet. 223 | 3. Enter negative content prompt to avoid content leakage. 224 | """ 225 | 226 | article = r""" 227 | --- 228 | 📝 **Citation** 229 |
230 | If our work is helpful for your research or applications, please cite us via: 231 | ```bibtex 232 | @article{wang2024instantstyle, 233 | title={InstantStyle: Free Lunch towards Style-Preserving in Text-to-Image Generation}, 234 | author={Wang, Haofan and Wang, Qixun and Bai, Xu and Qin, Zekui and Chen, Anthony}, 235 | journal={arXiv preprint arXiv:2404.02733}, 236 | year={2024} 237 | } 238 | ``` 239 | 📧 **Contact** 240 |
241 | If you have any questions, please feel free to open an issue or directly reach us out at haofanwang.ai@gmail.com. 242 | """ 243 | 244 | block = gr.Blocks(css="footer {visibility: hidden}").queue(max_size=10, api_open=False) 245 | with block: 246 | 247 | # description 248 | gr.Markdown(title) 249 | gr.Markdown(description) 250 | 251 | with gr.Tabs(): 252 | with gr.Row(): 253 | with gr.Column(): 254 | 255 | with gr.Row(): 256 | with gr.Column(): 257 | image_pil = gr.Image(label="Style Image", type='pil') 258 | 259 | target = gr.Radio(["Load only style blocks", "Load style+layout block", "Load original IP-Adapter"], 260 | value="Load only style blocks", 261 | label="Style mode") 262 | 263 | prompt = gr.Textbox(label="Prompt", 264 | value="a cat, masterpiece, best quality, high quality") 265 | 266 | scale = gr.Slider(minimum=0,maximum=2.0, step=0.01,value=1.0, label="Scale") 267 | 268 | with gr.Accordion(open=False, label="Advanced Options"): 269 | 270 | with gr.Column(): 271 | src_image_pil = gr.Image(label="Source Image (optional)", type='pil') 272 | control_scale = gr.Slider(minimum=0,maximum=1.0, step=0.01,value=0.5, label="Controlnet conditioning scale") 273 | 274 | n_prompt = gr.Textbox(label="Neg Prompt", value="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry") 275 | 276 | neg_content_prompt = gr.Textbox(label="Neg Content Prompt", value="") 277 | neg_content_scale = gr.Slider(minimum=0, maximum=1.0, step=0.01,value=0.5, label="Neg Content Scale") 278 | 279 | guidance_scale = gr.Slider(minimum=1,maximum=15.0, step=0.01,value=5.0, label="guidance scale") 280 | num_samples= gr.Slider(minimum=1,maximum=4.0, step=1.0,value=1.0, label="num samples") 281 | num_inference_steps = gr.Slider(minimum=5,maximum=50.0, step=1.0,value=20, label="num inference steps") 282 | seed = gr.Slider(minimum=-1000000,maximum=1000000,value=1, step=1, label="Seed Value") 283 | randomize_seed = gr.Checkbox(label="Randomize seed", value=True) 284 | 285 | generate_button = gr.Button("Generate Image") 286 | 287 | with gr.Column(): 288 | generated_image = gr.Gallery(label="Generated Image") 289 | 290 | generate_button.click( 291 | fn=randomize_seed_fn, 292 | inputs=[seed, randomize_seed], 293 | outputs=seed, 294 | queue=False, 295 | api_name=False, 296 | ).then( 297 | fn=create_image, 298 | inputs=[image_pil, 299 | src_image_pil, 300 | prompt, 301 | n_prompt, 302 | scale, 303 | control_scale, 304 | guidance_scale, 305 | num_samples, 306 | num_inference_steps, 307 | seed, 308 | target, 309 | neg_content_prompt, 310 | neg_content_scale], 311 | outputs=[generated_image]) 312 | 313 | gr.Examples( 314 | examples=get_example(), 315 | inputs=[image_pil, src_image_pil, prompt, scale, control_scale], 316 | fn=run_for_examples, 317 | outputs=[generated_image], 318 | cache_examples=True, 319 | ) 320 | 321 | gr.Markdown(article) 322 | 323 | block.launch() 324 | -------------------------------------------------------------------------------- /gradio_demo/assets/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instantX-research/InstantStyle/6b40588e263c958653353ec24eb7eb990cfa3da7/gradio_demo/assets/0.jpg -------------------------------------------------------------------------------- /gradio_demo/assets/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instantX-research/InstantStyle/6b40588e263c958653353ec24eb7eb990cfa3da7/gradio_demo/assets/1.jpg -------------------------------------------------------------------------------- /gradio_demo/assets/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instantX-research/InstantStyle/6b40588e263c958653353ec24eb7eb990cfa3da7/gradio_demo/assets/2.jpg -------------------------------------------------------------------------------- /gradio_demo/assets/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instantX-research/InstantStyle/6b40588e263c958653353ec24eb7eb990cfa3da7/gradio_demo/assets/3.jpg -------------------------------------------------------------------------------- /gradio_demo/assets/yann-lecun.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instantX-research/InstantStyle/6b40588e263c958653353ec24eb7eb990cfa3da7/gradio_demo/assets/yann-lecun.jpg -------------------------------------------------------------------------------- /gradio_demo/requirements.txt: -------------------------------------------------------------------------------- 1 | diffusers>=0.25.1 2 | torch>=2.0.0 3 | torchvision>=0.15.1 4 | transformers>=4.37.1 5 | accelerate 6 | safetensors 7 | einops 8 | spaces>=0.19.4 9 | omegaconf 10 | peft 11 | huggingface-hub>=0.20.2 12 | opencv-python 13 | gradio 14 | controlnet_aux 15 | gdown 16 | peft -------------------------------------------------------------------------------- /infer_style.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers import StableDiffusionXLPipeline 3 | from PIL import Image 4 | 5 | from ip_adapter import IPAdapterXL 6 | 7 | base_model_path = "stabilityai/stable-diffusion-xl-base-1.0" 8 | image_encoder_path = "sdxl_models/image_encoder" 9 | ip_ckpt = "sdxl_models/ip-adapter_sdxl.bin" 10 | device = "cuda" 11 | 12 | # load SDXL pipeline 13 | pipe = StableDiffusionXLPipeline.from_pretrained( 14 | base_model_path, 15 | torch_dtype=torch.float16, 16 | add_watermarker=False, 17 | ) 18 | pipe.enable_vae_tiling() 19 | 20 | # load ip-adapter 21 | # target_blocks=["block"] for original IP-Adapter 22 | # target_blocks=["up_blocks.0.attentions.1"] for style blocks only 23 | # target_blocks = ["up_blocks.0.attentions.1", "down_blocks.2.attentions.1"] # for style+layout blocks 24 | ip_model = IPAdapterXL(pipe, image_encoder_path, ip_ckpt, device, target_blocks=["up_blocks.0.attentions.1"]) 25 | 26 | image = "./assets/0.jpg" 27 | image = Image.open(image) 28 | image.resize((512, 512)) 29 | 30 | # generate image 31 | images = ip_model.generate(pil_image=image, 32 | prompt="a cat, masterpiece, best quality, high quality", 33 | negative_prompt= "text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry", 34 | scale=1.0, 35 | guidance_scale=5, 36 | num_samples=1, 37 | num_inference_steps=30, 38 | seed=42, 39 | #neg_content_prompt="a rabbit", 40 | #neg_content_scale=0.5, 41 | ) 42 | 43 | images[0].save("result.png") -------------------------------------------------------------------------------- /infer_style_controlnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline 3 | 4 | import cv2 5 | from PIL import Image 6 | 7 | from ip_adapter import IPAdapterXL 8 | 9 | base_model_path = "stabilityai/stable-diffusion-xl-base-1.0" 10 | image_encoder_path = "sdxl_models/image_encoder" 11 | ip_ckpt = "sdxl_models/ip-adapter_sdxl.bin" 12 | device = "cuda" 13 | 14 | controlnet_path = "diffusers/controlnet-canny-sdxl-1.0" 15 | controlnet = ControlNetModel.from_pretrained(controlnet_path, use_safetensors=False, torch_dtype=torch.float16).to(device) 16 | 17 | # load SDXL pipeline 18 | pipe = StableDiffusionXLControlNetPipeline.from_pretrained( 19 | base_model_path, 20 | controlnet=controlnet, 21 | torch_dtype=torch.float16, 22 | add_watermarker=False, 23 | ) 24 | pipe.enable_vae_tiling() 25 | 26 | # load ip-adapter 27 | # target_blocks=["block"] for original IP-Adapter 28 | # target_blocks=["up_blocks.0.attentions.1"] for style blocks only 29 | # target_blocks = ["up_blocks.0.attentions.1", "down_blocks.2.attentions.1"] # for style+layout blocks 30 | ip_model = IPAdapterXL(pipe, image_encoder_path, ip_ckpt, device, target_blocks=["up_blocks.0.attentions.1"]) 31 | 32 | # style image 33 | image = "./assets/4.jpg" 34 | image = Image.open(image) 35 | image.resize((512, 512)) 36 | 37 | # control image 38 | input_image = cv2.imread("./assets/yann-lecun.jpg") 39 | detected_map = cv2.Canny(input_image, 50, 200) 40 | canny_map = Image.fromarray(cv2.cvtColor(detected_map, cv2.COLOR_BGR2RGB)) 41 | 42 | # generate image 43 | images = ip_model.generate(pil_image=image, 44 | prompt="a man, masterpiece, best quality, high quality", 45 | negative_prompt= "text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry", 46 | scale=1.0, 47 | guidance_scale=5, 48 | num_samples=1, 49 | num_inference_steps=30, 50 | seed=42, 51 | image=canny_map, 52 | controlnet_conditioning_scale=0.6, 53 | ) 54 | 55 | images[0].save("result.png") -------------------------------------------------------------------------------- /infer_style_inpainting.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers import StableDiffusionXLInpaintPipeline 3 | from PIL import Image 4 | 5 | from ip_adapter import IPAdapterXL 6 | 7 | base_model_path = "diffusers/stable-diffusion-xl-1.0-inpainting-0.1" 8 | image_encoder_path = "sdxl_models/image_encoder" 9 | ip_ckpt = "sdxl_models/ip-adapter_sdxl.bin" 10 | device = "cuda" 11 | 12 | # load SDXL pipeline 13 | pipe = StableDiffusionXLInpaintPipeline.from_pretrained( 14 | base_model_path, 15 | torch_dtype=torch.float16, 16 | variant="fp16", 17 | use_safetensors=True, 18 | ) 19 | pipe.enable_vae_tiling() 20 | 21 | # load ip-adapter 22 | # target_blocks=["block"] for original IP-Adapter 23 | # target_blocks=["up_blocks.0.attentions.1"] for style blocks only 24 | # target_blocks = ["up_blocks.0.attentions.1", "down_blocks.2.attentions.1"] # for style+layout blocks 25 | ip_model = IPAdapterXL(pipe, image_encoder_path, ip_ckpt, device, target_blocks=["up_blocks.0.attentions.1"]) 26 | 27 | image = "./assets/5.jpg" 28 | image = Image.open(image) 29 | image.resize((512, 512)) 30 | 31 | init_image = Image.open("./assets/overture-creations-5sI6fQgYIuo.png").convert("RGB") 32 | mask_image = Image.open("./assets/overture-creations-5sI6fQgYIuo_mask_inverse.png").convert("RGB") 33 | 34 | # generate image 35 | images = ip_model.generate(pil_image=image, 36 | prompt="a dog sitting on, masterpiece, best quality, high quality", 37 | negative_prompt= "text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry", 38 | scale=2.0, 39 | guidance_scale=8, 40 | num_samples=1, 41 | num_inference_steps=30, 42 | image=init_image, 43 | mask_image=mask_image, 44 | strength=0.99 45 | ) 46 | 47 | images[0].save("result.png") -------------------------------------------------------------------------------- /infer_style_plus.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers import StableDiffusionXLPipeline 3 | from PIL import Image 4 | 5 | from ip_adapter import IPAdapterPlusXL 6 | 7 | base_model_path = "stabilityai/stable-diffusion-xl-base-1.0" 8 | image_encoder_path = "models/image_encoder" 9 | ip_ckpt = "sdxl_models/ip-adapter-plus_sdxl_vit-h.bin" 10 | device = "cuda" 11 | 12 | # load SDXL pipeline 13 | pipe = StableDiffusionXLPipeline.from_pretrained( 14 | base_model_path, 15 | torch_dtype=torch.float16, 16 | add_watermarker=False, 17 | ) 18 | pipe.enable_vae_tiling() 19 | 20 | # load ip-adapter 21 | # target_blocks=["block"] for original IP-Adapter 22 | # target_blocks=["up_blocks.0.attentions.1"] for style blocks only 23 | # target_blocks = ["up_blocks.0.attentions.1", "down_blocks.2.attentions.1"] # for style+layout blocks 24 | ip_model = IPAdapterPlusXL(pipe, image_encoder_path, ip_ckpt, device, num_tokens=16, target_blocks=["up_blocks.0.attentions.1"]) 25 | 26 | image = "./assets/0.jpg" 27 | image = Image.open(image) 28 | image.resize((512, 512)) 29 | 30 | # generate image 31 | images = ip_model.generate(pil_image=image, 32 | prompt="a cat, masterpiece, best quality, high quality", 33 | negative_prompt= "text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry", 34 | scale=1.0, 35 | guidance_scale=5, 36 | num_samples=1, 37 | num_inference_steps=30, 38 | seed=42, 39 | ) 40 | 41 | images[0].save("result.png") -------------------------------------------------------------------------------- /infer_style_sd15.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers import StableDiffusionPipeline, UniPCMultistepScheduler 3 | from PIL import Image 4 | 5 | from ip_adapter import IPAdapter 6 | 7 | base_model_path = "sd-legacy/stable-diffusion-v1-5" 8 | image_encoder_path = "models/image_encoder" 9 | ip_ckpt = "models/ip-adapter_sd15.bin" 10 | device = "cuda" 11 | 12 | # load SDXL pipeline 13 | pipe = StableDiffusionPipeline.from_pretrained( 14 | base_model_path, 15 | torch_dtype=torch.float16, 16 | ) 17 | pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) 18 | pipe.enable_vae_tiling() 19 | 20 | # load ip-adapter 21 | # target_blocks=["block"] for original IP-Adapter 22 | # target_blocks=["up_blocks.1"] for style blocks only (experimental, not obvious as SDXL) 23 | # target_blocks = ["down_blocks.2", "mid_block", "up_blocks.1"] # for style+layout blocks (experimental, not obvious as SDXL) 24 | ip_model = IPAdapter(pipe, image_encoder_path, ip_ckpt, device, target_blocks=["block"]) 25 | 26 | image = "./assets/3.jpg" 27 | image = Image.open(image) 28 | image.resize((512, 512)) 29 | 30 | # set negative content 31 | neg_content = "a girl" 32 | neg_content_scale = 0.8 33 | if neg_content is not None: 34 | from transformers import CLIPTextModelWithProjection, CLIPTokenizer 35 | text_encoder = CLIPTextModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K").to(pipe.device, 36 | dtype=pipe.dtype) 37 | tokenizer = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") 38 | 39 | tokens = tokenizer([neg_content], return_tensors='pt').to(pipe.device) 40 | neg_content_emb = text_encoder(**tokens).text_embeds 41 | neg_content_emb *= neg_content_scale 42 | else: 43 | neg_content_emb = None 44 | 45 | # generate image with content subtraction 46 | images = ip_model.generate(pil_image=image, 47 | prompt="a cat, masterpiece, best quality, high quality", 48 | negative_prompt= "text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry", 49 | scale=1.0, 50 | guidance_scale=5, 51 | num_samples=1, 52 | num_inference_steps=30, 53 | seed=42, 54 | neg_content_emb=neg_content_emb, 55 | ) 56 | 57 | images[0].save("result.png") 58 | -------------------------------------------------------------------------------- /ip_adapter/__init__.py: -------------------------------------------------------------------------------- 1 | from .ip_adapter import IPAdapter, IPAdapterPlus, IPAdapterPlusXL, IPAdapterXL, IPAdapterFull 2 | 3 | __all__ = [ 4 | "IPAdapter", 5 | "IPAdapterPlus", 6 | "IPAdapterPlusXL", 7 | "IPAdapterXL", 8 | "IPAdapterFull", 9 | ] 10 | -------------------------------------------------------------------------------- /ip_adapter/attention_processor.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class AttnProcessor(nn.Module): 8 | r""" 9 | Default processor for performing attention-related computations. 10 | """ 11 | 12 | def __init__( 13 | self, 14 | hidden_size=None, 15 | cross_attention_dim=None, 16 | ): 17 | super().__init__() 18 | 19 | def __call__( 20 | self, 21 | attn, 22 | hidden_states, 23 | encoder_hidden_states=None, 24 | attention_mask=None, 25 | temb=None, 26 | ): 27 | residual = hidden_states 28 | 29 | if attn.spatial_norm is not None: 30 | hidden_states = attn.spatial_norm(hidden_states, temb) 31 | 32 | input_ndim = hidden_states.ndim 33 | 34 | if input_ndim == 4: 35 | batch_size, channel, height, width = hidden_states.shape 36 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 37 | 38 | batch_size, sequence_length, _ = ( 39 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 40 | ) 41 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 42 | 43 | if attn.group_norm is not None: 44 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 45 | 46 | query = attn.to_q(hidden_states) 47 | 48 | if encoder_hidden_states is None: 49 | encoder_hidden_states = hidden_states 50 | elif attn.norm_cross: 51 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 52 | 53 | key = attn.to_k(encoder_hidden_states) 54 | value = attn.to_v(encoder_hidden_states) 55 | 56 | query = attn.head_to_batch_dim(query) 57 | key = attn.head_to_batch_dim(key) 58 | value = attn.head_to_batch_dim(value) 59 | 60 | attention_probs = attn.get_attention_scores(query, key, attention_mask) 61 | hidden_states = torch.bmm(attention_probs, value) 62 | hidden_states = attn.batch_to_head_dim(hidden_states) 63 | 64 | # linear proj 65 | hidden_states = attn.to_out[0](hidden_states) 66 | # dropout 67 | hidden_states = attn.to_out[1](hidden_states) 68 | 69 | if input_ndim == 4: 70 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 71 | 72 | if attn.residual_connection: 73 | hidden_states = hidden_states + residual 74 | 75 | hidden_states = hidden_states / attn.rescale_output_factor 76 | 77 | return hidden_states 78 | 79 | 80 | class IPAttnProcessor(nn.Module): 81 | r""" 82 | Attention processor for IP-Adapater. 83 | Args: 84 | hidden_size (`int`): 85 | The hidden size of the attention layer. 86 | cross_attention_dim (`int`): 87 | The number of channels in the `encoder_hidden_states`. 88 | scale (`float`, defaults to 1.0): 89 | the weight scale of image prompt. 90 | num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): 91 | The context length of the image features. 92 | """ 93 | 94 | def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, skip=False): 95 | super().__init__() 96 | 97 | self.hidden_size = hidden_size 98 | self.cross_attention_dim = cross_attention_dim 99 | self.scale = scale 100 | self.num_tokens = num_tokens 101 | self.skip = skip 102 | 103 | self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) 104 | self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) 105 | 106 | def __call__( 107 | self, 108 | attn, 109 | hidden_states, 110 | encoder_hidden_states=None, 111 | attention_mask=None, 112 | temb=None, 113 | ): 114 | residual = hidden_states 115 | 116 | if attn.spatial_norm is not None: 117 | hidden_states = attn.spatial_norm(hidden_states, temb) 118 | 119 | input_ndim = hidden_states.ndim 120 | 121 | if input_ndim == 4: 122 | batch_size, channel, height, width = hidden_states.shape 123 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 124 | 125 | batch_size, sequence_length, _ = ( 126 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 127 | ) 128 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 129 | 130 | if attn.group_norm is not None: 131 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 132 | 133 | query = attn.to_q(hidden_states) 134 | 135 | if encoder_hidden_states is None: 136 | encoder_hidden_states = hidden_states 137 | else: 138 | # get encoder_hidden_states, ip_hidden_states 139 | end_pos = encoder_hidden_states.shape[1] - self.num_tokens 140 | encoder_hidden_states, ip_hidden_states = ( 141 | encoder_hidden_states[:, :end_pos, :], 142 | encoder_hidden_states[:, end_pos:, :], 143 | ) 144 | if attn.norm_cross: 145 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 146 | 147 | key = attn.to_k(encoder_hidden_states) 148 | value = attn.to_v(encoder_hidden_states) 149 | 150 | query = attn.head_to_batch_dim(query) 151 | key = attn.head_to_batch_dim(key) 152 | value = attn.head_to_batch_dim(value) 153 | 154 | attention_probs = attn.get_attention_scores(query, key, attention_mask) 155 | hidden_states = torch.bmm(attention_probs, value) 156 | hidden_states = attn.batch_to_head_dim(hidden_states) 157 | 158 | if not self.skip: 159 | # for ip-adapter 160 | ip_key = self.to_k_ip(ip_hidden_states) 161 | ip_value = self.to_v_ip(ip_hidden_states) 162 | 163 | ip_key = attn.head_to_batch_dim(ip_key) 164 | ip_value = attn.head_to_batch_dim(ip_value) 165 | 166 | ip_attention_probs = attn.get_attention_scores(query, ip_key, None) 167 | self.attn_map = ip_attention_probs 168 | ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) 169 | ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) 170 | 171 | hidden_states = hidden_states + self.scale * ip_hidden_states 172 | 173 | # linear proj 174 | hidden_states = attn.to_out[0](hidden_states) 175 | # dropout 176 | hidden_states = attn.to_out[1](hidden_states) 177 | 178 | if input_ndim == 4: 179 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 180 | 181 | if attn.residual_connection: 182 | hidden_states = hidden_states + residual 183 | 184 | hidden_states = hidden_states / attn.rescale_output_factor 185 | 186 | return hidden_states 187 | 188 | 189 | class AttnProcessor2_0(torch.nn.Module): 190 | r""" 191 | Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). 192 | """ 193 | 194 | def __init__( 195 | self, 196 | hidden_size=None, 197 | cross_attention_dim=None, 198 | ): 199 | super().__init__() 200 | if not hasattr(F, "scaled_dot_product_attention"): 201 | raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") 202 | 203 | def __call__( 204 | self, 205 | attn, 206 | hidden_states, 207 | encoder_hidden_states=None, 208 | attention_mask=None, 209 | temb=None, 210 | ): 211 | residual = hidden_states 212 | 213 | if attn.spatial_norm is not None: 214 | hidden_states = attn.spatial_norm(hidden_states, temb) 215 | 216 | input_ndim = hidden_states.ndim 217 | 218 | if input_ndim == 4: 219 | batch_size, channel, height, width = hidden_states.shape 220 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 221 | 222 | batch_size, sequence_length, _ = ( 223 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 224 | ) 225 | 226 | if attention_mask is not None: 227 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 228 | # scaled_dot_product_attention expects attention_mask shape to be 229 | # (batch, heads, source_length, target_length) 230 | attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) 231 | 232 | if attn.group_norm is not None: 233 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 234 | 235 | query = attn.to_q(hidden_states) 236 | 237 | if encoder_hidden_states is None: 238 | encoder_hidden_states = hidden_states 239 | elif attn.norm_cross: 240 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 241 | 242 | key = attn.to_k(encoder_hidden_states) 243 | value = attn.to_v(encoder_hidden_states) 244 | 245 | inner_dim = key.shape[-1] 246 | head_dim = inner_dim // attn.heads 247 | 248 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 249 | 250 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 251 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 252 | 253 | # the output of sdp = (batch, num_heads, seq_len, head_dim) 254 | # TODO: add support for attn.scale when we move to Torch 2.1 255 | hidden_states = F.scaled_dot_product_attention( 256 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False 257 | ) 258 | 259 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 260 | hidden_states = hidden_states.to(query.dtype) 261 | 262 | # linear proj 263 | hidden_states = attn.to_out[0](hidden_states) 264 | # dropout 265 | hidden_states = attn.to_out[1](hidden_states) 266 | 267 | if input_ndim == 4: 268 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 269 | 270 | if attn.residual_connection: 271 | hidden_states = hidden_states + residual 272 | 273 | hidden_states = hidden_states / attn.rescale_output_factor 274 | 275 | return hidden_states 276 | 277 | 278 | class IPAttnProcessor2_0(torch.nn.Module): 279 | r""" 280 | Attention processor for IP-Adapater for PyTorch 2.0. 281 | Args: 282 | hidden_size (`int`): 283 | The hidden size of the attention layer. 284 | cross_attention_dim (`int`): 285 | The number of channels in the `encoder_hidden_states`. 286 | scale (`float`, defaults to 1.0): 287 | the weight scale of image prompt. 288 | num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): 289 | The context length of the image features. 290 | """ 291 | 292 | def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, skip=False): 293 | super().__init__() 294 | 295 | if not hasattr(F, "scaled_dot_product_attention"): 296 | raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") 297 | 298 | self.hidden_size = hidden_size 299 | self.cross_attention_dim = cross_attention_dim 300 | self.scale = scale 301 | self.num_tokens = num_tokens 302 | self.skip = skip 303 | 304 | self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) 305 | self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) 306 | 307 | def __call__( 308 | self, 309 | attn, 310 | hidden_states, 311 | encoder_hidden_states=None, 312 | attention_mask=None, 313 | temb=None, 314 | ): 315 | residual = hidden_states 316 | 317 | if attn.spatial_norm is not None: 318 | hidden_states = attn.spatial_norm(hidden_states, temb) 319 | 320 | input_ndim = hidden_states.ndim 321 | 322 | if input_ndim == 4: 323 | batch_size, channel, height, width = hidden_states.shape 324 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 325 | 326 | batch_size, sequence_length, _ = ( 327 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 328 | ) 329 | 330 | if attention_mask is not None: 331 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 332 | # scaled_dot_product_attention expects attention_mask shape to be 333 | # (batch, heads, source_length, target_length) 334 | attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) 335 | 336 | if attn.group_norm is not None: 337 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 338 | 339 | query = attn.to_q(hidden_states) 340 | 341 | if encoder_hidden_states is None: 342 | encoder_hidden_states = hidden_states 343 | else: 344 | # get encoder_hidden_states, ip_hidden_states 345 | end_pos = encoder_hidden_states.shape[1] - self.num_tokens 346 | encoder_hidden_states, ip_hidden_states = ( 347 | encoder_hidden_states[:, :end_pos, :], 348 | encoder_hidden_states[:, end_pos:, :], 349 | ) 350 | if attn.norm_cross: 351 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 352 | 353 | key = attn.to_k(encoder_hidden_states) 354 | value = attn.to_v(encoder_hidden_states) 355 | 356 | inner_dim = key.shape[-1] 357 | head_dim = inner_dim // attn.heads 358 | 359 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 360 | 361 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 362 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 363 | 364 | # the output of sdp = (batch, num_heads, seq_len, head_dim) 365 | # TODO: add support for attn.scale when we move to Torch 2.1 366 | hidden_states = F.scaled_dot_product_attention( 367 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False 368 | ) 369 | 370 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 371 | hidden_states = hidden_states.to(query.dtype) 372 | 373 | if not self.skip: 374 | # for ip-adapter 375 | ip_key = self.to_k_ip(ip_hidden_states) 376 | ip_value = self.to_v_ip(ip_hidden_states) 377 | 378 | ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 379 | ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 380 | 381 | # the output of sdp = (batch, num_heads, seq_len, head_dim) 382 | # TODO: add support for attn.scale when we move to Torch 2.1 383 | ip_hidden_states = F.scaled_dot_product_attention( 384 | query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False 385 | ) 386 | with torch.no_grad(): 387 | self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1) 388 | #print(self.attn_map.shape) 389 | 390 | ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 391 | ip_hidden_states = ip_hidden_states.to(query.dtype) 392 | 393 | hidden_states = hidden_states + self.scale * ip_hidden_states 394 | 395 | # linear proj 396 | hidden_states = attn.to_out[0](hidden_states) 397 | # dropout 398 | hidden_states = attn.to_out[1](hidden_states) 399 | 400 | if input_ndim == 4: 401 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 402 | 403 | if attn.residual_connection: 404 | hidden_states = hidden_states + residual 405 | 406 | hidden_states = hidden_states / attn.rescale_output_factor 407 | 408 | return hidden_states 409 | 410 | 411 | ## for controlnet 412 | class CNAttnProcessor: 413 | r""" 414 | Default processor for performing attention-related computations. 415 | """ 416 | 417 | def __init__(self, num_tokens=4): 418 | self.num_tokens = num_tokens 419 | 420 | def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None): 421 | residual = hidden_states 422 | 423 | if attn.spatial_norm is not None: 424 | hidden_states = attn.spatial_norm(hidden_states, temb) 425 | 426 | input_ndim = hidden_states.ndim 427 | 428 | if input_ndim == 4: 429 | batch_size, channel, height, width = hidden_states.shape 430 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 431 | 432 | batch_size, sequence_length, _ = ( 433 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 434 | ) 435 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 436 | 437 | if attn.group_norm is not None: 438 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 439 | 440 | query = attn.to_q(hidden_states) 441 | 442 | if encoder_hidden_states is None: 443 | encoder_hidden_states = hidden_states 444 | else: 445 | end_pos = encoder_hidden_states.shape[1] - self.num_tokens 446 | encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text 447 | if attn.norm_cross: 448 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 449 | 450 | key = attn.to_k(encoder_hidden_states) 451 | value = attn.to_v(encoder_hidden_states) 452 | 453 | query = attn.head_to_batch_dim(query) 454 | key = attn.head_to_batch_dim(key) 455 | value = attn.head_to_batch_dim(value) 456 | 457 | attention_probs = attn.get_attention_scores(query, key, attention_mask) 458 | hidden_states = torch.bmm(attention_probs, value) 459 | hidden_states = attn.batch_to_head_dim(hidden_states) 460 | 461 | # linear proj 462 | hidden_states = attn.to_out[0](hidden_states) 463 | # dropout 464 | hidden_states = attn.to_out[1](hidden_states) 465 | 466 | if input_ndim == 4: 467 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 468 | 469 | if attn.residual_connection: 470 | hidden_states = hidden_states + residual 471 | 472 | hidden_states = hidden_states / attn.rescale_output_factor 473 | 474 | return hidden_states 475 | 476 | 477 | class CNAttnProcessor2_0: 478 | r""" 479 | Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). 480 | """ 481 | 482 | def __init__(self, num_tokens=4): 483 | if not hasattr(F, "scaled_dot_product_attention"): 484 | raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") 485 | self.num_tokens = num_tokens 486 | 487 | def __call__( 488 | self, 489 | attn, 490 | hidden_states, 491 | encoder_hidden_states=None, 492 | attention_mask=None, 493 | temb=None, 494 | ): 495 | residual = hidden_states 496 | 497 | if attn.spatial_norm is not None: 498 | hidden_states = attn.spatial_norm(hidden_states, temb) 499 | 500 | input_ndim = hidden_states.ndim 501 | 502 | if input_ndim == 4: 503 | batch_size, channel, height, width = hidden_states.shape 504 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 505 | 506 | batch_size, sequence_length, _ = ( 507 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 508 | ) 509 | 510 | if attention_mask is not None: 511 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 512 | # scaled_dot_product_attention expects attention_mask shape to be 513 | # (batch, heads, source_length, target_length) 514 | attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) 515 | 516 | if attn.group_norm is not None: 517 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 518 | 519 | query = attn.to_q(hidden_states) 520 | 521 | if encoder_hidden_states is None: 522 | encoder_hidden_states = hidden_states 523 | else: 524 | end_pos = encoder_hidden_states.shape[1] - self.num_tokens 525 | encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text 526 | if attn.norm_cross: 527 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 528 | 529 | key = attn.to_k(encoder_hidden_states) 530 | value = attn.to_v(encoder_hidden_states) 531 | 532 | inner_dim = key.shape[-1] 533 | head_dim = inner_dim // attn.heads 534 | 535 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 536 | 537 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 538 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 539 | 540 | # the output of sdp = (batch, num_heads, seq_len, head_dim) 541 | # TODO: add support for attn.scale when we move to Torch 2.1 542 | hidden_states = F.scaled_dot_product_attention( 543 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False 544 | ) 545 | 546 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 547 | hidden_states = hidden_states.to(query.dtype) 548 | 549 | # linear proj 550 | hidden_states = attn.to_out[0](hidden_states) 551 | # dropout 552 | hidden_states = attn.to_out[1](hidden_states) 553 | 554 | if input_ndim == 4: 555 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 556 | 557 | if attn.residual_connection: 558 | hidden_states = hidden_states + residual 559 | 560 | hidden_states = hidden_states / attn.rescale_output_factor 561 | 562 | return hidden_states 563 | -------------------------------------------------------------------------------- /ip_adapter/ip_adapter.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List 3 | 4 | import torch 5 | from diffusers import StableDiffusionPipeline 6 | from diffusers.pipelines.controlnet import MultiControlNetModel 7 | from PIL import Image 8 | from safetensors import safe_open 9 | from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection 10 | 11 | from .utils import is_torch2_available, get_generator 12 | 13 | if is_torch2_available(): 14 | from .attention_processor import ( 15 | AttnProcessor2_0 as AttnProcessor, 16 | ) 17 | from .attention_processor import ( 18 | CNAttnProcessor2_0 as CNAttnProcessor, 19 | ) 20 | from .attention_processor import ( 21 | IPAttnProcessor2_0 as IPAttnProcessor, 22 | ) 23 | else: 24 | from .attention_processor import AttnProcessor, CNAttnProcessor, IPAttnProcessor 25 | from .resampler import Resampler 26 | 27 | 28 | class ImageProjModel(torch.nn.Module): 29 | """Projection Model""" 30 | 31 | def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): 32 | super().__init__() 33 | 34 | self.generator = None 35 | self.cross_attention_dim = cross_attention_dim 36 | self.clip_extra_context_tokens = clip_extra_context_tokens 37 | self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) 38 | self.norm = torch.nn.LayerNorm(cross_attention_dim) 39 | 40 | def forward(self, image_embeds): 41 | embeds = image_embeds 42 | clip_extra_context_tokens = self.proj(embeds).reshape( 43 | -1, self.clip_extra_context_tokens, self.cross_attention_dim 44 | ) 45 | clip_extra_context_tokens = self.norm(clip_extra_context_tokens) 46 | return clip_extra_context_tokens 47 | 48 | 49 | class MLPProjModel(torch.nn.Module): 50 | """SD model with image prompt""" 51 | def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024): 52 | super().__init__() 53 | 54 | self.proj = torch.nn.Sequential( 55 | torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim), 56 | torch.nn.GELU(), 57 | torch.nn.Linear(clip_embeddings_dim, cross_attention_dim), 58 | torch.nn.LayerNorm(cross_attention_dim) 59 | ) 60 | 61 | def forward(self, image_embeds): 62 | clip_extra_context_tokens = self.proj(image_embeds) 63 | return clip_extra_context_tokens 64 | 65 | 66 | class IPAdapter: 67 | def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4, target_blocks=["block"]): 68 | self.device = device 69 | self.image_encoder_path = image_encoder_path 70 | self.ip_ckpt = ip_ckpt 71 | self.num_tokens = num_tokens 72 | self.target_blocks = target_blocks 73 | 74 | self.pipe = sd_pipe.to(self.device) 75 | self.set_ip_adapter() 76 | 77 | # load image encoder 78 | self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to( 79 | self.device, dtype=torch.float16 80 | ) 81 | self.clip_image_processor = CLIPImageProcessor() 82 | # image proj model 83 | self.image_proj_model = self.init_proj() 84 | 85 | self.load_ip_adapter() 86 | 87 | def init_proj(self): 88 | image_proj_model = ImageProjModel( 89 | cross_attention_dim=self.pipe.unet.config.cross_attention_dim, 90 | clip_embeddings_dim=self.image_encoder.config.projection_dim, 91 | clip_extra_context_tokens=self.num_tokens, 92 | ).to(self.device, dtype=torch.float16) 93 | return image_proj_model 94 | 95 | def set_ip_adapter(self): 96 | unet = self.pipe.unet 97 | attn_procs = {} 98 | for name in unet.attn_processors.keys(): 99 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim 100 | if name.startswith("mid_block"): 101 | hidden_size = unet.config.block_out_channels[-1] 102 | elif name.startswith("up_blocks"): 103 | block_id = int(name[len("up_blocks.")]) 104 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id] 105 | elif name.startswith("down_blocks"): 106 | block_id = int(name[len("down_blocks.")]) 107 | hidden_size = unet.config.block_out_channels[block_id] 108 | if cross_attention_dim is None: 109 | attn_procs[name] = AttnProcessor() 110 | else: 111 | selected = False 112 | for block_name in self.target_blocks: 113 | if block_name in name: 114 | selected = True 115 | break 116 | if selected: 117 | attn_procs[name] = IPAttnProcessor( 118 | hidden_size=hidden_size, 119 | cross_attention_dim=cross_attention_dim, 120 | scale=1.0, 121 | num_tokens=self.num_tokens, 122 | ).to(self.device, dtype=torch.float16) 123 | else: 124 | attn_procs[name] = IPAttnProcessor( 125 | hidden_size=hidden_size, 126 | cross_attention_dim=cross_attention_dim, 127 | scale=1.0, 128 | num_tokens=self.num_tokens, 129 | skip=True 130 | ).to(self.device, dtype=torch.float16) 131 | unet.set_attn_processor(attn_procs) 132 | if hasattr(self.pipe, "controlnet"): 133 | if isinstance(self.pipe.controlnet, MultiControlNetModel): 134 | for controlnet in self.pipe.controlnet.nets: 135 | controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens)) 136 | else: 137 | self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens)) 138 | 139 | def load_ip_adapter(self): 140 | if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors": 141 | state_dict = {"image_proj": {}, "ip_adapter": {}} 142 | with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f: 143 | for key in f.keys(): 144 | if key.startswith("image_proj."): 145 | state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) 146 | elif key.startswith("ip_adapter."): 147 | state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) 148 | else: 149 | state_dict = torch.load(self.ip_ckpt, map_location="cpu") 150 | self.image_proj_model.load_state_dict(state_dict["image_proj"]) 151 | ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) 152 | ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False) 153 | 154 | @torch.inference_mode() 155 | def get_image_embeds(self, pil_image=None, clip_image_embeds=None, content_prompt_embeds=None): 156 | if pil_image is not None: 157 | if isinstance(pil_image, Image.Image): 158 | pil_image = [pil_image] 159 | clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values 160 | clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds 161 | else: 162 | clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16) 163 | 164 | if content_prompt_embeds is not None: 165 | clip_image_embeds = clip_image_embeds - content_prompt_embeds 166 | 167 | image_prompt_embeds = self.image_proj_model(clip_image_embeds) 168 | uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds)) 169 | return image_prompt_embeds, uncond_image_prompt_embeds 170 | 171 | def set_scale(self, scale): 172 | for attn_processor in self.pipe.unet.attn_processors.values(): 173 | if isinstance(attn_processor, IPAttnProcessor): 174 | attn_processor.scale = scale 175 | 176 | def generate( 177 | self, 178 | pil_image=None, 179 | clip_image_embeds=None, 180 | prompt=None, 181 | negative_prompt=None, 182 | scale=1.0, 183 | num_samples=4, 184 | seed=None, 185 | guidance_scale=7.5, 186 | num_inference_steps=30, 187 | neg_content_emb=None, 188 | **kwargs, 189 | ): 190 | self.set_scale(scale) 191 | 192 | if pil_image is not None: 193 | num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) 194 | else: 195 | num_prompts = clip_image_embeds.size(0) 196 | 197 | if prompt is None: 198 | prompt = "best quality, high quality" 199 | if negative_prompt is None: 200 | negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" 201 | 202 | if not isinstance(prompt, List): 203 | prompt = [prompt] * num_prompts 204 | if not isinstance(negative_prompt, List): 205 | negative_prompt = [negative_prompt] * num_prompts 206 | 207 | image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds( 208 | pil_image=pil_image, clip_image_embeds=clip_image_embeds, content_prompt_embeds=neg_content_emb 209 | ) 210 | bs_embed, seq_len, _ = image_prompt_embeds.shape 211 | image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) 212 | image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 213 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) 214 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 215 | 216 | with torch.inference_mode(): 217 | prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt( 218 | prompt, 219 | device=self.device, 220 | num_images_per_prompt=num_samples, 221 | do_classifier_free_guidance=True, 222 | negative_prompt=negative_prompt, 223 | ) 224 | prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) 225 | negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) 226 | 227 | generator = get_generator(seed, self.device) 228 | 229 | images = self.pipe( 230 | prompt_embeds=prompt_embeds, 231 | negative_prompt_embeds=negative_prompt_embeds, 232 | guidance_scale=guidance_scale, 233 | num_inference_steps=num_inference_steps, 234 | generator=generator, 235 | **kwargs, 236 | ).images 237 | 238 | return images 239 | 240 | 241 | class IPAdapterXL(IPAdapter): 242 | """SDXL""" 243 | 244 | def generate( 245 | self, 246 | pil_image, 247 | prompt=None, 248 | negative_prompt=None, 249 | scale=1.0, 250 | num_samples=4, 251 | seed=None, 252 | num_inference_steps=30, 253 | neg_content_emb=None, 254 | neg_content_prompt=None, 255 | neg_content_scale=1.0, 256 | **kwargs, 257 | ): 258 | self.set_scale(scale) 259 | 260 | num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) 261 | 262 | if prompt is None: 263 | prompt = "best quality, high quality" 264 | if negative_prompt is None: 265 | negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" 266 | 267 | if not isinstance(prompt, List): 268 | prompt = [prompt] * num_prompts 269 | if not isinstance(negative_prompt, List): 270 | negative_prompt = [negative_prompt] * num_prompts 271 | 272 | if neg_content_emb is None: 273 | if neg_content_prompt is not None: 274 | with torch.inference_mode(): 275 | ( 276 | prompt_embeds_, # torch.Size([1, 77, 2048]) 277 | negative_prompt_embeds_, 278 | pooled_prompt_embeds_, # torch.Size([1, 1280]) 279 | negative_pooled_prompt_embeds_, 280 | ) = self.pipe.encode_prompt( 281 | neg_content_prompt, 282 | num_images_per_prompt=num_samples, 283 | do_classifier_free_guidance=True, 284 | negative_prompt=negative_prompt, 285 | ) 286 | pooled_prompt_embeds_ *= neg_content_scale 287 | else: 288 | pooled_prompt_embeds_ = neg_content_emb 289 | else: 290 | pooled_prompt_embeds_ = None 291 | 292 | image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image, content_prompt_embeds=pooled_prompt_embeds_) 293 | bs_embed, seq_len, _ = image_prompt_embeds.shape 294 | image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) 295 | image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 296 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) 297 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 298 | 299 | with torch.inference_mode(): 300 | ( 301 | prompt_embeds, 302 | negative_prompt_embeds, 303 | pooled_prompt_embeds, 304 | negative_pooled_prompt_embeds, 305 | ) = self.pipe.encode_prompt( 306 | prompt, 307 | num_images_per_prompt=num_samples, 308 | do_classifier_free_guidance=True, 309 | negative_prompt=negative_prompt, 310 | ) 311 | prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) 312 | negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) 313 | 314 | self.generator = get_generator(seed, self.device) 315 | 316 | images = self.pipe( 317 | prompt_embeds=prompt_embeds, 318 | negative_prompt_embeds=negative_prompt_embeds, 319 | pooled_prompt_embeds=pooled_prompt_embeds, 320 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, 321 | num_inference_steps=num_inference_steps, 322 | generator=self.generator, 323 | **kwargs, 324 | ).images 325 | 326 | return images 327 | 328 | 329 | class IPAdapterPlus(IPAdapter): 330 | """IP-Adapter with fine-grained features""" 331 | 332 | def init_proj(self): 333 | image_proj_model = Resampler( 334 | dim=self.pipe.unet.config.cross_attention_dim, 335 | depth=4, 336 | dim_head=64, 337 | heads=12, 338 | num_queries=self.num_tokens, 339 | embedding_dim=self.image_encoder.config.hidden_size, 340 | output_dim=self.pipe.unet.config.cross_attention_dim, 341 | ff_mult=4, 342 | ).to(self.device, dtype=torch.float16) 343 | return image_proj_model 344 | 345 | @torch.inference_mode() 346 | def get_image_embeds(self, pil_image=None, clip_image_embeds=None): 347 | if isinstance(pil_image, Image.Image): 348 | pil_image = [pil_image] 349 | clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values 350 | clip_image = clip_image.to(self.device, dtype=torch.float16) 351 | clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] 352 | image_prompt_embeds = self.image_proj_model(clip_image_embeds) 353 | uncond_clip_image_embeds = self.image_encoder( 354 | torch.zeros_like(clip_image), output_hidden_states=True 355 | ).hidden_states[-2] 356 | uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) 357 | return image_prompt_embeds, uncond_image_prompt_embeds 358 | 359 | 360 | class IPAdapterFull(IPAdapterPlus): 361 | """IP-Adapter with full features""" 362 | 363 | def init_proj(self): 364 | image_proj_model = MLPProjModel( 365 | cross_attention_dim=self.pipe.unet.config.cross_attention_dim, 366 | clip_embeddings_dim=self.image_encoder.config.hidden_size, 367 | ).to(self.device, dtype=torch.float16) 368 | return image_proj_model 369 | 370 | 371 | class IPAdapterPlusXL(IPAdapter): 372 | """SDXL""" 373 | 374 | def init_proj(self): 375 | image_proj_model = Resampler( 376 | dim=1280, 377 | depth=4, 378 | dim_head=64, 379 | heads=20, 380 | num_queries=self.num_tokens, 381 | embedding_dim=self.image_encoder.config.hidden_size, 382 | output_dim=self.pipe.unet.config.cross_attention_dim, 383 | ff_mult=4, 384 | ).to(self.device, dtype=torch.float16) 385 | return image_proj_model 386 | 387 | @torch.inference_mode() 388 | def get_image_embeds(self, pil_image): 389 | if isinstance(pil_image, Image.Image): 390 | pil_image = [pil_image] 391 | clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values 392 | clip_image = clip_image.to(self.device, dtype=torch.float16) 393 | clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] 394 | image_prompt_embeds = self.image_proj_model(clip_image_embeds) 395 | uncond_clip_image_embeds = self.image_encoder( 396 | torch.zeros_like(clip_image), output_hidden_states=True 397 | ).hidden_states[-2] 398 | uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) 399 | return image_prompt_embeds, uncond_image_prompt_embeds 400 | 401 | def generate( 402 | self, 403 | pil_image, 404 | prompt=None, 405 | negative_prompt=None, 406 | scale=1.0, 407 | num_samples=4, 408 | seed=None, 409 | num_inference_steps=30, 410 | **kwargs, 411 | ): 412 | self.set_scale(scale) 413 | 414 | num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) 415 | 416 | if prompt is None: 417 | prompt = "best quality, high quality" 418 | if negative_prompt is None: 419 | negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" 420 | 421 | if not isinstance(prompt, List): 422 | prompt = [prompt] * num_prompts 423 | if not isinstance(negative_prompt, List): 424 | negative_prompt = [negative_prompt] * num_prompts 425 | 426 | image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image) 427 | bs_embed, seq_len, _ = image_prompt_embeds.shape 428 | image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) 429 | image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 430 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) 431 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 432 | 433 | with torch.inference_mode(): 434 | ( 435 | prompt_embeds, 436 | negative_prompt_embeds, 437 | pooled_prompt_embeds, 438 | negative_pooled_prompt_embeds, 439 | ) = self.pipe.encode_prompt( 440 | prompt, 441 | num_images_per_prompt=num_samples, 442 | do_classifier_free_guidance=True, 443 | negative_prompt=negative_prompt, 444 | ) 445 | prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) 446 | negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) 447 | 448 | generator = get_generator(seed, self.device) 449 | 450 | images = self.pipe( 451 | prompt_embeds=prompt_embeds, 452 | negative_prompt_embeds=negative_prompt_embeds, 453 | pooled_prompt_embeds=pooled_prompt_embeds, 454 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, 455 | num_inference_steps=num_inference_steps, 456 | generator=generator, 457 | **kwargs, 458 | ).images 459 | 460 | return images 461 | -------------------------------------------------------------------------------- /ip_adapter/resampler.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py 2 | # and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py 3 | 4 | import math 5 | 6 | import torch 7 | import torch.nn as nn 8 | from einops import rearrange 9 | from einops.layers.torch import Rearrange 10 | 11 | 12 | # FFN 13 | def FeedForward(dim, mult=4): 14 | inner_dim = int(dim * mult) 15 | return nn.Sequential( 16 | nn.LayerNorm(dim), 17 | nn.Linear(dim, inner_dim, bias=False), 18 | nn.GELU(), 19 | nn.Linear(inner_dim, dim, bias=False), 20 | ) 21 | 22 | 23 | def reshape_tensor(x, heads): 24 | bs, length, width = x.shape 25 | # (bs, length, width) --> (bs, length, n_heads, dim_per_head) 26 | x = x.view(bs, length, heads, -1) 27 | # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) 28 | x = x.transpose(1, 2) 29 | # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) 30 | x = x.reshape(bs, heads, length, -1) 31 | return x 32 | 33 | 34 | class PerceiverAttention(nn.Module): 35 | def __init__(self, *, dim, dim_head=64, heads=8): 36 | super().__init__() 37 | self.scale = dim_head**-0.5 38 | self.dim_head = dim_head 39 | self.heads = heads 40 | inner_dim = dim_head * heads 41 | 42 | self.norm1 = nn.LayerNorm(dim) 43 | self.norm2 = nn.LayerNorm(dim) 44 | 45 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 46 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) 47 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 48 | 49 | def forward(self, x, latents): 50 | """ 51 | Args: 52 | x (torch.Tensor): image features 53 | shape (b, n1, D) 54 | latent (torch.Tensor): latent features 55 | shape (b, n2, D) 56 | """ 57 | x = self.norm1(x) 58 | latents = self.norm2(latents) 59 | 60 | b, l, _ = latents.shape 61 | 62 | q = self.to_q(latents) 63 | kv_input = torch.cat((x, latents), dim=-2) 64 | k, v = self.to_kv(kv_input).chunk(2, dim=-1) 65 | 66 | q = reshape_tensor(q, self.heads) 67 | k = reshape_tensor(k, self.heads) 68 | v = reshape_tensor(v, self.heads) 69 | 70 | # attention 71 | scale = 1 / math.sqrt(math.sqrt(self.dim_head)) 72 | weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards 73 | weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) 74 | out = weight @ v 75 | 76 | out = out.permute(0, 2, 1, 3).reshape(b, l, -1) 77 | 78 | return self.to_out(out) 79 | 80 | 81 | class Resampler(nn.Module): 82 | def __init__( 83 | self, 84 | dim=1024, 85 | depth=8, 86 | dim_head=64, 87 | heads=16, 88 | num_queries=8, 89 | embedding_dim=768, 90 | output_dim=1024, 91 | ff_mult=4, 92 | max_seq_len: int = 257, # CLIP tokens + CLS token 93 | apply_pos_emb: bool = False, 94 | num_latents_mean_pooled: int = 0, # number of latents derived from mean pooled representation of the sequence 95 | ): 96 | super().__init__() 97 | self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None 98 | 99 | self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) 100 | 101 | self.proj_in = nn.Linear(embedding_dim, dim) 102 | 103 | self.proj_out = nn.Linear(dim, output_dim) 104 | self.norm_out = nn.LayerNorm(output_dim) 105 | 106 | self.to_latents_from_mean_pooled_seq = ( 107 | nn.Sequential( 108 | nn.LayerNorm(dim), 109 | nn.Linear(dim, dim * num_latents_mean_pooled), 110 | Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled), 111 | ) 112 | if num_latents_mean_pooled > 0 113 | else None 114 | ) 115 | 116 | self.layers = nn.ModuleList([]) 117 | for _ in range(depth): 118 | self.layers.append( 119 | nn.ModuleList( 120 | [ 121 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), 122 | FeedForward(dim=dim, mult=ff_mult), 123 | ] 124 | ) 125 | ) 126 | 127 | def forward(self, x): 128 | if self.pos_emb is not None: 129 | n, device = x.shape[1], x.device 130 | pos_emb = self.pos_emb(torch.arange(n, device=device)) 131 | x = x + pos_emb 132 | 133 | latents = self.latents.repeat(x.size(0), 1, 1) 134 | 135 | x = self.proj_in(x) 136 | 137 | if self.to_latents_from_mean_pooled_seq: 138 | meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool)) 139 | meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq) 140 | latents = torch.cat((meanpooled_latents, latents), dim=-2) 141 | 142 | for attn, ff in self.layers: 143 | latents = attn(x, latents) + latents 144 | latents = ff(latents) + latents 145 | 146 | latents = self.proj_out(latents) 147 | return self.norm_out(latents) 148 | 149 | 150 | def masked_mean(t, *, dim, mask=None): 151 | if mask is None: 152 | return t.mean(dim=dim) 153 | 154 | denom = mask.sum(dim=dim, keepdim=True) 155 | mask = rearrange(mask, "b n -> b n 1") 156 | masked_t = t.masked_fill(~mask, 0.0) 157 | 158 | return masked_t.sum(dim=dim) / denom.clamp(min=1e-5) 159 | -------------------------------------------------------------------------------- /ip_adapter/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from PIL import Image 5 | 6 | attn_maps = {} 7 | def hook_fn(name): 8 | def forward_hook(module, input, output): 9 | if hasattr(module.processor, "attn_map"): 10 | attn_maps[name] = module.processor.attn_map 11 | del module.processor.attn_map 12 | 13 | return forward_hook 14 | 15 | def register_cross_attention_hook(unet): 16 | for name, module in unet.named_modules(): 17 | if name.split('.')[-1].startswith('attn2'): 18 | module.register_forward_hook(hook_fn(name)) 19 | 20 | return unet 21 | 22 | def upscale(attn_map, target_size): 23 | attn_map = torch.mean(attn_map, dim=0) 24 | attn_map = attn_map.permute(1,0) 25 | temp_size = None 26 | 27 | for i in range(0,5): 28 | scale = 2 ** i 29 | if ( target_size[0] // scale ) * ( target_size[1] // scale) == attn_map.shape[1]*64: 30 | temp_size = (target_size[0]//(scale*8), target_size[1]//(scale*8)) 31 | break 32 | 33 | assert temp_size is not None, "temp_size cannot is None" 34 | 35 | attn_map = attn_map.view(attn_map.shape[0], *temp_size) 36 | 37 | attn_map = F.interpolate( 38 | attn_map.unsqueeze(0).to(dtype=torch.float32), 39 | size=target_size, 40 | mode='bilinear', 41 | align_corners=False 42 | )[0] 43 | 44 | attn_map = torch.softmax(attn_map, dim=0) 45 | return attn_map 46 | def get_net_attn_map(image_size, batch_size=2, instance_or_negative=False, detach=True): 47 | 48 | idx = 0 if instance_or_negative else 1 49 | net_attn_maps = [] 50 | 51 | for name, attn_map in attn_maps.items(): 52 | attn_map = attn_map.cpu() if detach else attn_map 53 | attn_map = torch.chunk(attn_map, batch_size)[idx].squeeze() 54 | attn_map = upscale(attn_map, image_size) 55 | net_attn_maps.append(attn_map) 56 | 57 | net_attn_maps = torch.mean(torch.stack(net_attn_maps,dim=0),dim=0) 58 | 59 | return net_attn_maps 60 | 61 | def attnmaps2images(net_attn_maps): 62 | 63 | #total_attn_scores = 0 64 | images = [] 65 | 66 | for attn_map in net_attn_maps: 67 | attn_map = attn_map.cpu().numpy() 68 | #total_attn_scores += attn_map.mean().item() 69 | 70 | normalized_attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255 71 | normalized_attn_map = normalized_attn_map.astype(np.uint8) 72 | #print("norm: ", normalized_attn_map.shape) 73 | image = Image.fromarray(normalized_attn_map) 74 | 75 | #image = fix_save_attn_map(attn_map) 76 | images.append(image) 77 | 78 | #print(total_attn_scores) 79 | return images 80 | def is_torch2_available(): 81 | return hasattr(F, "scaled_dot_product_attention") 82 | 83 | def get_generator(seed, device): 84 | 85 | if seed is not None: 86 | if isinstance(seed, list): 87 | generator = [torch.Generator(device).manual_seed(seed_item) for seed_item in seed] 88 | else: 89 | generator = torch.Generator(device).manual_seed(seed) 90 | else: 91 | generator = None 92 | 93 | return generator --------------------------------------------------------------------------------