├── README.md ├── assets ├── image3_1.jpg ├── img_0.png ├── img_1.png ├── img_2.png ├── overview.jpg ├── page1.png ├── page10.png ├── page11.jpg ├── page4.png ├── page8.png └── vis.jpg ├── gradio ├── app.py ├── assets │ ├── img_0.png │ ├── img_1.png │ └── img_2.png └── requirements.txt ├── infer ├── infer_CSGO.py └── infer_csgo.ipynb ├── ip_adapter ├── __init__.py ├── attention_processor.py ├── ip_adapter.py ├── resampler.py └── utils.py └── requirements.txt /README.md: -------------------------------------------------------------------------------- 1 |
2 |

CSGO: Content-Style Composition in Text-to-Image Generation

3 | 4 | [**Peng Xing**](https://github.com/xingp-ng)12* · [**Haofan Wang**](https://haofanwang.github.io/)1* · [**Yanpeng Sun**](https://scholar.google.com.hk/citations?user=a3FI8c4AAAAJ&hl=zh-CN&oi=ao/)2 · [**Qixun Wang**](https://github.com/wangqixun)1 · [**Xu Bai**](https://huggingface.co/baymin0220)13 · [**Hao Ai**](https://github.com/aihao2000)14 · [**Renyuan Huang**](https://github.com/DannHuang)15 5 | [**Zechao Li**](https://zechao-li.github.io/)2✉ 6 | 7 | 1InstantX Team · 2Nanjing University of Science and Technology · 3Xiaohongshu · 4Beihang University · 5Peking University 8 | 9 | *equal contributions, corresponding authors 10 | 11 | 12 | 13 | [![Hugging Face](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-blue)](https://huggingface.co/InstantX/CSGO) 14 | [![Hugging Face](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-App-red)](https://huggingface.co/spaces/xingpng/CSGO/) 15 | [![GitHub](https://img.shields.io/github/stars/instantX-research/CSGO?style=social)](https://github.com/instantX-research/CSGO) 16 |
17 | 18 | 19 | ## Updates 🔥 20 | 21 | [//]: # (- **`2024/07/19`**: ✨ We support 🎞️ portrait video editing (aka v2v)! More to see [here](assets/docs/changelog/2024-07-19.md).) 22 | 23 | [//]: # (- **`2024/07/17`**: 🍎 We support macOS with Apple Silicon, modified from [jeethu](https://github.com/jeethu)'s PR [#143](https://github.com/KwaiVGI/LivePortrait/pull/143).) 24 | 25 | [//]: # (- **`2024/07/10`**: 💪 We support audio and video concatenating, driving video auto-cropping, and template making to protect privacy. More to see [here](assets/docs/changelog/2024-07-10.md).) 26 | 27 | [//]: # (- **`2024/07/09`**: 🤗 We released the [HuggingFace Space](https://huggingface.co/spaces/KwaiVGI/liveportrait), thanks to the HF team and [Gradio](https://github.com/gradio-app/gradio)!) 28 | [//]: # (Continuous updates, stay tuned!) 29 | - **`2024/09/04`**: 🔥 We released the gradio code. You can simply configure it and use it directly. 30 | - **`2024/09/03`**: 🔥 We released the online demo on [Hugggingface](https://huggingface.co/spaces/xingpng/CSGO/). 31 | - **`2024/09/03`**: 🔥 We released the [pre-trained weight](https://huggingface.co/InstantX/CSGO). 32 | - **`2024/09/03`**: 🔥 We released the initial version of the inference code. 33 | - **`2024/08/30`**: 🔥 We released the technical report on [arXiv](https://arxiv.org/pdf/2408.16766) 34 | - **`2024/07/15`**: 🔥 We released the [homepage](https://csgo-gen.github.io). 35 | 36 | ## Plan 💪 37 | - [x] technical report 38 | - [x] inference code 39 | - [x] pre-trained weight [4_16] 40 | - [x] pre-trained weight [4_32] 41 | - [x] online demo 42 | - [ ] pre-trained weight_v2 [4_32] 43 | - [ ] IMAGStyle dataset 44 | - [ ] training code 45 | - [ ] more pre-trained weight 46 | 47 | ## Introduction 📖 48 | This repo, named **CSGO**, contains the official PyTorch implementation of our paper [CSGO: Content-Style Composition in Text-to-Image Generation](https://arxiv.org/pdf/). 49 | We are actively updating and improving this repository. If you find any bugs or have suggestions, welcome to raise issues or submit pull requests (PR) 💖. 50 | 51 | ## Pipeline 💻 52 |

53 | 54 |

55 | 56 | ## Capabilities 🚅 57 | 58 | 🔥 Our CSGO achieves **image-driven style transfer, text-driven stylized synthesis, and text editing-driven stylized synthesis**. 59 | 60 | 🔥 For more results, visit our homepage 🔥 61 | 62 |

63 | 64 |

65 | 66 | 67 | ## Getting Started 🏁 68 | ### 1. Clone the code and prepare the environment 69 | ```bash 70 | git clone https://github.com/instantX-research/CSGO 71 | cd CSGO 72 | 73 | # create env using conda 74 | conda create -n CSGO python=3.9 75 | conda activate CSGO 76 | 77 | # install dependencies with pip 78 | # for Linux and Windows users 79 | pip install -r requirements.txt 80 | ``` 81 | 82 | ### 2. Download pretrained weights 83 | 84 | We currently release two model weights. 85 | 86 | | Mode | content token | style token | Other | 87 | |:----------------:|:-----------:|:-----------:|:---------------------------------:| 88 | | csgo.bin |4|16| - | 89 | | csgo_4_32.bin |4|32| Deepspeed zero2 | 90 | | csgo_4_32_v2.bin |4|32| Deepspeed zero2+more(coming soon) | 91 | 92 | The easiest way to download the pretrained weights is from HuggingFace: 93 | ```bash 94 | # first, ensure git-lfs is installed, see: https://docs.github.com/en/repositories/working-with-files/managing-large-files/installing-git-large-file-storage 95 | git lfs install 96 | # clone and move the weights 97 | git clone https://huggingface.co/InstantX/CSGO 98 | ``` 99 | Our method is fully compatible with [SDXL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0), [VAE](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix), [ControlNet](https://huggingface.co/TTPlanet/TTPLanet_SDXL_Controlnet_Tile_Realistic), and [Image Encoder](https://huggingface.co/h94/IP-Adapter/tree/main/sdxl_models/image_encoder). 100 | Please download them and place them in the ./base_models folder. 101 | 102 | tips:If you expect to load Controlnet directly using ControlNetPipeline as in CSGO, do the following: 103 | ```bash 104 | git clone https://huggingface.co/TTPlanet/TTPLanet_SDXL_Controlnet_Tile_Realistic 105 | mv TTPLanet_SDXL_Controlnet_Tile_Realistic/TTPLANET_Controlnet_Tile_realistic_v2_fp16.safetensors TTPLanet_SDXL_Controlnet_Tile_Realistic/diffusion_pytorch_model.safetensors 106 | ``` 107 | ### 3. Inference 🚀 108 | 109 | ```python 110 | import torch 111 | from ip_adapter.utils import BLOCKS as BLOCKS 112 | from ip_adapter.utils import controlnet_BLOCKS as controlnet_BLOCKS 113 | from PIL import Image 114 | from diffusers import ( 115 | AutoencoderKL, 116 | ControlNetModel, 117 | StableDiffusionXLControlNetPipeline, 118 | 119 | ) 120 | from ip_adapter import CSGO 121 | 122 | 123 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 124 | 125 | base_model_path = "./base_models/stable-diffusion-xl-base-1.0" 126 | image_encoder_path = "./base_models/IP-Adapter/sdxl_models/image_encoder" 127 | csgo_ckpt = "./CSGO/csgo.bin" 128 | pretrained_vae_name_or_path ='./base_models/sdxl-vae-fp16-fix' 129 | controlnet_path = "./base_models/TTPLanet_SDXL_Controlnet_Tile_Realistic" 130 | weight_dtype = torch.float16 131 | 132 | 133 | vae = AutoencoderKL.from_pretrained(pretrained_vae_name_or_path,torch_dtype=torch.float16) 134 | controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16,use_safetensors=True) 135 | pipe = StableDiffusionXLControlNetPipeline.from_pretrained( 136 | base_model_path, 137 | controlnet=controlnet, 138 | torch_dtype=torch.float16, 139 | add_watermarker=False, 140 | vae=vae 141 | ) 142 | pipe.enable_vae_tiling() 143 | 144 | 145 | target_content_blocks = BLOCKS['content'] 146 | target_style_blocks = BLOCKS['style'] 147 | controlnet_target_content_blocks = controlnet_BLOCKS['content'] 148 | controlnet_target_style_blocks = controlnet_BLOCKS['style'] 149 | 150 | csgo = CSGO(pipe, image_encoder_path, csgo_ckpt, device, num_content_tokens=4,num_style_tokens=32, 151 | target_content_blocks=target_content_blocks, target_style_blocks=target_style_blocks,controlnet=False,controlnet_adapter=True, 152 | controlnet_target_content_blocks=controlnet_target_content_blocks, 153 | controlnet_target_style_blocks=controlnet_target_style_blocks, 154 | content_model_resampler=True, 155 | style_model_resampler=True, 156 | load_controlnet=False, 157 | 158 | ) 159 | 160 | style_name = 'img_0.png' 161 | content_name = 'img_0.png' 162 | style_image = "../assets/{}".format(style_name) 163 | content_image = Image.open('../assets/{}'.format(content_name)).convert('RGB') 164 | 165 | caption ='a small house with a sheep statue on top of it' 166 | 167 | num_sample=4 168 | 169 | #image-driven style transfer 170 | images = csgo.generate(pil_content_image= content_image, pil_style_image=style_image, 171 | prompt=caption, 172 | negative_prompt= "text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry", 173 | content_scale=1.0, 174 | style_scale=1.0, 175 | guidance_scale=10, 176 | num_images_per_prompt=num_sample, 177 | num_samples=1, 178 | num_inference_steps=50, 179 | seed=42, 180 | image=content_image.convert('RGB'), 181 | controlnet_conditioning_scale=0.6, 182 | ) 183 | 184 | #text-driven stylized synthesis 185 | caption='a cat' 186 | images = csgo.generate(pil_content_image= content_image, pil_style_image=style_image, 187 | prompt=caption, 188 | negative_prompt= "text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry", 189 | content_scale=1.0, 190 | style_scale=1.0, 191 | guidance_scale=10, 192 | num_images_per_prompt=num_sample, 193 | num_samples=1, 194 | num_inference_steps=50, 195 | seed=42, 196 | image=content_image.convert('RGB'), 197 | controlnet_conditioning_scale=0.01, 198 | ) 199 | 200 | #text editing-driven stylized synthesis 201 | caption='a small house' 202 | images = csgo.generate(pil_content_image= content_image, pil_style_image=style_image, 203 | prompt=caption, 204 | negative_prompt= "text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry", 205 | content_scale=1.0, 206 | style_scale=1.0, 207 | guidance_scale=10, 208 | num_images_per_prompt=num_sample, 209 | num_samples=1, 210 | num_inference_steps=50, 211 | seed=42, 212 | image=content_image.convert('RGB'), 213 | controlnet_conditioning_scale=0.4, 214 | ) 215 | ``` 216 | ### 4 Gradio interface ⚙️ 217 | 218 | We also provide a Gradio interface for a better experience, just run by: 219 | 220 | ```bash 221 | # For Linux and Windows users (and macOS) 222 | python gradio/app.py 223 | ``` 224 | If you don't have the resources to configure it, we provide an online [demo](https://huggingface.co/spaces/xingpng/CSGO/). 225 | ## Demos 226 |

227 |
228 | 🔥 For more results, visit our homepage 🔥 229 |

230 | 231 | ### Content-Style Composition 232 |

233 | 234 |

235 | 236 |

237 | 238 |

239 | 240 | ### Cycle Translation 241 |

242 | 243 |

244 | 245 | ### Text-Driven Style Synthesis 246 |

247 | 248 |

249 | 250 | ### Text Editing-Driven Style Synthesis 251 |

252 | 253 |

254 | 255 | ## Star History 256 | [![Star History Chart](https://api.star-history.com/svg?repos=instantX-research/CSGO&type=Date)](https://star-history.com/#instantX-research/CSGO&Date) 257 | 258 | 259 | 260 | ## Acknowledgements 261 | This project is developed by InstantX Team and Xiaohongshu, all copyright reserved. 262 | Sincere thanks to xiaohongshu for providing the computing resources. 263 | 264 | ## Citation 💖 265 | If you find CSGO useful for your research, welcome to 🌟 this repo and cite our work using the following BibTeX: 266 | ```bibtex 267 | @article{xing2024csgo, 268 | title={CSGO: Content-Style Composition in Text-to-Image Generation}, 269 | author={Peng Xing and Haofan Wang and Yanpeng Sun and Qixun Wang and Xu Bai and Hao Ai and Renyuan Huang and Zechao Li}, 270 | year={2024}, 271 | journal = {arXiv 2408.16766}, 272 | } 273 | ``` -------------------------------------------------------------------------------- /assets/image3_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instantX-research/CSGO/fefec09cf680d9796b92a64282ceb12501b9f977/assets/image3_1.jpg -------------------------------------------------------------------------------- /assets/img_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instantX-research/CSGO/fefec09cf680d9796b92a64282ceb12501b9f977/assets/img_0.png -------------------------------------------------------------------------------- /assets/img_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instantX-research/CSGO/fefec09cf680d9796b92a64282ceb12501b9f977/assets/img_1.png -------------------------------------------------------------------------------- /assets/img_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instantX-research/CSGO/fefec09cf680d9796b92a64282ceb12501b9f977/assets/img_2.png -------------------------------------------------------------------------------- /assets/overview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instantX-research/CSGO/fefec09cf680d9796b92a64282ceb12501b9f977/assets/overview.jpg -------------------------------------------------------------------------------- /assets/page1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instantX-research/CSGO/fefec09cf680d9796b92a64282ceb12501b9f977/assets/page1.png -------------------------------------------------------------------------------- /assets/page10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instantX-research/CSGO/fefec09cf680d9796b92a64282ceb12501b9f977/assets/page10.png -------------------------------------------------------------------------------- /assets/page11.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instantX-research/CSGO/fefec09cf680d9796b92a64282ceb12501b9f977/assets/page11.jpg -------------------------------------------------------------------------------- /assets/page4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instantX-research/CSGO/fefec09cf680d9796b92a64282ceb12501b9f977/assets/page4.png -------------------------------------------------------------------------------- /assets/page8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instantX-research/CSGO/fefec09cf680d9796b92a64282ceb12501b9f977/assets/page8.png -------------------------------------------------------------------------------- /assets/vis.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instantX-research/CSGO/fefec09cf680d9796b92a64282ceb12501b9f977/assets/vis.jpg -------------------------------------------------------------------------------- /gradio/app.py: -------------------------------------------------------------------------------- 1 | import sys 2 | # sys.path.append("../") 3 | sys.path.append("./") 4 | import gradio as gr 5 | import torch 6 | from ip_adapter.utils import BLOCKS as BLOCKS 7 | from ip_adapter.utils import controlnet_BLOCKS as controlnet_BLOCKS 8 | from ip_adapter.utils import resize_content 9 | import cv2 10 | import numpy as np 11 | import random 12 | from PIL import Image 13 | from transformers import AutoImageProcessor, AutoModel 14 | from diffusers import ( 15 | AutoencoderKL, 16 | ControlNetModel, 17 | StableDiffusionXLControlNetPipeline, 18 | 19 | ) 20 | from ip_adapter import CSGO 21 | from transformers import BlipProcessor, BlipForConditionalGeneration 22 | 23 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 24 | 25 | base_model_path = "stabilityai/stable-diffusion-xl-base-1.0" 26 | image_encoder_path = "h94/IP-Adapter/sdxl_models/image_encoder" 27 | csgo_ckpt ='InstantX/CSGO/csgo_4_32.bin' 28 | pretrained_vae_name_or_path ='madebyollin/sdxl-vae-fp16-fix' 29 | controlnet_path = "TTPlanet/TTPLanet_SDXL_Controlnet_Tile_Realistic" 30 | weight_dtype = torch.float16 31 | 32 | 33 | 34 | 35 | vae = AutoencoderKL.from_pretrained(pretrained_vae_name_or_path,torch_dtype=torch.float16) 36 | controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16,use_safetensors=True) 37 | pipe = StableDiffusionXLControlNetPipeline.from_pretrained( 38 | base_model_path, 39 | controlnet=controlnet, 40 | torch_dtype=torch.float16, 41 | add_watermarker=False, 42 | vae=vae 43 | ) 44 | pipe.enable_vae_tiling() 45 | blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large") 46 | blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to(device) 47 | 48 | target_content_blocks = BLOCKS['content'] 49 | target_style_blocks = BLOCKS['style'] 50 | controlnet_target_content_blocks = controlnet_BLOCKS['content'] 51 | controlnet_target_style_blocks = controlnet_BLOCKS['style'] 52 | 53 | csgo = CSGO(pipe, image_encoder_path, csgo_ckpt, device, num_content_tokens=4, num_style_tokens=32, 54 | target_content_blocks=target_content_blocks, target_style_blocks=target_style_blocks, 55 | controlnet_adapter=True, 56 | controlnet_target_content_blocks=controlnet_target_content_blocks, 57 | controlnet_target_style_blocks=controlnet_target_style_blocks, 58 | content_model_resampler=True, 59 | style_model_resampler=True, 60 | ) 61 | 62 | MAX_SEED = np.iinfo(np.int32).max 63 | 64 | def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: 65 | if randomize_seed: 66 | seed = random.randint(0, MAX_SEED) 67 | return seed 68 | 69 | 70 | 71 | 72 | 73 | def get_example(): 74 | case = [ 75 | [ 76 | "./assets/img_0.png", 77 | './assets/img_1.png', 78 | "Image-Driven Style Transfer", 79 | "there is a small house with a sheep statue on top of it", 80 | 1.0, 81 | 0.6, 82 | 1.0, 83 | ], 84 | [ 85 | None, 86 | './assets/img_1.png', 87 | "Text-Driven Style Synthesis", 88 | "a cat", 89 | 1.0, 90 | 0.01, 91 | 1.0 92 | ], 93 | [ 94 | None, 95 | './assets/img_2.png', 96 | "Text-Driven Style Synthesis", 97 | "a building", 98 | 0.5, 99 | 0.01, 100 | 1.0 101 | ], 102 | [ 103 | "./assets/img_0.png", 104 | './assets/img_1.png', 105 | "Text Edit-Driven Style Synthesis", 106 | "there is a small house", 107 | 1.0, 108 | 0.4, 109 | 1.0 110 | ], 111 | ] 112 | return case 113 | 114 | 115 | def run_for_examples(content_image_pil,style_image_pil,target, prompt,scale_c_controlnet, scale_c, scale_s): 116 | return create_image( 117 | content_image_pil=content_image_pil, 118 | style_image_pil=style_image_pil, 119 | prompt=prompt, 120 | scale_c_controlnet=scale_c_controlnet, 121 | scale_c=scale_c, 122 | scale_s=scale_s, 123 | guidance_scale=7.0, 124 | num_samples=3, 125 | num_inference_steps=50, 126 | seed=42, 127 | target=target, 128 | ) 129 | def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: 130 | if randomize_seed: 131 | seed = random.randint(0, MAX_SEED) 132 | return seed 133 | 134 | def image_grid(imgs, rows, cols): 135 | assert len(imgs) == rows * cols 136 | 137 | w, h = imgs[0].size 138 | grid = Image.new('RGB', size=(cols * w, rows * h)) 139 | grid_w, grid_h = grid.size 140 | 141 | for i, img in enumerate(imgs): 142 | grid.paste(img, box=(i % cols * w, i // cols * h)) 143 | return grid 144 | def create_image(content_image_pil, 145 | style_image_pil, 146 | prompt, 147 | scale_c_controlnet, 148 | scale_c, 149 | scale_s, 150 | guidance_scale, 151 | num_samples, 152 | num_inference_steps, 153 | seed, 154 | target="Image-Driven Style Transfer", 155 | ): 156 | 157 | if content_image_pil is None: 158 | content_image_pil = Image.fromarray( 159 | np.zeros((1024, 1024, 3), dtype=np.uint8)).convert('RGB') 160 | 161 | if prompt is None or prompt == '': 162 | with torch.no_grad(): 163 | inputs = blip_processor(content_image_pil, return_tensors="pt").to(device) 164 | out = blip_model.generate(**inputs) 165 | prompt = blip_processor.decode(out[0], skip_special_tokens=True) 166 | width, height, content_image = resize_content(content_image_pil) 167 | style_image = style_image_pil 168 | neg_content_prompt='text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry' 169 | if target =="Image-Driven Style Transfer": 170 | with torch.no_grad(): 171 | images = csgo.generate(pil_content_image=content_image, pil_style_image=style_image, 172 | prompt=prompt, 173 | negative_prompt=neg_content_prompt, 174 | height=height, 175 | width=width, 176 | content_scale=scale_c, 177 | style_scale=scale_s, 178 | guidance_scale=guidance_scale, 179 | num_images_per_prompt=num_samples, 180 | num_inference_steps=num_inference_steps, 181 | num_samples=1, 182 | seed=seed, 183 | image=content_image.convert('RGB'), 184 | controlnet_conditioning_scale=scale_c_controlnet, 185 | ) 186 | 187 | elif target =="Text-Driven Style Synthesis": 188 | content_image = Image.fromarray( 189 | np.zeros((1024, 1024, 3), dtype=np.uint8)).convert('RGB') 190 | with torch.no_grad(): 191 | images = csgo.generate(pil_content_image=content_image, pil_style_image=style_image, 192 | prompt=prompt, 193 | negative_prompt="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry", 194 | height=height, 195 | width=width, 196 | content_scale=scale_c, 197 | style_scale=scale_s, 198 | guidance_scale=7, 199 | num_images_per_prompt=num_samples, 200 | num_inference_steps=num_inference_steps, 201 | num_samples=1, 202 | seed=42, 203 | image=content_image.convert('RGB'), 204 | controlnet_conditioning_scale=scale_c_controlnet, 205 | ) 206 | elif target =="Text Edit-Driven Style Synthesis": 207 | 208 | with torch.no_grad(): 209 | images = csgo.generate(pil_content_image=content_image, pil_style_image=style_image, 210 | prompt=prompt, 211 | negative_prompt=neg_content_prompt, 212 | height=height, 213 | width=width, 214 | content_scale=scale_c, 215 | style_scale=scale_s, 216 | guidance_scale=guidance_scale, 217 | num_images_per_prompt=num_samples, 218 | num_inference_steps=num_inference_steps, 219 | num_samples=1, 220 | seed=seed, 221 | image=content_image.convert('RGB'), 222 | controlnet_conditioning_scale=scale_c_controlnet, 223 | ) 224 | 225 | return [image_grid(images, 1, num_samples)] 226 | 227 | 228 | def pil_to_cv2(image_pil): 229 | image_np = np.array(image_pil) 230 | image_cv2 = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) 231 | return image_cv2 232 | 233 | 234 | # Description 235 | title = r""" 236 |

CSGO: Content-Style Composition in Text-to-Image Generation

237 | """ 238 | 239 | description = r""" 240 | Official 🤗 Gradio demo for CSGO: Content-Style Composition in Text-to-Image Generation.
241 | How to use:
242 | 1. Upload a content image if you want to use image-driven style transfer. 243 | 2. Upload a style image. 244 | 3. Sets the type of task to perform, by default image-driven style transfer is performed. Options are Image-driven style transfer, Text-driven style synthesis, and Text editing-driven style synthesis. 245 | 4. If you choose a text-driven task, enter your desired prompt. 246 | 5.If you don't provide a prompt, the default is to use the BLIP model to generate the caption. 247 | 6. Click the Submit button to begin customization. 248 | 7. Share your stylized photo with your friends and enjoy! 😊 249 | 250 | Advanced usage:
251 | 1. Click advanced options. 252 | 2. Choose different guidance and steps. 253 | """ 254 | 255 | article = r""" 256 | --- 257 | 📝 **Tips** 258 | In CSGO, the more accurate the text prompts for content images, the better the content retention. 259 | Text-driven style synthesis and text-edit-driven style synthesis are expected to be more stable in the next release. 260 | --- 261 | 📝 **Citation** 262 |
263 | If our work is helpful for your research or applications, please cite us via: 264 | ```bibtex 265 | @article{xing2024csgo, 266 | title={CSGO: Content-Style Composition in Text-to-Image Generation}, 267 | author={Peng Xing and Haofan Wang and Yanpeng Sun and Qixun Wang and Xu Bai and Hao Ai and Renyuan Huang and Zechao Li}, 268 | year={2024}, 269 | journal = {arXiv 2408.16766}, 270 | } 271 | ``` 272 | 📧 **Contact** 273 |
274 | If you have any questions, please feel free to open an issue or directly reach us out at xingp_ng@njust.edu.cn. 275 | """ 276 | 277 | block = gr.Blocks(css="footer {visibility: hidden}").queue(max_size=10, api_open=False) 278 | with block: 279 | # description 280 | gr.Markdown(title) 281 | gr.Markdown(description) 282 | 283 | with gr.Tabs(): 284 | with gr.Row(): 285 | with gr.Column(): 286 | with gr.Row(): 287 | with gr.Column(): 288 | content_image_pil = gr.Image(label="Content Image (optional)", type='pil') 289 | style_image_pil = gr.Image(label="Style Image", type='pil') 290 | 291 | target = gr.Radio(["Image-Driven Style Transfer", "Text-Driven Style Synthesis", "Text Edit-Driven Style Synthesis"], 292 | value="Image-Driven Style Transfer", 293 | label="task") 294 | 295 | prompt = gr.Textbox(label="Prompt", 296 | value="there is a small house with a sheep statue on top of it") 297 | 298 | scale_c_controlnet = gr.Slider(minimum=0, maximum=2.0, step=0.01, value=0.6, 299 | label="Content Scale for controlnet") 300 | scale_c = gr.Slider(minimum=0, maximum=2.0, step=0.01, value=0.6, label="Content Scale for IPA") 301 | 302 | scale_s = gr.Slider(minimum=0, maximum=2.0, step=0.01, value=1.0, label="Style Scale") 303 | with gr.Accordion(open=False, label="Advanced Options"): 304 | 305 | guidance_scale = gr.Slider(minimum=1, maximum=15.0, step=0.01, value=7.0, label="guidance scale") 306 | num_samples = gr.Slider(minimum=1, maximum=4.0, step=1.0, value=1.0, label="num samples") 307 | num_inference_steps = gr.Slider(minimum=5, maximum=100.0, step=1.0, value=50, 308 | label="num inference steps") 309 | seed = gr.Slider(minimum=-1000000, maximum=1000000, value=1, step=1, label="Seed Value") 310 | randomize_seed = gr.Checkbox(label="Randomize seed", value=True) 311 | 312 | generate_button = gr.Button("Generate Image") 313 | 314 | with gr.Column(): 315 | generated_image = gr.Gallery(label="Generated Image") 316 | 317 | generate_button.click( 318 | fn=randomize_seed_fn, 319 | inputs=[seed, randomize_seed], 320 | outputs=seed, 321 | queue=False, 322 | api_name=False, 323 | ).then( 324 | fn=create_image, 325 | inputs=[content_image_pil, 326 | style_image_pil, 327 | prompt, 328 | scale_c_controlnet, 329 | scale_c, 330 | scale_s, 331 | guidance_scale, 332 | num_samples, 333 | num_inference_steps, 334 | seed, 335 | target,], 336 | outputs=[generated_image]) 337 | 338 | gr.Examples( 339 | examples=get_example(), 340 | inputs=[content_image_pil,style_image_pil,target, prompt,scale_c_controlnet, scale_c, scale_s], 341 | fn=run_for_examples, 342 | outputs=[generated_image], 343 | cache_examples=True, 344 | ) 345 | 346 | gr.Markdown(article) 347 | 348 | block.launch(server_name="0.0.0.0", server_port=1234) 349 | -------------------------------------------------------------------------------- /gradio/assets/img_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instantX-research/CSGO/fefec09cf680d9796b92a64282ceb12501b9f977/gradio/assets/img_0.png -------------------------------------------------------------------------------- /gradio/assets/img_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instantX-research/CSGO/fefec09cf680d9796b92a64282ceb12501b9f977/gradio/assets/img_1.png -------------------------------------------------------------------------------- /gradio/assets/img_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instantX-research/CSGO/fefec09cf680d9796b92a64282ceb12501b9f977/gradio/assets/img_2.png -------------------------------------------------------------------------------- /gradio/requirements.txt: -------------------------------------------------------------------------------- 1 | diffusers==0.25.1 2 | torch==2.0.1 3 | torchaudio==2.0.2 4 | torchvision==0.15.2 5 | transformers==4.40.2 6 | accelerate 7 | safetensors 8 | einops 9 | spaces==0.19.4 10 | omegaconf 11 | peft 12 | huggingface-hub==0.24.5 13 | opencv-python 14 | insightface 15 | gradio 16 | controlnet_aux 17 | gdown 18 | peft 19 | -------------------------------------------------------------------------------- /infer/infer_CSGO.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['HF_ENDPOINT']='https://hf-mirror.com' 3 | import torch 4 | from ip_adapter.utils import BLOCKS as BLOCKS 5 | from ip_adapter.utils import controlnet_BLOCKS as controlnet_BLOCKS 6 | from ip_adapter.utils import resize_content 7 | import cv2 8 | from PIL import Image 9 | from transformers import AutoImageProcessor, AutoModel 10 | from diffusers import ( 11 | AutoencoderKL, 12 | ControlNetModel, 13 | StableDiffusionXLControlNetPipeline, 14 | 15 | ) 16 | from ip_adapter import CSGO 17 | from transformers import BlipProcessor, BlipForConditionalGeneration 18 | 19 | device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") 20 | 21 | base_model_path = "../../../base_models/stable-diffusion-xl-base-1.0" 22 | image_encoder_path = "../../../base_models/IP-Adapter/sdxl_models/image_encoder" 23 | csgo_ckpt = "/share2/xingpeng/DATA/blora/outputs/content_style_checkpoints_2/base_train_free_controlnet_S12_alldata_C_0_S_I_zero2_style_res_32_content_res4_drop/checkpoint-220000/ip_adapter.bin" 24 | pretrained_vae_name_or_path ='../../../base_models/sdxl-vae-fp16-fix' 25 | controlnet_path = "../../../base_models/TTPLanet_SDXL_Controlnet_Tile_Realistic" 26 | weight_dtype = torch.float16 27 | 28 | blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large") 29 | blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to(device) 30 | 31 | 32 | vae = AutoencoderKL.from_pretrained(pretrained_vae_name_or_path,torch_dtype=torch.float16) 33 | controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16,use_safetensors=True) 34 | pipe = StableDiffusionXLControlNetPipeline.from_pretrained( 35 | base_model_path, 36 | controlnet=controlnet, 37 | torch_dtype=torch.float16, 38 | add_watermarker=False, 39 | vae=vae 40 | ) 41 | pipe.enable_vae_tiling() 42 | 43 | 44 | target_content_blocks = BLOCKS['content'] 45 | target_style_blocks = BLOCKS['style'] 46 | controlnet_target_content_blocks = controlnet_BLOCKS['content'] 47 | controlnet_target_style_blocks = controlnet_BLOCKS['style'] 48 | 49 | csgo = CSGO(pipe, image_encoder_path, csgo_ckpt, device, num_content_tokens=4,num_style_tokens=32, 50 | target_content_blocks=target_content_blocks, target_style_blocks=target_style_blocks,controlnet_adapter=True, 51 | controlnet_target_content_blocks=controlnet_target_content_blocks, 52 | controlnet_target_style_blocks=controlnet_target_style_blocks, 53 | content_model_resampler=True, 54 | style_model_resampler=True, 55 | ) 56 | 57 | style_name = 'img_1.png' 58 | content_name = 'img_0.png' 59 | style_image = Image.open("../assets/{}".format(style_name)).convert('RGB') 60 | content_image = Image.open('../assets/{}'.format(content_name)).convert('RGB') 61 | 62 | 63 | with torch.no_grad(): 64 | inputs = blip_processor(content_image, return_tensors="pt").to(device) 65 | out = blip_model.generate(**inputs) 66 | caption = blip_processor.decode(out[0], skip_special_tokens=True) 67 | 68 | num_sample=1 69 | 70 | width,height,content_image = resize_content(content_image) 71 | images = csgo.generate(pil_content_image= content_image, pil_style_image=style_image, 72 | prompt=caption, 73 | negative_prompt= "text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry", 74 | height=height, 75 | width=width, 76 | content_scale=0.5, 77 | style_scale=1.0, 78 | guidance_scale=10, 79 | num_images_per_prompt=num_sample, 80 | num_samples=1, 81 | num_inference_steps=50, 82 | seed=42, 83 | image=content_image.convert('RGB'), 84 | controlnet_conditioning_scale=0.6, 85 | ) 86 | images[0].save("../assets/content_img_0_style_imag_1.png") -------------------------------------------------------------------------------- /ip_adapter/__init__.py: -------------------------------------------------------------------------------- 1 | from .ip_adapter import IPAdapter, IPAdapterPlus, IPAdapterPlusXL, IPAdapterXL, IPAdapterFull,IPAdapterXL_CS,IPAdapter_CS 2 | from .ip_adapter import CSGO 3 | __all__ = [ 4 | "IPAdapter", 5 | "IPAdapterPlus", 6 | "IPAdapterPlusXL", 7 | "IPAdapterXL", 8 | "CSGO" 9 | "IPAdapterFull", 10 | ] 11 | -------------------------------------------------------------------------------- /ip_adapter/attention_processor.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class AttnProcessor(nn.Module): 8 | r""" 9 | Default processor for performing attention-related computations. 10 | """ 11 | 12 | def __init__( 13 | self, 14 | hidden_size=None, 15 | cross_attention_dim=None, 16 | save_in_unet='down', 17 | atten_control=None, 18 | ): 19 | super().__init__() 20 | self.atten_control = atten_control 21 | self.save_in_unet = save_in_unet 22 | 23 | def __call__( 24 | self, 25 | attn, 26 | hidden_states, 27 | encoder_hidden_states=None, 28 | attention_mask=None, 29 | temb=None, 30 | ): 31 | residual = hidden_states 32 | 33 | if attn.spatial_norm is not None: 34 | hidden_states = attn.spatial_norm(hidden_states, temb) 35 | 36 | input_ndim = hidden_states.ndim 37 | 38 | if input_ndim == 4: 39 | batch_size, channel, height, width = hidden_states.shape 40 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 41 | 42 | batch_size, sequence_length, _ = ( 43 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 44 | ) 45 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 46 | 47 | if attn.group_norm is not None: 48 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 49 | 50 | query = attn.to_q(hidden_states) 51 | 52 | if encoder_hidden_states is None: 53 | encoder_hidden_states = hidden_states 54 | elif attn.norm_cross: 55 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 56 | 57 | key = attn.to_k(encoder_hidden_states) 58 | value = attn.to_v(encoder_hidden_states) 59 | 60 | query = attn.head_to_batch_dim(query) 61 | key = attn.head_to_batch_dim(key) 62 | value = attn.head_to_batch_dim(value) 63 | 64 | attention_probs = attn.get_attention_scores(query, key, attention_mask) 65 | hidden_states = torch.bmm(attention_probs, value) 66 | hidden_states = attn.batch_to_head_dim(hidden_states) 67 | 68 | # linear proj 69 | hidden_states = attn.to_out[0](hidden_states) 70 | # dropout 71 | hidden_states = attn.to_out[1](hidden_states) 72 | 73 | if input_ndim == 4: 74 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 75 | 76 | if attn.residual_connection: 77 | hidden_states = hidden_states + residual 78 | 79 | hidden_states = hidden_states / attn.rescale_output_factor 80 | 81 | return hidden_states 82 | 83 | 84 | class IPAttnProcessor(nn.Module): 85 | r""" 86 | Attention processor for IP-Adapater. 87 | Args: 88 | hidden_size (`int`): 89 | The hidden size of the attention layer. 90 | cross_attention_dim (`int`): 91 | The number of channels in the `encoder_hidden_states`. 92 | scale (`float`, defaults to 1.0): 93 | the weight scale of image prompt. 94 | num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): 95 | The context length of the image features. 96 | """ 97 | 98 | def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, skip=False,save_in_unet='down', atten_control=None): 99 | super().__init__() 100 | 101 | self.hidden_size = hidden_size 102 | self.cross_attention_dim = cross_attention_dim 103 | self.scale = scale 104 | self.num_tokens = num_tokens 105 | self.skip = skip 106 | 107 | self.atten_control = atten_control 108 | self.save_in_unet = save_in_unet 109 | 110 | self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) 111 | self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) 112 | 113 | def __call__( 114 | self, 115 | attn, 116 | hidden_states, 117 | encoder_hidden_states=None, 118 | attention_mask=None, 119 | temb=None, 120 | ): 121 | residual = hidden_states 122 | 123 | if attn.spatial_norm is not None: 124 | hidden_states = attn.spatial_norm(hidden_states, temb) 125 | 126 | input_ndim = hidden_states.ndim 127 | 128 | if input_ndim == 4: 129 | batch_size, channel, height, width = hidden_states.shape 130 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 131 | 132 | batch_size, sequence_length, _ = ( 133 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 134 | ) 135 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 136 | 137 | if attn.group_norm is not None: 138 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 139 | 140 | query = attn.to_q(hidden_states) 141 | 142 | if encoder_hidden_states is None: 143 | encoder_hidden_states = hidden_states 144 | else: 145 | # get encoder_hidden_states, ip_hidden_states 146 | end_pos = encoder_hidden_states.shape[1] - self.num_tokens 147 | encoder_hidden_states, ip_hidden_states = ( 148 | encoder_hidden_states[:, :end_pos, :], 149 | encoder_hidden_states[:, end_pos:, :], 150 | ) 151 | if attn.norm_cross: 152 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 153 | 154 | key = attn.to_k(encoder_hidden_states) 155 | value = attn.to_v(encoder_hidden_states) 156 | 157 | query = attn.head_to_batch_dim(query) 158 | key = attn.head_to_batch_dim(key) 159 | value = attn.head_to_batch_dim(value) 160 | 161 | attention_probs = attn.get_attention_scores(query, key, attention_mask) 162 | hidden_states = torch.bmm(attention_probs, value) 163 | hidden_states = attn.batch_to_head_dim(hidden_states) 164 | 165 | if not self.skip: 166 | # for ip-adapter 167 | ip_key = self.to_k_ip(ip_hidden_states) 168 | ip_value = self.to_v_ip(ip_hidden_states) 169 | 170 | ip_key = attn.head_to_batch_dim(ip_key) 171 | ip_value = attn.head_to_batch_dim(ip_value) 172 | 173 | ip_attention_probs = attn.get_attention_scores(query, ip_key, None) 174 | self.attn_map = ip_attention_probs 175 | ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) 176 | ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) 177 | 178 | hidden_states = hidden_states + self.scale * ip_hidden_states 179 | 180 | # linear proj 181 | hidden_states = attn.to_out[0](hidden_states) 182 | # dropout 183 | hidden_states = attn.to_out[1](hidden_states) 184 | 185 | if input_ndim == 4: 186 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 187 | 188 | if attn.residual_connection: 189 | hidden_states = hidden_states + residual 190 | 191 | hidden_states = hidden_states / attn.rescale_output_factor 192 | 193 | return hidden_states 194 | 195 | 196 | class AttnProcessor2_0(torch.nn.Module): 197 | r""" 198 | Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). 199 | """ 200 | 201 | def __init__( 202 | self, 203 | hidden_size=None, 204 | cross_attention_dim=None, 205 | save_in_unet='down', 206 | atten_control=None, 207 | ): 208 | super().__init__() 209 | if not hasattr(F, "scaled_dot_product_attention"): 210 | raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") 211 | self.atten_control = atten_control 212 | self.save_in_unet = save_in_unet 213 | 214 | def __call__( 215 | self, 216 | attn, 217 | hidden_states, 218 | encoder_hidden_states=None, 219 | attention_mask=None, 220 | temb=None, 221 | ): 222 | residual = hidden_states 223 | 224 | if attn.spatial_norm is not None: 225 | hidden_states = attn.spatial_norm(hidden_states, temb) 226 | 227 | input_ndim = hidden_states.ndim 228 | 229 | if input_ndim == 4: 230 | batch_size, channel, height, width = hidden_states.shape 231 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 232 | 233 | batch_size, sequence_length, _ = ( 234 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 235 | ) 236 | 237 | if attention_mask is not None: 238 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 239 | # scaled_dot_product_attention expects attention_mask shape to be 240 | # (batch, heads, source_length, target_length) 241 | attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) 242 | 243 | if attn.group_norm is not None: 244 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 245 | 246 | query = attn.to_q(hidden_states) 247 | 248 | if encoder_hidden_states is None: 249 | encoder_hidden_states = hidden_states 250 | elif attn.norm_cross: 251 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 252 | 253 | key = attn.to_k(encoder_hidden_states) 254 | value = attn.to_v(encoder_hidden_states) 255 | 256 | inner_dim = key.shape[-1] 257 | head_dim = inner_dim // attn.heads 258 | 259 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 260 | 261 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 262 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 263 | 264 | # the output of sdp = (batch, num_heads, seq_len, head_dim) 265 | # TODO: add support for attn.scale when we move to Torch 2.1 266 | hidden_states = F.scaled_dot_product_attention( 267 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False 268 | ) 269 | 270 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 271 | hidden_states = hidden_states.to(query.dtype) 272 | 273 | # linear proj 274 | hidden_states = attn.to_out[0](hidden_states) 275 | # dropout 276 | hidden_states = attn.to_out[1](hidden_states) 277 | 278 | if input_ndim == 4: 279 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 280 | 281 | if attn.residual_connection: 282 | hidden_states = hidden_states + residual 283 | 284 | hidden_states = hidden_states / attn.rescale_output_factor 285 | 286 | return hidden_states 287 | 288 | 289 | class IPAttnProcessor2_0(torch.nn.Module): 290 | r""" 291 | Attention processor for IP-Adapater for PyTorch 2.0. 292 | Args: 293 | hidden_size (`int`): 294 | The hidden size of the attention layer. 295 | cross_attention_dim (`int`): 296 | The number of channels in the `encoder_hidden_states`. 297 | scale (`float`, defaults to 1.0): 298 | the weight scale of image prompt. 299 | num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): 300 | The context length of the image features. 301 | """ 302 | 303 | def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, skip=False,save_in_unet='down', atten_control=None): 304 | super().__init__() 305 | 306 | if not hasattr(F, "scaled_dot_product_attention"): 307 | raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") 308 | 309 | self.hidden_size = hidden_size 310 | self.cross_attention_dim = cross_attention_dim 311 | self.scale = scale 312 | self.num_tokens = num_tokens 313 | self.skip = skip 314 | 315 | self.atten_control = atten_control 316 | self.save_in_unet = save_in_unet 317 | 318 | self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) 319 | self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) 320 | 321 | def __call__( 322 | self, 323 | attn, 324 | hidden_states, 325 | encoder_hidden_states=None, 326 | attention_mask=None, 327 | temb=None, 328 | ): 329 | residual = hidden_states 330 | 331 | if attn.spatial_norm is not None: 332 | hidden_states = attn.spatial_norm(hidden_states, temb) 333 | 334 | input_ndim = hidden_states.ndim 335 | 336 | if input_ndim == 4: 337 | batch_size, channel, height, width = hidden_states.shape 338 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 339 | 340 | batch_size, sequence_length, _ = ( 341 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 342 | ) 343 | 344 | if attention_mask is not None: 345 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 346 | # scaled_dot_product_attention expects attention_mask shape to be 347 | # (batch, heads, source_length, target_length) 348 | attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) 349 | 350 | if attn.group_norm is not None: 351 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 352 | 353 | query = attn.to_q(hidden_states) 354 | 355 | if encoder_hidden_states is None: 356 | encoder_hidden_states = hidden_states 357 | else: 358 | # get encoder_hidden_states, ip_hidden_states 359 | end_pos = encoder_hidden_states.shape[1] - self.num_tokens 360 | encoder_hidden_states, ip_hidden_states = ( 361 | encoder_hidden_states[:, :end_pos, :], 362 | encoder_hidden_states[:, end_pos:, :], 363 | ) 364 | if attn.norm_cross: 365 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 366 | 367 | key = attn.to_k(encoder_hidden_states) 368 | value = attn.to_v(encoder_hidden_states) 369 | 370 | inner_dim = key.shape[-1] 371 | head_dim = inner_dim // attn.heads 372 | 373 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 374 | 375 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 376 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 377 | 378 | # the output of sdp = (batch, num_heads, seq_len, head_dim) 379 | # TODO: add support for attn.scale when we move to Torch 2.1 380 | hidden_states = F.scaled_dot_product_attention( 381 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False 382 | ) 383 | 384 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 385 | hidden_states = hidden_states.to(query.dtype) 386 | 387 | if not self.skip: 388 | # for ip-adapter 389 | ip_key = self.to_k_ip(ip_hidden_states) 390 | ip_value = self.to_v_ip(ip_hidden_states) 391 | 392 | ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 393 | ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 394 | 395 | # the output of sdp = (batch, num_heads, seq_len, head_dim) 396 | # TODO: add support for attn.scale when we move to Torch 2.1 397 | ip_hidden_states = F.scaled_dot_product_attention( 398 | query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False 399 | ) 400 | with torch.no_grad(): 401 | self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1) 402 | #print(self.attn_map.shape) 403 | 404 | ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 405 | ip_hidden_states = ip_hidden_states.to(query.dtype) 406 | 407 | hidden_states = hidden_states + self.scale * ip_hidden_states 408 | 409 | # linear proj 410 | hidden_states = attn.to_out[0](hidden_states) 411 | # dropout 412 | hidden_states = attn.to_out[1](hidden_states) 413 | 414 | if input_ndim == 4: 415 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 416 | 417 | if attn.residual_connection: 418 | hidden_states = hidden_states + residual 419 | 420 | hidden_states = hidden_states / attn.rescale_output_factor 421 | 422 | return hidden_states 423 | 424 | 425 | class IP_CS_AttnProcessor2_0(torch.nn.Module): 426 | r""" 427 | Attention processor for IP-Adapater for PyTorch 2.0. 428 | Args: 429 | hidden_size (`int`): 430 | The hidden size of the attention layer. 431 | cross_attention_dim (`int`): 432 | The number of channels in the `encoder_hidden_states`. 433 | scale (`float`, defaults to 1.0): 434 | the weight scale of image prompt. 435 | num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): 436 | The context length of the image features. 437 | """ 438 | 439 | def __init__(self, hidden_size, cross_attention_dim=None, content_scale=1.0,style_scale=1.0, num_content_tokens=4,num_style_tokens=4, 440 | skip=False,content=False, style=False): 441 | super().__init__() 442 | 443 | if not hasattr(F, "scaled_dot_product_attention"): 444 | raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") 445 | 446 | self.hidden_size = hidden_size 447 | self.cross_attention_dim = cross_attention_dim 448 | self.content_scale = content_scale 449 | self.style_scale = style_scale 450 | self.num_content_tokens = num_content_tokens 451 | self.num_style_tokens = num_style_tokens 452 | self.skip = skip 453 | 454 | self.content = content 455 | self.style = style 456 | 457 | if self.content or self.style: 458 | self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) 459 | self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) 460 | self.to_k_ip_content =None 461 | self.to_v_ip_content =None 462 | 463 | def set_content_ipa(self,content_scale=1.0): 464 | 465 | self.to_k_ip_content = nn.Linear(self.cross_attention_dim or self.hidden_size, self.hidden_size, bias=False) 466 | self.to_v_ip_content = nn.Linear(self.cross_attention_dim or self.hidden_size, self.hidden_size, bias=False) 467 | self.content_scale=content_scale 468 | self.content =True 469 | 470 | def __call__( 471 | self, 472 | attn, 473 | hidden_states, 474 | encoder_hidden_states=None, 475 | attention_mask=None, 476 | temb=None, 477 | ): 478 | residual = hidden_states 479 | 480 | if attn.spatial_norm is not None: 481 | hidden_states = attn.spatial_norm(hidden_states, temb) 482 | 483 | input_ndim = hidden_states.ndim 484 | 485 | if input_ndim == 4: 486 | batch_size, channel, height, width = hidden_states.shape 487 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 488 | 489 | batch_size, sequence_length, _ = ( 490 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 491 | ) 492 | 493 | if attention_mask is not None: 494 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 495 | # scaled_dot_product_attention expects attention_mask shape to be 496 | # (batch, heads, source_length, target_length) 497 | attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) 498 | 499 | if attn.group_norm is not None: 500 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 501 | 502 | query = attn.to_q(hidden_states) 503 | 504 | if encoder_hidden_states is None: 505 | encoder_hidden_states = hidden_states 506 | else: 507 | # get encoder_hidden_states, ip_hidden_states 508 | end_pos = encoder_hidden_states.shape[1] - self.num_content_tokens-self.num_style_tokens 509 | encoder_hidden_states, ip_content_hidden_states,ip_style_hidden_states = ( 510 | encoder_hidden_states[:, :end_pos, :], 511 | encoder_hidden_states[:, end_pos:end_pos + self.num_content_tokens, :], 512 | encoder_hidden_states[:, end_pos + self.num_content_tokens:, :], 513 | ) 514 | if attn.norm_cross: 515 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 516 | 517 | key = attn.to_k(encoder_hidden_states) 518 | value = attn.to_v(encoder_hidden_states) 519 | 520 | inner_dim = key.shape[-1] 521 | head_dim = inner_dim // attn.heads 522 | 523 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 524 | 525 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 526 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 527 | 528 | # the output of sdp = (batch, num_heads, seq_len, head_dim) 529 | # TODO: add support for attn.scale when we move to Torch 2.1 530 | hidden_states = F.scaled_dot_product_attention( 531 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False 532 | ) 533 | 534 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 535 | hidden_states = hidden_states.to(query.dtype) 536 | 537 | if not self.skip and self.content is True: 538 | # print('content#####################################################') 539 | # for ip-content-adapter 540 | if self.to_k_ip_content is None: 541 | 542 | ip_content_key = self.to_k_ip(ip_content_hidden_states) 543 | ip_content_value = self.to_v_ip(ip_content_hidden_states) 544 | else: 545 | ip_content_key = self.to_k_ip_content(ip_content_hidden_states) 546 | ip_content_value = self.to_v_ip_content(ip_content_hidden_states) 547 | 548 | ip_content_key = ip_content_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 549 | ip_content_value = ip_content_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 550 | 551 | # the output of sdp = (batch, num_heads, seq_len, head_dim) 552 | # TODO: add support for attn.scale when we move to Torch 2.1 553 | ip_content_hidden_states = F.scaled_dot_product_attention( 554 | query, ip_content_key, ip_content_value, attn_mask=None, dropout_p=0.0, is_causal=False 555 | ) 556 | 557 | 558 | ip_content_hidden_states = ip_content_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 559 | ip_content_hidden_states = ip_content_hidden_states.to(query.dtype) 560 | 561 | 562 | hidden_states = hidden_states + self.content_scale * ip_content_hidden_states 563 | 564 | if not self.skip and self.style is True: 565 | # for ip-style-adapter 566 | ip_style_key = self.to_k_ip(ip_style_hidden_states) 567 | ip_style_value = self.to_v_ip(ip_style_hidden_states) 568 | 569 | ip_style_key = ip_style_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 570 | ip_style_value = ip_style_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 571 | 572 | # the output of sdp = (batch, num_heads, seq_len, head_dim) 573 | # TODO: add support for attn.scale when we move to Torch 2.1 574 | ip_style_hidden_states = F.scaled_dot_product_attention( 575 | query, ip_style_key, ip_style_value, attn_mask=None, dropout_p=0.0, is_causal=False 576 | ) 577 | 578 | ip_style_hidden_states = ip_style_hidden_states.transpose(1, 2).reshape(batch_size, -1, 579 | attn.heads * head_dim) 580 | ip_style_hidden_states = ip_style_hidden_states.to(query.dtype) 581 | 582 | hidden_states = hidden_states + self.style_scale * ip_style_hidden_states 583 | 584 | # linear proj 585 | hidden_states = attn.to_out[0](hidden_states) 586 | # dropout 587 | hidden_states = attn.to_out[1](hidden_states) 588 | 589 | if input_ndim == 4: 590 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 591 | 592 | if attn.residual_connection: 593 | hidden_states = hidden_states + residual 594 | 595 | hidden_states = hidden_states / attn.rescale_output_factor 596 | 597 | return hidden_states 598 | 599 | ## for controlnet 600 | class CNAttnProcessor: 601 | r""" 602 | Default processor for performing attention-related computations. 603 | """ 604 | 605 | def __init__(self, num_tokens=4,save_in_unet='down',atten_control=None): 606 | self.num_tokens = num_tokens 607 | self.atten_control = atten_control 608 | self.save_in_unet = save_in_unet 609 | 610 | def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None): 611 | residual = hidden_states 612 | 613 | if attn.spatial_norm is not None: 614 | hidden_states = attn.spatial_norm(hidden_states, temb) 615 | 616 | input_ndim = hidden_states.ndim 617 | 618 | if input_ndim == 4: 619 | batch_size, channel, height, width = hidden_states.shape 620 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 621 | 622 | batch_size, sequence_length, _ = ( 623 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 624 | ) 625 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 626 | 627 | if attn.group_norm is not None: 628 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 629 | 630 | query = attn.to_q(hidden_states) 631 | 632 | if encoder_hidden_states is None: 633 | encoder_hidden_states = hidden_states 634 | else: 635 | end_pos = encoder_hidden_states.shape[1] - self.num_tokens 636 | encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text 637 | if attn.norm_cross: 638 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 639 | 640 | key = attn.to_k(encoder_hidden_states) 641 | value = attn.to_v(encoder_hidden_states) 642 | 643 | query = attn.head_to_batch_dim(query) 644 | key = attn.head_to_batch_dim(key) 645 | value = attn.head_to_batch_dim(value) 646 | 647 | attention_probs = attn.get_attention_scores(query, key, attention_mask) 648 | hidden_states = torch.bmm(attention_probs, value) 649 | hidden_states = attn.batch_to_head_dim(hidden_states) 650 | 651 | # linear proj 652 | hidden_states = attn.to_out[0](hidden_states) 653 | # dropout 654 | hidden_states = attn.to_out[1](hidden_states) 655 | 656 | if input_ndim == 4: 657 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 658 | 659 | if attn.residual_connection: 660 | hidden_states = hidden_states + residual 661 | 662 | hidden_states = hidden_states / attn.rescale_output_factor 663 | 664 | return hidden_states 665 | 666 | 667 | class CNAttnProcessor2_0: 668 | r""" 669 | Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). 670 | """ 671 | 672 | def __init__(self, num_tokens=4, save_in_unet='down', atten_control=None): 673 | if not hasattr(F, "scaled_dot_product_attention"): 674 | raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") 675 | self.num_tokens = num_tokens 676 | self.atten_control = atten_control 677 | self.save_in_unet = save_in_unet 678 | 679 | def __call__( 680 | self, 681 | attn, 682 | hidden_states, 683 | encoder_hidden_states=None, 684 | attention_mask=None, 685 | temb=None, 686 | ): 687 | residual = hidden_states 688 | 689 | if attn.spatial_norm is not None: 690 | hidden_states = attn.spatial_norm(hidden_states, temb) 691 | 692 | input_ndim = hidden_states.ndim 693 | 694 | if input_ndim == 4: 695 | batch_size, channel, height, width = hidden_states.shape 696 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 697 | 698 | batch_size, sequence_length, _ = ( 699 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 700 | ) 701 | 702 | if attention_mask is not None: 703 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 704 | # scaled_dot_product_attention expects attention_mask shape to be 705 | # (batch, heads, source_length, target_length) 706 | attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) 707 | 708 | if attn.group_norm is not None: 709 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 710 | 711 | query = attn.to_q(hidden_states) 712 | 713 | if encoder_hidden_states is None: 714 | encoder_hidden_states = hidden_states 715 | else: 716 | end_pos = encoder_hidden_states.shape[1] - self.num_tokens 717 | encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text 718 | if attn.norm_cross: 719 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 720 | 721 | key = attn.to_k(encoder_hidden_states) 722 | value = attn.to_v(encoder_hidden_states) 723 | 724 | inner_dim = key.shape[-1] 725 | head_dim = inner_dim // attn.heads 726 | 727 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 728 | 729 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 730 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 731 | 732 | # the output of sdp = (batch, num_heads, seq_len, head_dim) 733 | # TODO: add support for attn.scale when we move to Torch 2.1 734 | hidden_states = F.scaled_dot_product_attention( 735 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False 736 | ) 737 | 738 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 739 | hidden_states = hidden_states.to(query.dtype) 740 | 741 | # linear proj 742 | hidden_states = attn.to_out[0](hidden_states) 743 | # dropout 744 | hidden_states = attn.to_out[1](hidden_states) 745 | 746 | if input_ndim == 4: 747 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 748 | 749 | if attn.residual_connection: 750 | hidden_states = hidden_states + residual 751 | 752 | hidden_states = hidden_states / attn.rescale_output_factor 753 | 754 | return hidden_states 755 | -------------------------------------------------------------------------------- /ip_adapter/ip_adapter.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List 3 | 4 | import torch 5 | from diffusers import StableDiffusionPipeline 6 | from diffusers.pipelines.controlnet import MultiControlNetModel 7 | from PIL import Image 8 | from safetensors import safe_open 9 | from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection 10 | from torchvision import transforms 11 | from .utils import is_torch2_available, get_generator 12 | 13 | # import torchvision.transforms.functional as Func 14 | 15 | # from .clip_style_models import CSD_CLIP, convert_state_dict 16 | 17 | if is_torch2_available(): 18 | from .attention_processor import ( 19 | AttnProcessor2_0 as AttnProcessor, 20 | ) 21 | from .attention_processor import ( 22 | CNAttnProcessor2_0 as CNAttnProcessor, 23 | ) 24 | from .attention_processor import ( 25 | IPAttnProcessor2_0 as IPAttnProcessor, 26 | ) 27 | from .attention_processor import IP_CS_AttnProcessor2_0 as IP_CS_AttnProcessor 28 | else: 29 | from .attention_processor import AttnProcessor, CNAttnProcessor, IPAttnProcessor 30 | from .resampler import Resampler 31 | 32 | from transformers import AutoImageProcessor, AutoModel 33 | 34 | 35 | class ImageProjModel(torch.nn.Module): 36 | """Projection Model""" 37 | 38 | def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): 39 | super().__init__() 40 | 41 | self.generator = None 42 | self.cross_attention_dim = cross_attention_dim 43 | self.clip_extra_context_tokens = clip_extra_context_tokens 44 | # print(clip_embeddings_dim, self.clip_extra_context_tokens, cross_attention_dim) 45 | self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) 46 | self.norm = torch.nn.LayerNorm(cross_attention_dim) 47 | 48 | def forward(self, image_embeds): 49 | embeds = image_embeds 50 | clip_extra_context_tokens = self.proj(embeds).reshape( 51 | -1, self.clip_extra_context_tokens, self.cross_attention_dim 52 | ) 53 | clip_extra_context_tokens = self.norm(clip_extra_context_tokens) 54 | return clip_extra_context_tokens 55 | 56 | 57 | class MLPProjModel(torch.nn.Module): 58 | """SD model with image prompt""" 59 | 60 | def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024): 61 | super().__init__() 62 | 63 | self.proj = torch.nn.Sequential( 64 | torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim), 65 | torch.nn.GELU(), 66 | torch.nn.Linear(clip_embeddings_dim, cross_attention_dim), 67 | torch.nn.LayerNorm(cross_attention_dim) 68 | ) 69 | 70 | def forward(self, image_embeds): 71 | clip_extra_context_tokens = self.proj(image_embeds) 72 | return clip_extra_context_tokens 73 | 74 | 75 | class IPAdapter: 76 | def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4, target_blocks=["block"]): 77 | self.device = device 78 | self.image_encoder_path = image_encoder_path 79 | self.ip_ckpt = ip_ckpt 80 | self.num_tokens = num_tokens 81 | self.target_blocks = target_blocks 82 | 83 | self.pipe = sd_pipe.to(self.device) 84 | self.set_ip_adapter() 85 | 86 | # load image encoder 87 | self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to( 88 | self.device, dtype=torch.float16 89 | ) 90 | self.clip_image_processor = CLIPImageProcessor() 91 | # image proj model 92 | self.image_proj_model = self.init_proj() 93 | 94 | self.load_ip_adapter() 95 | 96 | def init_proj(self): 97 | image_proj_model = ImageProjModel( 98 | cross_attention_dim=self.pipe.unet.config.cross_attention_dim, 99 | clip_embeddings_dim=self.image_encoder.config.projection_dim, 100 | clip_extra_context_tokens=self.num_tokens, 101 | ).to(self.device, dtype=torch.float16) 102 | return image_proj_model 103 | 104 | def set_ip_adapter(self): 105 | unet = self.pipe.unet 106 | attn_procs = {} 107 | for name in unet.attn_processors.keys(): 108 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim 109 | if name.startswith("mid_block"): 110 | hidden_size = unet.config.block_out_channels[-1] 111 | elif name.startswith("up_blocks"): 112 | block_id = int(name[len("up_blocks.")]) 113 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id] 114 | elif name.startswith("down_blocks"): 115 | block_id = int(name[len("down_blocks.")]) 116 | hidden_size = unet.config.block_out_channels[block_id] 117 | if cross_attention_dim is None: 118 | attn_procs[name] = AttnProcessor() 119 | else: 120 | selected = False 121 | for block_name in self.target_blocks: 122 | if block_name in name: 123 | selected = True 124 | break 125 | if selected: 126 | attn_procs[name] = IPAttnProcessor( 127 | hidden_size=hidden_size, 128 | cross_attention_dim=cross_attention_dim, 129 | scale=1.0, 130 | num_tokens=self.num_tokens, 131 | ).to(self.device, dtype=torch.float16) 132 | else: 133 | attn_procs[name] = IPAttnProcessor( 134 | hidden_size=hidden_size, 135 | cross_attention_dim=cross_attention_dim, 136 | scale=1.0, 137 | num_tokens=self.num_tokens, 138 | skip=True 139 | ).to(self.device, dtype=torch.float16) 140 | unet.set_attn_processor(attn_procs) 141 | if hasattr(self.pipe, "controlnet"): 142 | if isinstance(self.pipe.controlnet, MultiControlNetModel): 143 | for controlnet in self.pipe.controlnet.nets: 144 | controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens)) 145 | else: 146 | self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens)) 147 | 148 | def load_ip_adapter(self): 149 | if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors": 150 | state_dict = {"image_proj": {}, "ip_adapter": {}} 151 | with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f: 152 | for key in f.keys(): 153 | if key.startswith("image_proj."): 154 | state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) 155 | elif key.startswith("ip_adapter."): 156 | state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) 157 | else: 158 | state_dict = torch.load(self.ip_ckpt, map_location="cpu") 159 | self.image_proj_model.load_state_dict(state_dict["image_proj"]) 160 | ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) 161 | ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False) 162 | 163 | @torch.inference_mode() 164 | def get_image_embeds(self, pil_image=None, clip_image_embeds=None, content_prompt_embeds=None): 165 | if pil_image is not None: 166 | if isinstance(pil_image, Image.Image): 167 | pil_image = [pil_image] 168 | clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values 169 | clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds 170 | else: 171 | clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16) 172 | 173 | if content_prompt_embeds is not None: 174 | clip_image_embeds = clip_image_embeds - content_prompt_embeds 175 | 176 | image_prompt_embeds = self.image_proj_model(clip_image_embeds) 177 | uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds)) 178 | return image_prompt_embeds, uncond_image_prompt_embeds 179 | 180 | def set_scale(self, scale): 181 | for attn_processor in self.pipe.unet.attn_processors.values(): 182 | if isinstance(attn_processor, IPAttnProcessor): 183 | attn_processor.scale = scale 184 | 185 | def generate( 186 | self, 187 | pil_image=None, 188 | clip_image_embeds=None, 189 | prompt=None, 190 | negative_prompt=None, 191 | scale=1.0, 192 | num_samples=4, 193 | seed=None, 194 | guidance_scale=7.5, 195 | num_inference_steps=30, 196 | neg_content_emb=None, 197 | **kwargs, 198 | ): 199 | self.set_scale(scale) 200 | 201 | if pil_image is not None: 202 | num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) 203 | else: 204 | num_prompts = clip_image_embeds.size(0) 205 | 206 | if prompt is None: 207 | prompt = "best quality, high quality" 208 | if negative_prompt is None: 209 | negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" 210 | 211 | if not isinstance(prompt, List): 212 | prompt = [prompt] * num_prompts 213 | if not isinstance(negative_prompt, List): 214 | negative_prompt = [negative_prompt] * num_prompts 215 | 216 | image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds( 217 | pil_image=pil_image, clip_image_embeds=clip_image_embeds, content_prompt_embeds=neg_content_emb 218 | ) 219 | bs_embed, seq_len, _ = image_prompt_embeds.shape 220 | image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) 221 | image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 222 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) 223 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 224 | 225 | with torch.inference_mode(): 226 | prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt( 227 | prompt, 228 | device=self.device, 229 | num_images_per_prompt=num_samples, 230 | do_classifier_free_guidance=True, 231 | negative_prompt=negative_prompt, 232 | ) 233 | prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) 234 | negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) 235 | 236 | generator = get_generator(seed, self.device) 237 | 238 | images = self.pipe( 239 | prompt_embeds=prompt_embeds, 240 | negative_prompt_embeds=negative_prompt_embeds, 241 | guidance_scale=guidance_scale, 242 | num_inference_steps=num_inference_steps, 243 | generator=generator, 244 | **kwargs, 245 | ).images 246 | 247 | return images 248 | 249 | 250 | class IPAdapter_CS: 251 | def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_content_tokens=4, 252 | num_style_tokens=4, 253 | target_content_blocks=["block"], target_style_blocks=["block"], content_image_encoder_path=None, 254 | controlnet_adapter=False, 255 | controlnet_target_content_blocks=None, 256 | controlnet_target_style_blocks=None, 257 | content_model_resampler=False, 258 | style_model_resampler=False, 259 | ): 260 | self.device = device 261 | self.image_encoder_path = image_encoder_path 262 | self.ip_ckpt = ip_ckpt 263 | self.num_content_tokens = num_content_tokens 264 | self.num_style_tokens = num_style_tokens 265 | self.content_target_blocks = target_content_blocks 266 | self.style_target_blocks = target_style_blocks 267 | 268 | self.content_model_resampler = content_model_resampler 269 | self.style_model_resampler = style_model_resampler 270 | 271 | self.controlnet_adapter = controlnet_adapter 272 | self.controlnet_target_content_blocks = controlnet_target_content_blocks 273 | self.controlnet_target_style_blocks = controlnet_target_style_blocks 274 | 275 | self.pipe = sd_pipe.to(self.device) 276 | self.set_ip_adapter() 277 | self.content_image_encoder_path = content_image_encoder_path 278 | 279 | 280 | # load image encoder 281 | if content_image_encoder_path is not None: 282 | self.content_image_encoder = AutoModel.from_pretrained(content_image_encoder_path).to(self.device, 283 | dtype=torch.float16) 284 | self.content_image_processor = AutoImageProcessor.from_pretrained(content_image_encoder_path) 285 | else: 286 | self.content_image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to( 287 | self.device, dtype=torch.float16 288 | ) 289 | self.content_image_processor = CLIPImageProcessor() 290 | # model.requires_grad_(False) 291 | 292 | self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to( 293 | self.device, dtype=torch.float16 294 | ) 295 | # if self.use_CSD is not None: 296 | # self.style_image_encoder = CSD_CLIP("vit_large", "default",self.use_CSD+"/ViT-L-14.pt") 297 | # model_path = self.use_CSD+"/checkpoint.pth" 298 | # checkpoint = torch.load(model_path, map_location="cpu") 299 | # state_dict = convert_state_dict(checkpoint['model_state_dict']) 300 | # self.style_image_encoder.load_state_dict(state_dict, strict=False) 301 | # 302 | # normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) 303 | # self.style_preprocess = transforms.Compose([ 304 | # transforms.Resize(size=224, interpolation=Func.InterpolationMode.BICUBIC), 305 | # transforms.CenterCrop(224), 306 | # transforms.ToTensor(), 307 | # normalize, 308 | # ]) 309 | 310 | self.clip_image_processor = CLIPImageProcessor() 311 | # image proj model 312 | self.content_image_proj_model = self.init_proj(self.num_content_tokens, content_or_style_='content', 313 | model_resampler=self.content_model_resampler) 314 | self.style_image_proj_model = self.init_proj(self.num_style_tokens, content_or_style_='style', 315 | model_resampler=self.style_model_resampler) 316 | 317 | self.load_ip_adapter() 318 | 319 | def init_proj(self, num_tokens, content_or_style_='content', model_resampler=False): 320 | 321 | # print('@@@@',self.pipe.unet.config.cross_attention_dim,self.image_encoder.config.projection_dim) 322 | if content_or_style_ == 'content' and self.content_image_encoder_path is not None: 323 | image_proj_model = ImageProjModel( 324 | cross_attention_dim=self.pipe.unet.config.cross_attention_dim, 325 | clip_embeddings_dim=self.content_image_encoder.config.projection_dim, 326 | clip_extra_context_tokens=num_tokens, 327 | ).to(self.device, dtype=torch.float16) 328 | return image_proj_model 329 | 330 | image_proj_model = ImageProjModel( 331 | cross_attention_dim=self.pipe.unet.config.cross_attention_dim, 332 | clip_embeddings_dim=self.image_encoder.config.projection_dim, 333 | clip_extra_context_tokens=num_tokens, 334 | ).to(self.device, dtype=torch.float16) 335 | return image_proj_model 336 | 337 | def set_ip_adapter(self): 338 | unet = self.pipe.unet 339 | attn_procs = {} 340 | for name in unet.attn_processors.keys(): 341 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim 342 | if name.startswith("mid_block"): 343 | hidden_size = unet.config.block_out_channels[-1] 344 | elif name.startswith("up_blocks"): 345 | block_id = int(name[len("up_blocks.")]) 346 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id] 347 | elif name.startswith("down_blocks"): 348 | block_id = int(name[len("down_blocks.")]) 349 | hidden_size = unet.config.block_out_channels[block_id] 350 | if cross_attention_dim is None: 351 | attn_procs[name] = AttnProcessor() 352 | else: 353 | # layername_id += 1 354 | selected = False 355 | for block_name in self.style_target_blocks: 356 | if block_name in name: 357 | selected = True 358 | # print(name) 359 | attn_procs[name] = IP_CS_AttnProcessor( 360 | hidden_size=hidden_size, 361 | cross_attention_dim=cross_attention_dim, 362 | style_scale=1.0, 363 | style=True, 364 | num_content_tokens=self.num_content_tokens, 365 | num_style_tokens=self.num_style_tokens, 366 | ) 367 | for block_name in self.content_target_blocks: 368 | if block_name in name: 369 | # selected = True 370 | if selected is False: 371 | attn_procs[name] = IP_CS_AttnProcessor( 372 | hidden_size=hidden_size, 373 | cross_attention_dim=cross_attention_dim, 374 | content_scale=1.0, 375 | content=True, 376 | num_content_tokens=self.num_content_tokens, 377 | num_style_tokens=self.num_style_tokens, 378 | ) 379 | else: 380 | attn_procs[name].set_content_ipa(content_scale=1.0) 381 | # attn_procs[name].content=True 382 | 383 | if selected is False: 384 | attn_procs[name] = IP_CS_AttnProcessor( 385 | hidden_size=hidden_size, 386 | cross_attention_dim=cross_attention_dim, 387 | num_content_tokens=self.num_content_tokens, 388 | num_style_tokens=self.num_style_tokens, 389 | skip=True, 390 | ) 391 | 392 | attn_procs[name].to(self.device, dtype=torch.float16) 393 | unet.set_attn_processor(attn_procs) 394 | if hasattr(self.pipe, "controlnet"): 395 | if self.controlnet_adapter is False: 396 | if isinstance(self.pipe.controlnet, MultiControlNetModel): 397 | for controlnet in self.pipe.controlnet.nets: 398 | controlnet.set_attn_processor(CNAttnProcessor( 399 | num_tokens=self.num_content_tokens + self.num_style_tokens)) 400 | else: 401 | self.pipe.controlnet.set_attn_processor(CNAttnProcessor( 402 | num_tokens=self.num_content_tokens + self.num_style_tokens)) 403 | 404 | else: 405 | controlnet_attn_procs = {} 406 | controlnet_style_target_blocks = self.controlnet_target_style_blocks 407 | controlnet_content_target_blocks = self.controlnet_target_content_blocks 408 | for name in self.pipe.controlnet.attn_processors.keys(): 409 | # print(name) 410 | cross_attention_dim = None if name.endswith( 411 | "attn1.processor") else self.pipe.controlnet.config.cross_attention_dim 412 | if name.startswith("mid_block"): 413 | hidden_size = self.pipe.controlnet.config.block_out_channels[-1] 414 | elif name.startswith("up_blocks"): 415 | block_id = int(name[len("up_blocks.")]) 416 | hidden_size = list(reversed(self.pipe.controlnet.config.block_out_channels))[block_id] 417 | elif name.startswith("down_blocks"): 418 | block_id = int(name[len("down_blocks.")]) 419 | hidden_size = self.pipe.controlnet.config.block_out_channels[block_id] 420 | if cross_attention_dim is None: 421 | # layername_id += 1 422 | controlnet_attn_procs[name] = AttnProcessor() 423 | 424 | else: 425 | # layername_id += 1 426 | selected = False 427 | for block_name in controlnet_style_target_blocks: 428 | if block_name in name: 429 | selected = True 430 | # print(name) 431 | controlnet_attn_procs[name] = IP_CS_AttnProcessor( 432 | hidden_size=hidden_size, 433 | cross_attention_dim=cross_attention_dim, 434 | style_scale=1.0, 435 | style=True, 436 | num_content_tokens=self.num_content_tokens, 437 | num_style_tokens=self.num_style_tokens, 438 | ) 439 | 440 | for block_name in controlnet_content_target_blocks: 441 | if block_name in name: 442 | if selected is False: 443 | controlnet_attn_procs[name] = IP_CS_AttnProcessor( 444 | hidden_size=hidden_size, 445 | cross_attention_dim=cross_attention_dim, 446 | content_scale=1.0, 447 | content=True, 448 | num_content_tokens=self.num_content_tokens, 449 | num_style_tokens=self.num_style_tokens, 450 | ) 451 | 452 | selected = True 453 | elif selected is True: 454 | controlnet_attn_procs[name].set_content_ipa(content_scale=1.0) 455 | 456 | # if args.content_image_encoder_type !='dinov2': 457 | # weights = { 458 | # "to_k_ip.weight": state_dict["ip_adapter"][str(layername_id) + ".to_k_ip.weight"], 459 | # "to_v_ip.weight": state_dict["ip_adapter"][str(layername_id) + ".to_v_ip.weight"], 460 | # } 461 | # attn_procs[name].load_state_dict(weights) 462 | if selected is False: 463 | controlnet_attn_procs[name] = IP_CS_AttnProcessor( 464 | hidden_size=hidden_size, 465 | cross_attention_dim=cross_attention_dim, 466 | num_content_tokens=self.num_content_tokens, 467 | num_style_tokens=self.num_style_tokens, 468 | skip=True, 469 | ) 470 | controlnet_attn_procs[name].to(self.device, dtype=torch.float16) 471 | # layer_name = name.split(".processor")[0] 472 | # # print(state_dict["ip_adapter"].keys()) 473 | # weights = { 474 | # "to_k_ip.weight": state_dict["ip_adapter"][str(layername_id) + ".to_k_ip.weight"], 475 | # "to_v_ip.weight": state_dict["ip_adapter"][str(layername_id) + ".to_v_ip.weight"], 476 | # } 477 | # attn_procs[name].load_state_dict(weights) 478 | self.pipe.controlnet.set_attn_processor(controlnet_attn_procs) 479 | 480 | def load_ip_adapter(self): 481 | if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors": 482 | state_dict = {"content_image_proj": {}, "style_image_proj": {}, "ip_adapter": {}} 483 | with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f: 484 | for key in f.keys(): 485 | if key.startswith("content_image_proj."): 486 | state_dict["content_image_proj"][key.replace("content_image_proj.", "")] = f.get_tensor(key) 487 | elif key.startswith("style_image_proj."): 488 | state_dict["style_image_proj"][key.replace("style_image_proj.", "")] = f.get_tensor(key) 489 | elif key.startswith("ip_adapter."): 490 | state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) 491 | else: 492 | state_dict = torch.load(self.ip_ckpt, map_location="cpu") 493 | self.content_image_proj_model.load_state_dict(state_dict["content_image_proj"]) 494 | self.style_image_proj_model.load_state_dict(state_dict["style_image_proj"]) 495 | 496 | if 'conv_in_unet_sd' in state_dict.keys(): 497 | self.pipe.unet.conv_in.load_state_dict(state_dict["conv_in_unet_sd"], strict=True) 498 | ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) 499 | ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False) 500 | 501 | if self.controlnet_adapter is True: 502 | print('loading controlnet_adapter') 503 | self.pipe.controlnet.load_state_dict(state_dict["controlnet_adapter_modules"], strict=False) 504 | 505 | @torch.inference_mode() 506 | def get_image_embeds(self, pil_image=None, clip_image_embeds=None, content_prompt_embeds=None, 507 | content_or_style_=''): 508 | # if pil_image is not None: 509 | # if isinstance(pil_image, Image.Image): 510 | # pil_image = [pil_image] 511 | # clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values 512 | # clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds 513 | # else: 514 | # clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16) 515 | 516 | # if content_prompt_embeds is not None: 517 | # clip_image_embeds = clip_image_embeds - content_prompt_embeds 518 | 519 | if content_or_style_ == 'content': 520 | if pil_image is not None: 521 | if isinstance(pil_image, Image.Image): 522 | pil_image = [pil_image] 523 | if self.content_image_proj_model is not None: 524 | clip_image = self.content_image_processor(images=pil_image, return_tensors="pt").pixel_values 525 | clip_image_embeds = self.content_image_encoder( 526 | clip_image.to(self.device, dtype=torch.float16)).image_embeds 527 | else: 528 | clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values 529 | clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds 530 | else: 531 | clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16) 532 | 533 | image_prompt_embeds = self.content_image_proj_model(clip_image_embeds) 534 | uncond_image_prompt_embeds = self.content_image_proj_model(torch.zeros_like(clip_image_embeds)) 535 | return image_prompt_embeds, uncond_image_prompt_embeds 536 | if content_or_style_ == 'style': 537 | if pil_image is not None: 538 | if self.use_CSD is not None: 539 | clip_image = self.style_preprocess(pil_image).unsqueeze(0).to(self.device, dtype=torch.float32) 540 | clip_image_embeds = self.style_image_encoder(clip_image) 541 | else: 542 | if isinstance(pil_image, Image.Image): 543 | pil_image = [pil_image] 544 | clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values 545 | clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds 546 | 547 | 548 | else: 549 | clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16) 550 | image_prompt_embeds = self.style_image_proj_model(clip_image_embeds) 551 | uncond_image_prompt_embeds = self.style_image_proj_model(torch.zeros_like(clip_image_embeds)) 552 | return image_prompt_embeds, uncond_image_prompt_embeds 553 | 554 | def set_scale(self, content_scale, style_scale): 555 | for attn_processor in self.pipe.unet.attn_processors.values(): 556 | if isinstance(attn_processor, IP_CS_AttnProcessor): 557 | if attn_processor.content is True: 558 | attn_processor.content_scale = content_scale 559 | 560 | if attn_processor.style is True: 561 | attn_processor.style_scale = style_scale 562 | # print('style_scale:',style_scale) 563 | if self.controlnet_adapter is not None: 564 | for attn_processor in self.pipe.controlnet.attn_processors.values(): 565 | 566 | if isinstance(attn_processor, IP_CS_AttnProcessor): 567 | if attn_processor.content is True: 568 | attn_processor.content_scale = content_scale 569 | # print(content_scale) 570 | 571 | if attn_processor.style is True: 572 | attn_processor.style_scale = style_scale 573 | 574 | def generate( 575 | self, 576 | pil_content_image=None, 577 | pil_style_image=None, 578 | clip_content_image_embeds=None, 579 | clip_style_image_embeds=None, 580 | prompt=None, 581 | negative_prompt=None, 582 | content_scale=1.0, 583 | style_scale=1.0, 584 | num_samples=4, 585 | seed=None, 586 | guidance_scale=7.5, 587 | num_inference_steps=30, 588 | neg_content_emb=None, 589 | **kwargs, 590 | ): 591 | self.set_scale(content_scale, style_scale) 592 | 593 | if pil_content_image is not None: 594 | num_prompts = 1 if isinstance(pil_content_image, Image.Image) else len(pil_content_image) 595 | else: 596 | num_prompts = clip_content_image_embeds.size(0) 597 | 598 | if prompt is None: 599 | prompt = "best quality, high quality" 600 | if negative_prompt is None: 601 | negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" 602 | 603 | if not isinstance(prompt, List): 604 | prompt = [prompt] * num_prompts 605 | if not isinstance(negative_prompt, List): 606 | negative_prompt = [negative_prompt] * num_prompts 607 | 608 | content_image_prompt_embeds, uncond_content_image_prompt_embeds = self.get_image_embeds( 609 | pil_image=pil_content_image, clip_image_embeds=clip_content_image_embeds 610 | ) 611 | style_image_prompt_embeds, uncond_style_image_prompt_embeds = self.get_image_embeds( 612 | pil_image=pil_style_image, clip_image_embeds=clip_style_image_embeds 613 | ) 614 | 615 | bs_embed, seq_len, _ = content_image_prompt_embeds.shape 616 | content_image_prompt_embeds = content_image_prompt_embeds.repeat(1, num_samples, 1) 617 | content_image_prompt_embeds = content_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 618 | uncond_content_image_prompt_embeds = uncond_content_image_prompt_embeds.repeat(1, num_samples, 1) 619 | uncond_content_image_prompt_embeds = uncond_content_image_prompt_embeds.view(bs_embed * num_samples, seq_len, 620 | -1) 621 | 622 | bs_style_embed, seq_style_len, _ = content_image_prompt_embeds.shape 623 | style_image_prompt_embeds = style_image_prompt_embeds.repeat(1, num_samples, 1) 624 | style_image_prompt_embeds = style_image_prompt_embeds.view(bs_embed * num_samples, seq_style_len, -1) 625 | uncond_style_image_prompt_embeds = uncond_style_image_prompt_embeds.repeat(1, num_samples, 1) 626 | uncond_style_image_prompt_embeds = uncond_style_image_prompt_embeds.view(bs_embed * num_samples, seq_style_len, 627 | -1) 628 | 629 | with torch.inference_mode(): 630 | prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt( 631 | prompt, 632 | device=self.device, 633 | num_images_per_prompt=num_samples, 634 | do_classifier_free_guidance=True, 635 | negative_prompt=negative_prompt, 636 | ) 637 | prompt_embeds = torch.cat([prompt_embeds_, content_image_prompt_embeds, style_image_prompt_embeds], dim=1) 638 | negative_prompt_embeds = torch.cat([negative_prompt_embeds_, 639 | uncond_content_image_prompt_embeds, uncond_style_image_prompt_embeds], 640 | dim=1) 641 | 642 | generator = get_generator(seed, self.device) 643 | 644 | images = self.pipe( 645 | prompt_embeds=prompt_embeds, 646 | negative_prompt_embeds=negative_prompt_embeds, 647 | guidance_scale=guidance_scale, 648 | num_inference_steps=num_inference_steps, 649 | generator=generator, 650 | **kwargs, 651 | ).images 652 | 653 | return images 654 | 655 | 656 | class IPAdapterXL_CS(IPAdapter_CS): 657 | """SDXL""" 658 | 659 | def generate( 660 | self, 661 | pil_content_image, 662 | pil_style_image, 663 | prompt=None, 664 | negative_prompt=None, 665 | content_scale=1.0, 666 | style_scale=1.0, 667 | num_samples=4, 668 | seed=None, 669 | content_image_embeds=None, 670 | style_image_embeds=None, 671 | num_inference_steps=30, 672 | neg_content_emb=None, 673 | neg_content_prompt=None, 674 | neg_content_scale=1.0, 675 | **kwargs, 676 | ): 677 | self.set_scale(content_scale, style_scale) 678 | 679 | num_prompts = 1 if isinstance(pil_content_image, Image.Image) else len(pil_content_image) 680 | 681 | if prompt is None: 682 | prompt = "best quality, high quality" 683 | if negative_prompt is None: 684 | negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" 685 | 686 | if not isinstance(prompt, List): 687 | prompt = [prompt] * num_prompts 688 | if not isinstance(negative_prompt, List): 689 | negative_prompt = [negative_prompt] * num_prompts 690 | 691 | content_image_prompt_embeds, uncond_content_image_prompt_embeds = self.get_image_embeds(pil_content_image, 692 | content_image_embeds, 693 | content_or_style_='content') 694 | 695 | 696 | 697 | style_image_prompt_embeds, uncond_style_image_prompt_embeds = self.get_image_embeds(pil_style_image, 698 | style_image_embeds, 699 | content_or_style_='style') 700 | 701 | bs_embed, seq_len, _ = content_image_prompt_embeds.shape 702 | 703 | content_image_prompt_embeds = content_image_prompt_embeds.repeat(1, num_samples, 1) 704 | content_image_prompt_embeds = content_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 705 | 706 | uncond_content_image_prompt_embeds = uncond_content_image_prompt_embeds.repeat(1, num_samples, 1) 707 | uncond_content_image_prompt_embeds = uncond_content_image_prompt_embeds.view(bs_embed * num_samples, seq_len, 708 | -1) 709 | bs_style_embed, seq_style_len, _ = style_image_prompt_embeds.shape 710 | style_image_prompt_embeds = style_image_prompt_embeds.repeat(1, num_samples, 1) 711 | style_image_prompt_embeds = style_image_prompt_embeds.view(bs_embed * num_samples, seq_style_len, -1) 712 | uncond_style_image_prompt_embeds = uncond_style_image_prompt_embeds.repeat(1, num_samples, 1) 713 | uncond_style_image_prompt_embeds = uncond_style_image_prompt_embeds.view(bs_embed * num_samples, seq_style_len, 714 | -1) 715 | 716 | with torch.inference_mode(): 717 | ( 718 | prompt_embeds, 719 | negative_prompt_embeds, 720 | pooled_prompt_embeds, 721 | negative_pooled_prompt_embeds, 722 | ) = self.pipe.encode_prompt( 723 | prompt, 724 | num_images_per_prompt=num_samples, 725 | do_classifier_free_guidance=True, 726 | negative_prompt=negative_prompt, 727 | ) 728 | prompt_embeds = torch.cat([prompt_embeds, content_image_prompt_embeds, style_image_prompt_embeds], dim=1) 729 | negative_prompt_embeds = torch.cat([negative_prompt_embeds, 730 | uncond_content_image_prompt_embeds, uncond_style_image_prompt_embeds], 731 | dim=1) 732 | 733 | self.generator = get_generator(seed, self.device) 734 | 735 | images = self.pipe( 736 | prompt_embeds=prompt_embeds, 737 | negative_prompt_embeds=negative_prompt_embeds, 738 | pooled_prompt_embeds=pooled_prompt_embeds, 739 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, 740 | num_inference_steps=num_inference_steps, 741 | generator=self.generator, 742 | **kwargs, 743 | ).images 744 | return images 745 | 746 | 747 | class CSGO(IPAdapterXL_CS): 748 | """SDXL""" 749 | 750 | def init_proj(self, num_tokens, content_or_style_='content', model_resampler=False): 751 | if content_or_style_ == 'content': 752 | if model_resampler: 753 | image_proj_model = Resampler( 754 | dim=self.pipe.unet.config.cross_attention_dim, 755 | depth=4, 756 | dim_head=64, 757 | heads=12, 758 | num_queries=num_tokens, 759 | embedding_dim=self.content_image_encoder.config.hidden_size, 760 | output_dim=self.pipe.unet.config.cross_attention_dim, 761 | ff_mult=4, 762 | ).to(self.device, dtype=torch.float16) 763 | else: 764 | image_proj_model = ImageProjModel( 765 | cross_attention_dim=self.pipe.unet.config.cross_attention_dim, 766 | clip_embeddings_dim=self.image_encoder.config.projection_dim, 767 | clip_extra_context_tokens=num_tokens, 768 | ).to(self.device, dtype=torch.float16) 769 | if content_or_style_ == 'style': 770 | if model_resampler: 771 | image_proj_model = Resampler( 772 | dim=self.pipe.unet.config.cross_attention_dim, 773 | depth=4, 774 | dim_head=64, 775 | heads=12, 776 | num_queries=num_tokens, 777 | embedding_dim=self.content_image_encoder.config.hidden_size, 778 | output_dim=self.pipe.unet.config.cross_attention_dim, 779 | ff_mult=4, 780 | ).to(self.device, dtype=torch.float16) 781 | else: 782 | image_proj_model = ImageProjModel( 783 | cross_attention_dim=self.pipe.unet.config.cross_attention_dim, 784 | clip_embeddings_dim=self.image_encoder.config.projection_dim, 785 | clip_extra_context_tokens=num_tokens, 786 | ).to(self.device, dtype=torch.float16) 787 | return image_proj_model 788 | 789 | @torch.inference_mode() 790 | def get_image_embeds(self, pil_image=None, clip_image_embeds=None, content_or_style_=''): 791 | if isinstance(pil_image, Image.Image): 792 | pil_image = [pil_image] 793 | if content_or_style_ == 'style': 794 | 795 | if self.style_model_resampler: 796 | clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values 797 | clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16), 798 | output_hidden_states=True).hidden_states[-2] 799 | image_prompt_embeds = self.style_image_proj_model(clip_image_embeds) 800 | uncond_image_prompt_embeds = self.style_image_proj_model(torch.zeros_like(clip_image_embeds)) 801 | else: 802 | 803 | 804 | clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values 805 | clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds 806 | image_prompt_embeds = self.style_image_proj_model(clip_image_embeds) 807 | uncond_image_prompt_embeds = self.style_image_proj_model(torch.zeros_like(clip_image_embeds)) 808 | return image_prompt_embeds, uncond_image_prompt_embeds 809 | 810 | 811 | else: 812 | 813 | if self.content_image_encoder_path is not None: 814 | clip_image = self.content_image_processor(images=pil_image, return_tensors="pt").pixel_values 815 | outputs = self.content_image_encoder(clip_image.to(self.device, dtype=torch.float16), 816 | output_hidden_states=True) 817 | clip_image_embeds = outputs.last_hidden_state 818 | image_prompt_embeds = self.content_image_proj_model(clip_image_embeds) 819 | 820 | # uncond_clip_image_embeds = self.image_encoder( 821 | # torch.zeros_like(clip_image), output_hidden_states=True 822 | # ).last_hidden_state 823 | uncond_image_prompt_embeds = self.content_image_proj_model(torch.zeros_like(clip_image_embeds)) 824 | return image_prompt_embeds, uncond_image_prompt_embeds 825 | 826 | else: 827 | if self.content_model_resampler: 828 | 829 | clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values 830 | 831 | clip_image = clip_image.to(self.device, dtype=torch.float16) 832 | clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] 833 | # clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16) 834 | image_prompt_embeds = self.content_image_proj_model(clip_image_embeds) 835 | # uncond_clip_image_embeds = self.image_encoder( 836 | # torch.zeros_like(clip_image), output_hidden_states=True 837 | # ).hidden_states[-2] 838 | uncond_image_prompt_embeds = self.content_image_proj_model(torch.zeros_like(clip_image_embeds)) 839 | else: 840 | clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values 841 | clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds 842 | image_prompt_embeds = self.content_image_proj_model(clip_image_embeds) 843 | uncond_image_prompt_embeds = self.content_image_proj_model(torch.zeros_like(clip_image_embeds)) 844 | 845 | return image_prompt_embeds, uncond_image_prompt_embeds 846 | 847 | # # clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values 848 | # clip_image = clip_image.to(self.device, dtype=torch.float16) 849 | # clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] 850 | # image_prompt_embeds = self.content_image_proj_model(clip_image_embeds) 851 | # uncond_clip_image_embeds = self.image_encoder( 852 | # torch.zeros_like(clip_image), output_hidden_states=True 853 | # ).hidden_states[-2] 854 | # uncond_image_prompt_embeds = self.content_image_proj_model(uncond_clip_image_embeds) 855 | # return image_prompt_embeds, uncond_image_prompt_embeds 856 | 857 | 858 | class IPAdapterXL(IPAdapter): 859 | """SDXL""" 860 | 861 | def generate( 862 | self, 863 | pil_image, 864 | prompt=None, 865 | negative_prompt=None, 866 | scale=1.0, 867 | num_samples=4, 868 | seed=None, 869 | num_inference_steps=30, 870 | neg_content_emb=None, 871 | neg_content_prompt=None, 872 | neg_content_scale=1.0, 873 | **kwargs, 874 | ): 875 | self.set_scale(scale) 876 | 877 | num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) 878 | 879 | if prompt is None: 880 | prompt = "best quality, high quality" 881 | if negative_prompt is None: 882 | negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" 883 | 884 | if not isinstance(prompt, List): 885 | prompt = [prompt] * num_prompts 886 | if not isinstance(negative_prompt, List): 887 | negative_prompt = [negative_prompt] * num_prompts 888 | 889 | if neg_content_emb is None: 890 | if neg_content_prompt is not None: 891 | with torch.inference_mode(): 892 | ( 893 | prompt_embeds_, # torch.Size([1, 77, 2048]) 894 | negative_prompt_embeds_, 895 | pooled_prompt_embeds_, # torch.Size([1, 1280]) 896 | negative_pooled_prompt_embeds_, 897 | ) = self.pipe.encode_prompt( 898 | neg_content_prompt, 899 | num_images_per_prompt=num_samples, 900 | do_classifier_free_guidance=True, 901 | negative_prompt=negative_prompt, 902 | ) 903 | pooled_prompt_embeds_ *= neg_content_scale 904 | else: 905 | pooled_prompt_embeds_ = neg_content_emb 906 | else: 907 | pooled_prompt_embeds_ = None 908 | 909 | image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image, 910 | content_prompt_embeds=pooled_prompt_embeds_) 911 | bs_embed, seq_len, _ = image_prompt_embeds.shape 912 | image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) 913 | image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 914 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) 915 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 916 | 917 | with torch.inference_mode(): 918 | ( 919 | prompt_embeds, 920 | negative_prompt_embeds, 921 | pooled_prompt_embeds, 922 | negative_pooled_prompt_embeds, 923 | ) = self.pipe.encode_prompt( 924 | prompt, 925 | num_images_per_prompt=num_samples, 926 | do_classifier_free_guidance=True, 927 | negative_prompt=negative_prompt, 928 | ) 929 | prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) 930 | negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) 931 | 932 | self.generator = get_generator(seed, self.device) 933 | 934 | images = self.pipe( 935 | prompt_embeds=prompt_embeds, 936 | negative_prompt_embeds=negative_prompt_embeds, 937 | pooled_prompt_embeds=pooled_prompt_embeds, 938 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, 939 | num_inference_steps=num_inference_steps, 940 | generator=self.generator, 941 | **kwargs, 942 | ).images 943 | 944 | return images 945 | 946 | 947 | class IPAdapterPlus(IPAdapter): 948 | """IP-Adapter with fine-grained features""" 949 | 950 | def init_proj(self): 951 | image_proj_model = Resampler( 952 | dim=self.pipe.unet.config.cross_attention_dim, 953 | depth=4, 954 | dim_head=64, 955 | heads=12, 956 | num_queries=self.num_tokens, 957 | embedding_dim=self.image_encoder.config.hidden_size, 958 | output_dim=self.pipe.unet.config.cross_attention_dim, 959 | ff_mult=4, 960 | ).to(self.device, dtype=torch.float16) 961 | return image_proj_model 962 | 963 | @torch.inference_mode() 964 | def get_image_embeds(self, pil_image=None, clip_image_embeds=None): 965 | if isinstance(pil_image, Image.Image): 966 | pil_image = [pil_image] 967 | clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values 968 | clip_image = clip_image.to(self.device, dtype=torch.float16) 969 | clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] 970 | image_prompt_embeds = self.image_proj_model(clip_image_embeds) 971 | uncond_clip_image_embeds = self.image_encoder( 972 | torch.zeros_like(clip_image), output_hidden_states=True 973 | ).hidden_states[-2] 974 | uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) 975 | return image_prompt_embeds, uncond_image_prompt_embeds 976 | 977 | 978 | class IPAdapterFull(IPAdapterPlus): 979 | """IP-Adapter with full features""" 980 | 981 | def init_proj(self): 982 | image_proj_model = MLPProjModel( 983 | cross_attention_dim=self.pipe.unet.config.cross_attention_dim, 984 | clip_embeddings_dim=self.image_encoder.config.hidden_size, 985 | ).to(self.device, dtype=torch.float16) 986 | return image_proj_model 987 | 988 | 989 | class IPAdapterPlusXL(IPAdapter): 990 | """SDXL""" 991 | 992 | def init_proj(self): 993 | image_proj_model = Resampler( 994 | dim=1280, 995 | depth=4, 996 | dim_head=64, 997 | heads=20, 998 | num_queries=self.num_tokens, 999 | embedding_dim=self.image_encoder.config.hidden_size, 1000 | output_dim=self.pipe.unet.config.cross_attention_dim, 1001 | ff_mult=4, 1002 | ).to(self.device, dtype=torch.float16) 1003 | return image_proj_model 1004 | 1005 | @torch.inference_mode() 1006 | def get_image_embeds(self, pil_image): 1007 | if isinstance(pil_image, Image.Image): 1008 | pil_image = [pil_image] 1009 | clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values 1010 | clip_image = clip_image.to(self.device, dtype=torch.float16) 1011 | clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] 1012 | image_prompt_embeds = self.image_proj_model(clip_image_embeds) 1013 | uncond_clip_image_embeds = self.image_encoder( 1014 | torch.zeros_like(clip_image), output_hidden_states=True 1015 | ).hidden_states[-2] 1016 | uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) 1017 | return image_prompt_embeds, uncond_image_prompt_embeds 1018 | 1019 | def generate( 1020 | self, 1021 | pil_image, 1022 | prompt=None, 1023 | negative_prompt=None, 1024 | scale=1.0, 1025 | num_samples=4, 1026 | seed=None, 1027 | num_inference_steps=30, 1028 | **kwargs, 1029 | ): 1030 | self.set_scale(scale) 1031 | 1032 | num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) 1033 | 1034 | if prompt is None: 1035 | prompt = "best quality, high quality" 1036 | if negative_prompt is None: 1037 | negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" 1038 | 1039 | if not isinstance(prompt, List): 1040 | prompt = [prompt] * num_prompts 1041 | if not isinstance(negative_prompt, List): 1042 | negative_prompt = [negative_prompt] * num_prompts 1043 | 1044 | image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image) 1045 | bs_embed, seq_len, _ = image_prompt_embeds.shape 1046 | image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) 1047 | image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 1048 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) 1049 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 1050 | 1051 | with torch.inference_mode(): 1052 | ( 1053 | prompt_embeds, 1054 | negative_prompt_embeds, 1055 | pooled_prompt_embeds, 1056 | negative_pooled_prompt_embeds, 1057 | ) = self.pipe.encode_prompt( 1058 | prompt, 1059 | num_images_per_prompt=num_samples, 1060 | do_classifier_free_guidance=True, 1061 | negative_prompt=negative_prompt, 1062 | ) 1063 | prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) 1064 | negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) 1065 | 1066 | generator = get_generator(seed, self.device) 1067 | 1068 | images = self.pipe( 1069 | prompt_embeds=prompt_embeds, 1070 | negative_prompt_embeds=negative_prompt_embeds, 1071 | pooled_prompt_embeds=pooled_prompt_embeds, 1072 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, 1073 | num_inference_steps=num_inference_steps, 1074 | generator=generator, 1075 | **kwargs, 1076 | ).images 1077 | 1078 | return images 1079 | -------------------------------------------------------------------------------- /ip_adapter/resampler.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py 2 | # and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py 3 | 4 | import math 5 | 6 | import torch 7 | import torch.nn as nn 8 | from einops import rearrange 9 | from einops.layers.torch import Rearrange 10 | 11 | 12 | # FFN 13 | def FeedForward(dim, mult=4): 14 | inner_dim = int(dim * mult) 15 | return nn.Sequential( 16 | nn.LayerNorm(dim), 17 | nn.Linear(dim, inner_dim, bias=False), 18 | nn.GELU(), 19 | nn.Linear(inner_dim, dim, bias=False), 20 | ) 21 | 22 | 23 | def reshape_tensor(x, heads): 24 | bs, length, width = x.shape 25 | # (bs, length, width) --> (bs, length, n_heads, dim_per_head) 26 | x = x.view(bs, length, heads, -1) 27 | # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) 28 | x = x.transpose(1, 2) 29 | # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) 30 | x = x.reshape(bs, heads, length, -1) 31 | return x 32 | 33 | 34 | class PerceiverAttention(nn.Module): 35 | def __init__(self, *, dim, dim_head=64, heads=8): 36 | super().__init__() 37 | self.scale = dim_head**-0.5 38 | self.dim_head = dim_head 39 | self.heads = heads 40 | inner_dim = dim_head * heads 41 | 42 | self.norm1 = nn.LayerNorm(dim) 43 | self.norm2 = nn.LayerNorm(dim) 44 | 45 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 46 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) 47 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 48 | 49 | def forward(self, x, latents): 50 | """ 51 | Args: 52 | x (torch.Tensor): image features 53 | shape (b, n1, D) 54 | latent (torch.Tensor): latent features 55 | shape (b, n2, D) 56 | """ 57 | x = self.norm1(x) 58 | latents = self.norm2(latents) 59 | 60 | b, l, _ = latents.shape 61 | 62 | q = self.to_q(latents) 63 | kv_input = torch.cat((x, latents), dim=-2) 64 | k, v = self.to_kv(kv_input).chunk(2, dim=-1) 65 | 66 | q = reshape_tensor(q, self.heads) 67 | k = reshape_tensor(k, self.heads) 68 | v = reshape_tensor(v, self.heads) 69 | 70 | # attention 71 | scale = 1 / math.sqrt(math.sqrt(self.dim_head)) 72 | weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards 73 | weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) 74 | out = weight @ v 75 | 76 | out = out.permute(0, 2, 1, 3).reshape(b, l, -1) 77 | 78 | return self.to_out(out) 79 | 80 | 81 | class Resampler(nn.Module): 82 | def __init__( 83 | self, 84 | dim=1024, 85 | depth=8, 86 | dim_head=64, 87 | heads=16, 88 | num_queries=8, 89 | embedding_dim=768, 90 | output_dim=1024, 91 | ff_mult=4, 92 | max_seq_len: int = 257, # CLIP tokens + CLS token 93 | apply_pos_emb: bool = False, 94 | num_latents_mean_pooled: int = 0, # number of latents derived from mean pooled representation of the sequence 95 | ): 96 | super().__init__() 97 | self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None 98 | 99 | self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) 100 | 101 | self.proj_in = nn.Linear(embedding_dim, dim) 102 | 103 | self.proj_out = nn.Linear(dim, output_dim) 104 | self.norm_out = nn.LayerNorm(output_dim) 105 | 106 | self.to_latents_from_mean_pooled_seq = ( 107 | nn.Sequential( 108 | nn.LayerNorm(dim), 109 | nn.Linear(dim, dim * num_latents_mean_pooled), 110 | Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled), 111 | ) 112 | if num_latents_mean_pooled > 0 113 | else None 114 | ) 115 | 116 | self.layers = nn.ModuleList([]) 117 | for _ in range(depth): 118 | self.layers.append( 119 | nn.ModuleList( 120 | [ 121 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), 122 | FeedForward(dim=dim, mult=ff_mult), 123 | ] 124 | ) 125 | ) 126 | 127 | def forward(self, x): 128 | if self.pos_emb is not None: 129 | n, device = x.shape[1], x.device 130 | pos_emb = self.pos_emb(torch.arange(n, device=device)) 131 | x = x + pos_emb 132 | 133 | latents = self.latents.repeat(x.size(0), 1, 1) 134 | 135 | x = self.proj_in(x) 136 | 137 | if self.to_latents_from_mean_pooled_seq: 138 | meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool)) 139 | meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq) 140 | latents = torch.cat((meanpooled_latents, latents), dim=-2) 141 | 142 | for attn, ff in self.layers: 143 | latents = attn(x, latents) + latents 144 | latents = ff(latents) + latents 145 | 146 | latents = self.proj_out(latents) 147 | return self.norm_out(latents) 148 | 149 | 150 | def masked_mean(t, *, dim, mask=None): 151 | if mask is None: 152 | return t.mean(dim=dim) 153 | 154 | denom = mask.sum(dim=dim, keepdim=True) 155 | mask = rearrange(mask, "b n -> b n 1") 156 | masked_t = t.masked_fill(~mask, 0.0) 157 | 158 | return masked_t.sum(dim=dim) / denom.clamp(min=1e-5) 159 | -------------------------------------------------------------------------------- /ip_adapter/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from PIL import Image 5 | 6 | BLOCKS = { 7 | 'content': ['down_blocks'], 8 | 'style': ["up_blocks"], 9 | 10 | } 11 | 12 | controlnet_BLOCKS = { 13 | 'content': [], 14 | 'style': ["down_blocks"], 15 | } 16 | 17 | 18 | def resize_width_height(width, height, min_short_side=512, max_long_side=1024): 19 | 20 | if width < height: 21 | 22 | if width < min_short_side: 23 | scale_factor = min_short_side / width 24 | new_width = min_short_side 25 | new_height = int(height * scale_factor) 26 | else: 27 | new_width, new_height = width, height 28 | else: 29 | 30 | if height < min_short_side: 31 | scale_factor = min_short_side / height 32 | new_width = int(width * scale_factor) 33 | new_height = min_short_side 34 | else: 35 | new_width, new_height = width, height 36 | 37 | if max(new_width, new_height) > max_long_side: 38 | scale_factor = max_long_side / max(new_width, new_height) 39 | new_width = int(new_width * scale_factor) 40 | new_height = int(new_height * scale_factor) 41 | return new_width, new_height 42 | 43 | def resize_content(content_image): 44 | max_long_side = 1024 45 | min_short_side = 1024 46 | 47 | new_width, new_height = resize_width_height(content_image.size[0], content_image.size[1], 48 | min_short_side=min_short_side, max_long_side=max_long_side) 49 | height = new_height // 16 * 16 50 | width = new_width // 16 * 16 51 | content_image = content_image.resize((width, height)) 52 | 53 | return width,height,content_image 54 | 55 | attn_maps = {} 56 | def hook_fn(name): 57 | def forward_hook(module, input, output): 58 | if hasattr(module.processor, "attn_map"): 59 | attn_maps[name] = module.processor.attn_map 60 | del module.processor.attn_map 61 | 62 | return forward_hook 63 | 64 | def register_cross_attention_hook(unet): 65 | for name, module in unet.named_modules(): 66 | if name.split('.')[-1].startswith('attn2'): 67 | module.register_forward_hook(hook_fn(name)) 68 | 69 | return unet 70 | 71 | def upscale(attn_map, target_size): 72 | attn_map = torch.mean(attn_map, dim=0) 73 | attn_map = attn_map.permute(1,0) 74 | temp_size = None 75 | 76 | for i in range(0,5): 77 | scale = 2 ** i 78 | if ( target_size[0] // scale ) * ( target_size[1] // scale) == attn_map.shape[1]*64: 79 | temp_size = (target_size[0]//(scale*8), target_size[1]//(scale*8)) 80 | break 81 | 82 | assert temp_size is not None, "temp_size cannot is None" 83 | 84 | attn_map = attn_map.view(attn_map.shape[0], *temp_size) 85 | 86 | attn_map = F.interpolate( 87 | attn_map.unsqueeze(0).to(dtype=torch.float32), 88 | size=target_size, 89 | mode='bilinear', 90 | align_corners=False 91 | )[0] 92 | 93 | attn_map = torch.softmax(attn_map, dim=0) 94 | return attn_map 95 | def get_net_attn_map(image_size, batch_size=2, instance_or_negative=False, detach=True): 96 | 97 | idx = 0 if instance_or_negative else 1 98 | net_attn_maps = [] 99 | 100 | for name, attn_map in attn_maps.items(): 101 | attn_map = attn_map.cpu() if detach else attn_map 102 | attn_map = torch.chunk(attn_map, batch_size)[idx].squeeze() 103 | attn_map = upscale(attn_map, image_size) 104 | net_attn_maps.append(attn_map) 105 | 106 | net_attn_maps = torch.mean(torch.stack(net_attn_maps,dim=0),dim=0) 107 | 108 | return net_attn_maps 109 | 110 | def attnmaps2images(net_attn_maps): 111 | 112 | #total_attn_scores = 0 113 | images = [] 114 | 115 | for attn_map in net_attn_maps: 116 | attn_map = attn_map.cpu().numpy() 117 | #total_attn_scores += attn_map.mean().item() 118 | 119 | normalized_attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255 120 | normalized_attn_map = normalized_attn_map.astype(np.uint8) 121 | #print("norm: ", normalized_attn_map.shape) 122 | image = Image.fromarray(normalized_attn_map) 123 | 124 | #image = fix_save_attn_map(attn_map) 125 | images.append(image) 126 | 127 | #print(total_attn_scores) 128 | return images 129 | def is_torch2_available(): 130 | return hasattr(F, "scaled_dot_product_attention") 131 | 132 | def get_generator(seed, device): 133 | 134 | if seed is not None: 135 | if isinstance(seed, list): 136 | generator = [torch.Generator(device).manual_seed(seed_item) for seed_item in seed] 137 | else: 138 | generator = torch.Generator(device).manual_seed(seed) 139 | else: 140 | generator = None 141 | 142 | return generator -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | diffusers==0.25.1 2 | torch==2.0.1 3 | torchaudio==2.0.2 4 | torchvision==0.15.2 5 | transformers==4.40.2 6 | accelerate 7 | safetensors 8 | einops 9 | spaces==0.19.4 10 | omegaconf 11 | peft 12 | huggingface-hub==0.24.5 13 | opencv-python 14 | insightface 15 | gradio 16 | controlnet_aux 17 | gdown 18 | peft 19 | --------------------------------------------------------------------------------