├── .gitignore ├── Readme.md ├── __init__.py ├── app.py ├── download_models.py ├── install.bat ├── install.sh ├── instantidcpu-screenshot.jpg ├── paths.py ├── pipeline_stable_diffusion_xl_instantid.py ├── requirements.txt ├── start.bat ├── start.sh └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | env 2 | *.bak 3 | *.pyc 4 | __pycache__ 5 | results 6 | models -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | # InstantID CPU 2 | 3 | InstantID CPU inference with less memory requirement(11GB RAM). 4 | 5 | ![InstantID CPU Screenshot](https://raw.githubusercontent.com/rupeshs/instantidcpu/main/instantidcpu-screenshot.jpg) 6 | 7 | ## Prerequisites 8 | 9 | - Python 3.10 or higher 10 | - 11 GB system RAM 11 | - Active internet connection to install and download models 12 | 13 | ## How to run 14 | 15 | Follow these steps : 16 | 17 | ### Windows 18 | 19 | - Double click install.bat(It will take some time to install,depending on your internet speed.) 20 | - To start desktop GUI double click start.bat 21 | 22 | ## Linux/Mac 23 | 24 | Run the following command in the terminal 25 | 26 | - Clone/download this repo 27 | - In the terminal, enter into instantidcpu directory 28 | - `chmod +x install.sh` 29 | - `./install.sh` 30 | - `./start.sh` 31 | 32 | ## Disclaimer 33 | 34 | The code of InstantID is released under [Apache License](https://github.com/InstantID/InstantID?tab=Apache-2.0-1-ov-file#readme) for both academic and commercial usage. **However, both manual-downloading and auto-downloading face models from insightface are for non-commercial research purposes only** according to their [license](https://github.com/deepinsight/insightface?tab=readme-ov-file#license). Users are granted the freedom to create images using this tool, but they are obligated to comply with local laws and utilize it responsibly. The developers will not assume any responsibility for potential misuse by users. 35 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rupeshs/instantidcpu/fbf506585cc1d11862d1c7706e36a20ba06adb05/__init__.py -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | from concurrent.futures import ThreadPoolExecutor 2 | from time import time 3 | 4 | import cv2 5 | import gradio as gr 6 | import numpy as np 7 | import torch 8 | from diffusers import LCMScheduler 9 | from diffusers.models import ControlNetModel 10 | from insightface.app import FaceAnalysis 11 | 12 | from download_models import download_instant_id_sdxl_models 13 | from pipeline_stable_diffusion_xl_instantid import ( 14 | StableDiffusionXLInstantIDPipeline, 15 | draw_kps, 16 | ) 17 | 18 | face_adapter = f"./models/sdxl/ip-adapter.bin" 19 | controlnet_path = f"./models/sdxl/ControlNetModel" 20 | base_model = "stabilityai/sdxl-turbo" 21 | APP_VERSION = "0.1.0" 22 | device = "cpu" 23 | 24 | download_instant_id_sdxl_models() 25 | 26 | app = FaceAnalysis( 27 | name="antelopev2", 28 | root="./", 29 | providers=["CPUExecutionProvider"], 30 | ) 31 | app.prepare(ctx_id=0, det_size=(320, 320)) 32 | torch_dtype = torch.float32 33 | 34 | controlnet = ControlNetModel.from_pretrained( 35 | controlnet_path, 36 | torch_dtype=torch_dtype, 37 | resume_download=True, 38 | ) 39 | pipe = StableDiffusionXLInstantIDPipeline.from_pretrained( 40 | base_model, 41 | controlnet=controlnet, 42 | torch_dtype=torch_dtype, 43 | resume_download=True, 44 | ) 45 | pipe.load_ip_adapter_instantid(face_adapter) 46 | pipe.to(device) 47 | pipe.unet = pipe.unet.to(memory_format=torch.channels_last) 48 | pipe.controlnet = pipe.controlnet.to(memory_format=torch.channels_last) 49 | 50 | pipe.text_encoder = pipe.text_encoder.to(memory_format=torch.channels_last) 51 | pipe.scheduler = LCMScheduler.from_config( 52 | pipe.scheduler.config, 53 | beta_start=0.001, 54 | beta_end=0.012, 55 | ) 56 | 57 | # pipe.scheduler = DEISMultistepScheduler.from_config( 58 | # pipe.scheduler.config, 59 | # ) 60 | pipe.enable_freeu( 61 | s1=0.6, 62 | s2=0.4, 63 | b1=1.1, 64 | b2=1.2, 65 | ) 66 | print(pipe) 67 | 68 | 69 | def generate_image( 70 | face_image, 71 | prompt, 72 | identitynet_strength_ratio, 73 | adapter_strength_ratio, 74 | ): 75 | print(f"identitynet_strength_ratio :{identitynet_strength_ratio}") 76 | print(f"adapter_strength_ratio :{adapter_strength_ratio}") 77 | if prompt == "": 78 | prompt = "Photo of a person,high quality" 79 | face_image = face_image.resize((960, 1024)) 80 | face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR)) 81 | face_info = sorted( 82 | face_info, 83 | key=lambda x: (x["bbox"][2] - x["bbox"][0]) * x["bbox"][3] - x["bbox"][1], 84 | )[ 85 | -1 86 | ] # only use the maximum face 87 | face_emb = face_info["embedding"] 88 | face_kps = draw_kps(face_image, face_info["kps"]) 89 | 90 | # generate image 91 | pipe.set_ip_adapter_scale(adapter_strength_ratio) 92 | print("start") 93 | images = pipe( 94 | prompt, 95 | image_embeds=face_emb, 96 | image=face_kps, 97 | controlnet_conditioning_scale=identitynet_strength_ratio, 98 | num_inference_steps=1, 99 | guidance_scale=0.0, 100 | ).images 101 | return images 102 | 103 | 104 | def process_image( 105 | face_image, 106 | prompt, 107 | identitynet_strength_ratio, 108 | adapter_strength_ratio, 109 | ): 110 | tick = time() 111 | with ThreadPoolExecutor(max_workers=1) as executor: 112 | future = executor.submit( 113 | generate_image, 114 | face_image, 115 | prompt, 116 | identitynet_strength_ratio, 117 | adapter_strength_ratio, 118 | ) 119 | images = future.result() 120 | elapsed = time() - tick 121 | print(f"Latency : {elapsed:.2f} seconds") 122 | return images[0] 123 | 124 | 125 | def _get_footer_message() -> str: 126 | version = f"

v{APP_VERSION}" 127 | footer_msg = version 128 | return footer_msg 129 | 130 | 131 | css = """ 132 | #generate_button { 133 | color: white; 134 | border-color: #007bff; 135 | background: #2563eb; 136 | 137 | } 138 | """ 139 | 140 | 141 | def get_web_ui() -> gr.Blocks: 142 | with gr.Blocks( 143 | css=css, 144 | title="InstantID CPU", 145 | ) as web_ui: 146 | gr.HTML("

InstantID CPU

") 147 | with gr.Row(): 148 | with gr.Column(): 149 | input_image = gr.Image( 150 | label="Face image", 151 | type="pil", 152 | height=512, 153 | ) 154 | with gr.Row(): 155 | prompt = gr.Textbox( 156 | show_label=False, 157 | lines=3, 158 | placeholder="Oil painting", 159 | container=False, 160 | ) 161 | 162 | generate_btn = gr.Button( 163 | "Generate", 164 | elem_id="generate_button", 165 | scale=0, 166 | ) 167 | identitynet_strength_ratio = gr.Slider( 168 | label="IdentityNet strength (for fidelity)", 169 | minimum=0, 170 | maximum=1.5, 171 | step=0.05, 172 | value=0.80, 173 | ) 174 | adapter_strength_ratio = gr.Slider( 175 | label="Image adapter strength (for detail)", 176 | minimum=0, 177 | maximum=1.5, 178 | step=0.05, 179 | value=0.80, 180 | ) 181 | 182 | input_params = [ 183 | input_image, 184 | prompt, 185 | identitynet_strength_ratio, 186 | adapter_strength_ratio, 187 | ] 188 | 189 | with gr.Column(): 190 | gallery = gr.Image( 191 | label="Generated Image", 192 | ) 193 | generate_btn.click( 194 | fn=process_image, 195 | inputs=input_params, 196 | outputs=gallery, 197 | ) 198 | 199 | gr.HTML(_get_footer_message()) 200 | 201 | return web_ui 202 | 203 | 204 | def start_webui( 205 | share: bool = False, 206 | ): 207 | webui = get_web_ui() 208 | webui.queue() 209 | webui.launch(share=share) 210 | 211 | 212 | start_webui() 213 | -------------------------------------------------------------------------------- /download_models.py: -------------------------------------------------------------------------------- 1 | from huggingface_hub import hf_hub_download 2 | from huggingface_hub import snapshot_download 3 | 4 | 5 | def download_instant_id_sdxl_models(): 6 | hf_hub_download( 7 | repo_id="InstantX/InstantID", 8 | filename="ControlNetModel/config.json", 9 | local_dir="./models/sdxl/", 10 | ) 11 | hf_hub_download( 12 | repo_id="InstantX/InstantID", 13 | filename="ControlNetModel/diffusion_pytorch_model.safetensors", 14 | local_dir="./models/sdxl/", 15 | ) 16 | hf_hub_download( 17 | repo_id="InstantX/InstantID", 18 | filename="ip-adapter.bin", 19 | local_dir="./models/sdxl/", 20 | ) 21 | snapshot_download( 22 | repo_id="rupeshs/antelopev2", 23 | local_dir="./models/antelopev2", 24 | ) 25 | -------------------------------------------------------------------------------- /install.bat: -------------------------------------------------------------------------------- 1 | 2 | @echo off 3 | setlocal 4 | echo Starting InstantID CPU env installation... 5 | 6 | set "PYTHON_COMMAND=python" 7 | 8 | call python --version > nul 2>&1 9 | if %errorlevel% equ 0 ( 10 | echo Python command check :OK 11 | ) else ( 12 | echo "Error: Python command not found,please install Python(Recommended : Python 3.10 or higher) and try again." 13 | pause 14 | exit /b 1 15 | 16 | ) 17 | 18 | :check_python_version 19 | for /f "tokens=2" %%I in ('%PYTHON_COMMAND% --version 2^>^&1') do ( 20 | set "python_version=%%I" 21 | ) 22 | 23 | echo Python version: %python_version% 24 | 25 | %PYTHON_COMMAND% -m venv "%~dp0env" 26 | call "%~dp0env\Scripts\activate.bat" && pip install torch==2.2.0 --index-url https://download.pytorch.org/whl/cpu 27 | call "%~dp0env\Scripts\activate.bat" && pip install -r "%~dp0requirements.txt" 28 | echo InstantID CPU env installation completed. 29 | pause -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | echo Starting InstantID CPU env installation... 3 | set -e 4 | PYTHON_COMMAND="python3" 5 | 6 | if ! command -v python3 &>/dev/null; then 7 | if ! command -v python &>/dev/null; then 8 | echo "Error: Python not found, please install python 3.10 or higher and try again" 9 | exit 1 10 | fi 11 | fi 12 | 13 | if command -v python &>/dev/null; then 14 | PYTHON_COMMAND="python" 15 | fi 16 | 17 | echo "Found $PYTHON_COMMAND command" 18 | 19 | python_version=$($PYTHON_COMMAND --version 2>&1 | awk '{print $2}') 20 | echo "Python version : $python_version" 21 | 22 | BASEDIR=$(pwd) 23 | 24 | $PYTHON_COMMAND -m venv "$BASEDIR/env" 25 | # shellcheck disable=SC1091 26 | source "$BASEDIR/env/bin/activate" 27 | pip install torch==2.2.0 --index-url https://download.pytorch.org/whl/cpu 28 | pip install -r "$BASEDIR/requirements.txt" 29 | chmod +x "start.sh" 30 | read -n1 -r -p "InstantID CPU installation completed,press any key to continue..." key -------------------------------------------------------------------------------- /instantidcpu-screenshot.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rupeshs/instantidcpu/fbf506585cc1d11862d1c7706e36a20ba06adb05/instantidcpu-screenshot.jpg -------------------------------------------------------------------------------- /paths.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | 5 | def ensure_file(path: str) -> bool: 6 | return os.path.exists(path) 7 | 8 | 9 | def join_paths( 10 | first_path: str, 11 | second_path: str, 12 | ) -> str: 13 | return os.path.join(first_path, second_path) 14 | 15 | 16 | def get_file_name(file_path: str) -> str: 17 | return Path(file_path).stem 18 | 19 | 20 | def get_app_path() -> str: 21 | app_dir = os.path.dirname(__file__) 22 | work_dir = os.path.dirname(app_dir) 23 | return work_dir 24 | -------------------------------------------------------------------------------- /pipeline_stable_diffusion_xl_instantid.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The InstantX Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import math 17 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union 18 | 19 | import cv2 20 | import numpy as np 21 | import PIL.Image 22 | import torch 23 | import torch.nn as nn 24 | 25 | from diffusers import StableDiffusionXLControlNetPipeline 26 | from diffusers.image_processor import PipelineImageInput 27 | from diffusers.models import ControlNetModel 28 | from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel 29 | from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput 30 | from diffusers.utils import ( 31 | deprecate, 32 | logging, 33 | replace_example_docstring, 34 | ) 35 | from diffusers.utils.import_utils import is_xformers_available 36 | from diffusers.utils.torch_utils import is_compiled_module, is_torch_version 37 | 38 | 39 | try: 40 | import xformers 41 | import xformers.ops 42 | 43 | xformers_available = True 44 | except Exception: 45 | xformers_available = False 46 | 47 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 48 | 49 | 50 | def FeedForward(dim, mult=4): 51 | inner_dim = int(dim * mult) 52 | return nn.Sequential( 53 | nn.LayerNorm(dim), 54 | nn.Linear(dim, inner_dim, bias=False), 55 | nn.GELU(), 56 | nn.Linear(inner_dim, dim, bias=False), 57 | ) 58 | 59 | 60 | def reshape_tensor(x, heads): 61 | bs, length, width = x.shape 62 | # (bs, length, width) --> (bs, length, n_heads, dim_per_head) 63 | x = x.view(bs, length, heads, -1) 64 | # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) 65 | x = x.transpose(1, 2) 66 | # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) 67 | x = x.reshape(bs, heads, length, -1) 68 | return x 69 | 70 | 71 | class PerceiverAttention(nn.Module): 72 | def __init__(self, *, dim, dim_head=64, heads=8): 73 | super().__init__() 74 | self.scale = dim_head**-0.5 75 | self.dim_head = dim_head 76 | self.heads = heads 77 | inner_dim = dim_head * heads 78 | 79 | self.norm1 = nn.LayerNorm(dim) 80 | self.norm2 = nn.LayerNorm(dim) 81 | 82 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 83 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) 84 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 85 | 86 | def forward(self, x, latents): 87 | """ 88 | Args: 89 | x (torch.Tensor): image features 90 | shape (b, n1, D) 91 | latent (torch.Tensor): latent features 92 | shape (b, n2, D) 93 | """ 94 | x = self.norm1(x) 95 | latents = self.norm2(latents) 96 | 97 | b, l, _ = latents.shape 98 | 99 | q = self.to_q(latents) 100 | kv_input = torch.cat((x, latents), dim=-2) 101 | k, v = self.to_kv(kv_input).chunk(2, dim=-1) 102 | 103 | q = reshape_tensor(q, self.heads) 104 | k = reshape_tensor(k, self.heads) 105 | v = reshape_tensor(v, self.heads) 106 | 107 | # attention 108 | scale = 1 / math.sqrt(math.sqrt(self.dim_head)) 109 | weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards 110 | weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) 111 | out = weight @ v 112 | 113 | out = out.permute(0, 2, 1, 3).reshape(b, l, -1) 114 | 115 | return self.to_out(out) 116 | 117 | 118 | class Resampler(nn.Module): 119 | def __init__( 120 | self, 121 | dim=1024, 122 | depth=8, 123 | dim_head=64, 124 | heads=16, 125 | num_queries=8, 126 | embedding_dim=768, 127 | output_dim=1024, 128 | ff_mult=4, 129 | ): 130 | super().__init__() 131 | 132 | self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) 133 | 134 | self.proj_in = nn.Linear(embedding_dim, dim) 135 | 136 | self.proj_out = nn.Linear(dim, output_dim) 137 | self.norm_out = nn.LayerNorm(output_dim) 138 | 139 | self.layers = nn.ModuleList([]) 140 | for _ in range(depth): 141 | self.layers.append( 142 | nn.ModuleList( 143 | [ 144 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), 145 | FeedForward(dim=dim, mult=ff_mult), 146 | ] 147 | ) 148 | ) 149 | 150 | def forward(self, x): 151 | latents = self.latents.repeat(x.size(0), 1, 1) 152 | x = self.proj_in(x) 153 | 154 | for attn, ff in self.layers: 155 | latents = attn(x, latents) + latents 156 | latents = ff(latents) + latents 157 | 158 | latents = self.proj_out(latents) 159 | return self.norm_out(latents) 160 | 161 | 162 | class AttnProcessor(nn.Module): 163 | r""" 164 | Default processor for performing attention-related computations. 165 | """ 166 | 167 | def __init__( 168 | self, 169 | hidden_size=None, 170 | cross_attention_dim=None, 171 | ): 172 | super().__init__() 173 | 174 | def __call__( 175 | self, 176 | attn, 177 | hidden_states, 178 | encoder_hidden_states=None, 179 | attention_mask=None, 180 | temb=None, 181 | ): 182 | residual = hidden_states 183 | 184 | if attn.spatial_norm is not None: 185 | hidden_states = attn.spatial_norm(hidden_states, temb) 186 | 187 | input_ndim = hidden_states.ndim 188 | 189 | if input_ndim == 4: 190 | batch_size, channel, height, width = hidden_states.shape 191 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 192 | 193 | batch_size, sequence_length, _ = ( 194 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 195 | ) 196 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 197 | 198 | if attn.group_norm is not None: 199 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 200 | 201 | query = attn.to_q(hidden_states) 202 | 203 | if encoder_hidden_states is None: 204 | encoder_hidden_states = hidden_states 205 | elif attn.norm_cross: 206 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 207 | 208 | key = attn.to_k(encoder_hidden_states) 209 | value = attn.to_v(encoder_hidden_states) 210 | 211 | query = attn.head_to_batch_dim(query) 212 | key = attn.head_to_batch_dim(key) 213 | value = attn.head_to_batch_dim(value) 214 | 215 | attention_probs = attn.get_attention_scores(query, key, attention_mask) 216 | hidden_states = torch.bmm(attention_probs, value) 217 | hidden_states = attn.batch_to_head_dim(hidden_states) 218 | 219 | # linear proj 220 | hidden_states = attn.to_out[0](hidden_states) 221 | # dropout 222 | hidden_states = attn.to_out[1](hidden_states) 223 | 224 | if input_ndim == 4: 225 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 226 | 227 | if attn.residual_connection: 228 | hidden_states = hidden_states + residual 229 | 230 | hidden_states = hidden_states / attn.rescale_output_factor 231 | 232 | return hidden_states 233 | 234 | 235 | class IPAttnProcessor(nn.Module): 236 | r""" 237 | Attention processor for IP-Adapater. 238 | Args: 239 | hidden_size (`int`): 240 | The hidden size of the attention layer. 241 | cross_attention_dim (`int`): 242 | The number of channels in the `encoder_hidden_states`. 243 | scale (`float`, defaults to 1.0): 244 | the weight scale of image prompt. 245 | num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): 246 | The context length of the image features. 247 | """ 248 | 249 | def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4): 250 | super().__init__() 251 | 252 | self.hidden_size = hidden_size 253 | self.cross_attention_dim = cross_attention_dim 254 | self.scale = scale 255 | self.num_tokens = num_tokens 256 | 257 | self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) 258 | self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) 259 | 260 | def __call__( 261 | self, 262 | attn, 263 | hidden_states, 264 | encoder_hidden_states=None, 265 | attention_mask=None, 266 | temb=None, 267 | ): 268 | residual = hidden_states 269 | 270 | if attn.spatial_norm is not None: 271 | hidden_states = attn.spatial_norm(hidden_states, temb) 272 | 273 | input_ndim = hidden_states.ndim 274 | 275 | if input_ndim == 4: 276 | batch_size, channel, height, width = hidden_states.shape 277 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 278 | 279 | batch_size, sequence_length, _ = ( 280 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 281 | ) 282 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 283 | 284 | if attn.group_norm is not None: 285 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 286 | 287 | query = attn.to_q(hidden_states) 288 | 289 | if encoder_hidden_states is None: 290 | encoder_hidden_states = hidden_states 291 | else: 292 | # get encoder_hidden_states, ip_hidden_states 293 | end_pos = encoder_hidden_states.shape[1] - self.num_tokens 294 | encoder_hidden_states, ip_hidden_states = ( 295 | encoder_hidden_states[:, :end_pos, :], 296 | encoder_hidden_states[:, end_pos:, :], 297 | ) 298 | if attn.norm_cross: 299 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 300 | 301 | key = attn.to_k(encoder_hidden_states) 302 | value = attn.to_v(encoder_hidden_states) 303 | 304 | query = attn.head_to_batch_dim(query) 305 | key = attn.head_to_batch_dim(key) 306 | value = attn.head_to_batch_dim(value) 307 | 308 | if xformers_available: 309 | hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) 310 | else: 311 | attention_probs = attn.get_attention_scores(query, key, attention_mask) 312 | hidden_states = torch.bmm(attention_probs, value) 313 | hidden_states = attn.batch_to_head_dim(hidden_states) 314 | 315 | # for ip-adapter 316 | ip_key = self.to_k_ip(ip_hidden_states) 317 | ip_value = self.to_v_ip(ip_hidden_states) 318 | 319 | ip_key = attn.head_to_batch_dim(ip_key) 320 | ip_value = attn.head_to_batch_dim(ip_value) 321 | 322 | if xformers_available: 323 | ip_hidden_states = self._memory_efficient_attention_xformers(query, ip_key, ip_value, None) 324 | else: 325 | ip_attention_probs = attn.get_attention_scores(query, ip_key, None) 326 | ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) 327 | ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) 328 | 329 | hidden_states = hidden_states + self.scale * ip_hidden_states 330 | 331 | # linear proj 332 | hidden_states = attn.to_out[0](hidden_states) 333 | # dropout 334 | hidden_states = attn.to_out[1](hidden_states) 335 | 336 | if input_ndim == 4: 337 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 338 | 339 | if attn.residual_connection: 340 | hidden_states = hidden_states + residual 341 | 342 | hidden_states = hidden_states / attn.rescale_output_factor 343 | 344 | return hidden_states 345 | 346 | def _memory_efficient_attention_xformers(self, query, key, value, attention_mask): 347 | # TODO attention_mask 348 | query = query.contiguous() 349 | key = key.contiguous() 350 | value = value.contiguous() 351 | hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) 352 | return hidden_states 353 | 354 | 355 | EXAMPLE_DOC_STRING = """ 356 | Examples: 357 | ```py 358 | >>> # !pip install opencv-python transformers accelerate insightface 359 | >>> import diffusers 360 | >>> from diffusers.utils import load_image 361 | >>> from diffusers.models import ControlNetModel 362 | 363 | >>> import cv2 364 | >>> import torch 365 | >>> import numpy as np 366 | >>> from PIL import Image 367 | 368 | >>> from insightface.app import FaceAnalysis 369 | >>> from pipeline_stable_diffusion_xl_instantid import StableDiffusionXLInstantIDPipeline, draw_kps 370 | 371 | >>> # download 'antelopev2' under ./models 372 | >>> app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) 373 | >>> app.prepare(ctx_id=0, det_size=(640, 640)) 374 | 375 | >>> # download models under ./checkpoints 376 | >>> face_adapter = f'./checkpoints/ip-adapter.bin' 377 | >>> controlnet_path = f'./checkpoints/ControlNetModel' 378 | 379 | >>> # load IdentityNet 380 | >>> controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16) 381 | 382 | >>> pipe = StableDiffusionXLInstantIDPipeline.from_pretrained( 383 | ... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16 384 | ... ) 385 | >>> pipe.cuda() 386 | 387 | >>> # load adapter 388 | >>> pipe.load_ip_adapter_instantid(face_adapter) 389 | 390 | >>> prompt = "analog film photo of a man. faded film, desaturated, 35mm photo, grainy, vignette, vintage, Kodachrome, Lomography, stained, highly detailed, found footage, masterpiece, best quality" 391 | >>> negative_prompt = "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured (lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch,deformed, mutated, cross-eyed, ugly, disfigured" 392 | 393 | >>> # load an image 394 | >>> image = load_image("your-example.jpg") 395 | 396 | >>> face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))[-1] 397 | >>> face_emb = face_info['embedding'] 398 | >>> face_kps = draw_kps(face_image, face_info['kps']) 399 | 400 | >>> pipe.set_ip_adapter_scale(0.8) 401 | 402 | >>> # generate image 403 | >>> image = pipe( 404 | ... prompt, image_embeds=face_emb, image=face_kps, controlnet_conditioning_scale=0.8 405 | ... ).images[0] 406 | ``` 407 | """ 408 | 409 | 410 | def draw_kps(image_pil, kps, color_list=[(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]): 411 | stickwidth = 4 412 | limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]]) 413 | kps = np.array(kps) 414 | 415 | w, h = image_pil.size 416 | out_img = np.zeros([h, w, 3]) 417 | 418 | for i in range(len(limbSeq)): 419 | index = limbSeq[i] 420 | color = color_list[index[0]] 421 | 422 | x = kps[index][:, 0] 423 | y = kps[index][:, 1] 424 | length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5 425 | angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1])) 426 | polygon = cv2.ellipse2Poly( 427 | (int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1 428 | ) 429 | out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color) 430 | out_img = (out_img * 0.6).astype(np.uint8) 431 | 432 | for idx_kp, kp in enumerate(kps): 433 | color = color_list[idx_kp] 434 | x, y = kp 435 | out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1) 436 | 437 | out_img_pil = PIL.Image.fromarray(out_img.astype(np.uint8)) 438 | return out_img_pil 439 | 440 | 441 | class StableDiffusionXLInstantIDPipeline(StableDiffusionXLControlNetPipeline): 442 | def cuda(self, dtype=torch.float16, use_xformers=False): 443 | self.to("cuda", dtype) 444 | 445 | if hasattr(self, "image_proj_model"): 446 | self.image_proj_model.to(self.unet.device).to(self.unet.dtype) 447 | 448 | if use_xformers: 449 | if is_xformers_available(): 450 | import xformers 451 | from packaging import version 452 | 453 | xformers_version = version.parse(xformers.__version__) 454 | if xformers_version == version.parse("0.0.16"): 455 | logger.warn( 456 | "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." 457 | ) 458 | self.enable_xformers_memory_efficient_attention() 459 | else: 460 | raise ValueError("xformers is not available. Make sure it is installed correctly") 461 | 462 | def load_ip_adapter_instantid(self, model_ckpt, image_emb_dim=512, num_tokens=16, scale=0.5): 463 | self.set_image_proj_model(model_ckpt, image_emb_dim, num_tokens) 464 | self.set_ip_adapter(model_ckpt, num_tokens, scale) 465 | 466 | def set_image_proj_model(self, model_ckpt, image_emb_dim=512, num_tokens=16): 467 | image_proj_model = Resampler( 468 | dim=1280, 469 | depth=4, 470 | dim_head=64, 471 | heads=20, 472 | num_queries=num_tokens, 473 | embedding_dim=image_emb_dim, 474 | output_dim=self.unet.config.cross_attention_dim, 475 | ff_mult=4, 476 | ) 477 | 478 | image_proj_model.eval() 479 | 480 | self.image_proj_model = image_proj_model.to(self.device, dtype=self.dtype) 481 | state_dict = torch.load(model_ckpt, map_location="cpu") 482 | if "image_proj" in state_dict: 483 | state_dict = state_dict["image_proj"] 484 | self.image_proj_model.load_state_dict(state_dict) 485 | 486 | self.image_proj_model_in_features = image_emb_dim 487 | 488 | def set_ip_adapter(self, model_ckpt, num_tokens, scale): 489 | unet = self.unet 490 | attn_procs = {} 491 | for name in unet.attn_processors.keys(): 492 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim 493 | if name.startswith("mid_block"): 494 | hidden_size = unet.config.block_out_channels[-1] 495 | elif name.startswith("up_blocks"): 496 | block_id = int(name[len("up_blocks.")]) 497 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id] 498 | elif name.startswith("down_blocks"): 499 | block_id = int(name[len("down_blocks.")]) 500 | hidden_size = unet.config.block_out_channels[block_id] 501 | if cross_attention_dim is None: 502 | attn_procs[name] = AttnProcessor().to(unet.device, dtype=unet.dtype) 503 | else: 504 | attn_procs[name] = IPAttnProcessor( 505 | hidden_size=hidden_size, 506 | cross_attention_dim=cross_attention_dim, 507 | scale=scale, 508 | num_tokens=num_tokens, 509 | ).to(unet.device, dtype=unet.dtype) 510 | unet.set_attn_processor(attn_procs) 511 | 512 | state_dict = torch.load(model_ckpt, map_location="cpu") 513 | ip_layers = torch.nn.ModuleList(self.unet.attn_processors.values()) 514 | if "ip_adapter" in state_dict: 515 | state_dict = state_dict["ip_adapter"] 516 | ip_layers.load_state_dict(state_dict) 517 | 518 | def set_ip_adapter_scale(self, scale): 519 | unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet 520 | for attn_processor in unet.attn_processors.values(): 521 | if isinstance(attn_processor, IPAttnProcessor): 522 | attn_processor.scale = scale 523 | 524 | def _encode_prompt_image_emb(self, prompt_image_emb, device, dtype, do_classifier_free_guidance): 525 | if isinstance(prompt_image_emb, torch.Tensor): 526 | prompt_image_emb = prompt_image_emb.clone().detach() 527 | else: 528 | prompt_image_emb = torch.tensor(prompt_image_emb) 529 | 530 | prompt_image_emb = prompt_image_emb.to(device=device, dtype=dtype) 531 | prompt_image_emb = prompt_image_emb.reshape([1, -1, self.image_proj_model_in_features]) 532 | 533 | if do_classifier_free_guidance: 534 | prompt_image_emb = torch.cat([torch.zeros_like(prompt_image_emb), prompt_image_emb], dim=0) 535 | else: 536 | prompt_image_emb = torch.cat([prompt_image_emb], dim=0) 537 | 538 | prompt_image_emb = self.image_proj_model(prompt_image_emb) 539 | return prompt_image_emb 540 | 541 | @torch.no_grad() 542 | @replace_example_docstring(EXAMPLE_DOC_STRING) 543 | def __call__( 544 | self, 545 | prompt: Union[str, List[str]] = None, 546 | prompt_2: Optional[Union[str, List[str]]] = None, 547 | image: PipelineImageInput = None, 548 | height: Optional[int] = None, 549 | width: Optional[int] = None, 550 | num_inference_steps: int = 50, 551 | guidance_scale: float = 5.0, 552 | negative_prompt: Optional[Union[str, List[str]]] = None, 553 | negative_prompt_2: Optional[Union[str, List[str]]] = None, 554 | num_images_per_prompt: Optional[int] = 1, 555 | eta: float = 0.0, 556 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 557 | latents: Optional[torch.FloatTensor] = None, 558 | prompt_embeds: Optional[torch.FloatTensor] = None, 559 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 560 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 561 | negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 562 | image_embeds: Optional[torch.FloatTensor] = None, 563 | output_type: Optional[str] = "pil", 564 | return_dict: bool = True, 565 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 566 | controlnet_conditioning_scale: Union[float, List[float]] = 1.0, 567 | guess_mode: bool = False, 568 | control_guidance_start: Union[float, List[float]] = 0.0, 569 | control_guidance_end: Union[float, List[float]] = 1.0, 570 | original_size: Tuple[int, int] = None, 571 | crops_coords_top_left: Tuple[int, int] = (0, 0), 572 | target_size: Tuple[int, int] = None, 573 | negative_original_size: Optional[Tuple[int, int]] = None, 574 | negative_crops_coords_top_left: Tuple[int, int] = (0, 0), 575 | negative_target_size: Optional[Tuple[int, int]] = None, 576 | clip_skip: Optional[int] = None, 577 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, 578 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], 579 | **kwargs, 580 | ): 581 | r""" 582 | The call function to the pipeline for generation. 583 | 584 | Args: 585 | prompt (`str` or `List[str]`, *optional*): 586 | The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. 587 | prompt_2 (`str` or `List[str]`, *optional*): 588 | The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is 589 | used in both text-encoders. 590 | image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: 591 | `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): 592 | The ControlNet input condition to provide guidance to the `unet` for generation. If the type is 593 | specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be 594 | accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height 595 | and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in 596 | `init`, images must be passed as a list such that each element of the list can be correctly batched for 597 | input to a single ControlNet. 598 | height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): 599 | The height in pixels of the generated image. Anything below 512 pixels won't work well for 600 | [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) 601 | and checkpoints that are not specifically fine-tuned on low resolutions. 602 | width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): 603 | The width in pixels of the generated image. Anything below 512 pixels won't work well for 604 | [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) 605 | and checkpoints that are not specifically fine-tuned on low resolutions. 606 | num_inference_steps (`int`, *optional*, defaults to 50): 607 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 608 | expense of slower inference. 609 | guidance_scale (`float`, *optional*, defaults to 5.0): 610 | A higher guidance scale value encourages the model to generate images closely linked to the text 611 | `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. 612 | negative_prompt (`str` or `List[str]`, *optional*): 613 | The prompt or prompts to guide what to not include in image generation. If not defined, you need to 614 | pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). 615 | negative_prompt_2 (`str` or `List[str]`, *optional*): 616 | The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2` 617 | and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders. 618 | num_images_per_prompt (`int`, *optional*, defaults to 1): 619 | The number of images to generate per prompt. 620 | eta (`float`, *optional*, defaults to 0.0): 621 | Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies 622 | to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. 623 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*): 624 | A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make 625 | generation deterministic. 626 | latents (`torch.FloatTensor`, *optional*): 627 | Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image 628 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 629 | tensor is generated by sampling using the supplied random `generator`. 630 | prompt_embeds (`torch.FloatTensor`, *optional*): 631 | Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not 632 | provided, text embeddings are generated from the `prompt` input argument. 633 | negative_prompt_embeds (`torch.FloatTensor`, *optional*): 634 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If 635 | not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. 636 | pooled_prompt_embeds (`torch.FloatTensor`, *optional*): 637 | Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If 638 | not provided, pooled text embeddings are generated from `prompt` input argument. 639 | negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): 640 | Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt 641 | weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input 642 | argument. 643 | image_embeds (`torch.FloatTensor`, *optional*): 644 | Pre-generated image embeddings. 645 | output_type (`str`, *optional*, defaults to `"pil"`): 646 | The output format of the generated image. Choose between `PIL.Image` or `np.array`. 647 | return_dict (`bool`, *optional*, defaults to `True`): 648 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a 649 | plain tuple. 650 | cross_attention_kwargs (`dict`, *optional*): 651 | A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in 652 | [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). 653 | controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): 654 | The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added 655 | to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set 656 | the corresponding scale as a list. 657 | guess_mode (`bool`, *optional*, defaults to `False`): 658 | The ControlNet encoder tries to recognize the content of the input image even if you remove all 659 | prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended. 660 | control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): 661 | The percentage of total steps at which the ControlNet starts applying. 662 | control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): 663 | The percentage of total steps at which the ControlNet stops applying. 664 | original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): 665 | If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. 666 | `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as 667 | explained in section 2.2 of 668 | [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). 669 | crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): 670 | `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position 671 | `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting 672 | `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of 673 | [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). 674 | target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): 675 | For most cases, `target_size` should be set to the desired height and width of the generated image. If 676 | not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in 677 | section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). 678 | negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): 679 | To negatively condition the generation process based on a specific image resolution. Part of SDXL's 680 | micro-conditioning as explained in section 2.2 of 681 | [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more 682 | information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. 683 | negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): 684 | To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's 685 | micro-conditioning as explained in section 2.2 of 686 | [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more 687 | information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. 688 | negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): 689 | To negatively condition the generation process based on a target image resolution. It should be as same 690 | as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of 691 | [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more 692 | information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. 693 | clip_skip (`int`, *optional*): 694 | Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that 695 | the output of the pre-final layer will be used for computing the prompt embeddings. 696 | callback_on_step_end (`Callable`, *optional*): 697 | A function that calls at the end of each denoising steps during the inference. The function is called 698 | with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, 699 | callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by 700 | `callback_on_step_end_tensor_inputs`. 701 | callback_on_step_end_tensor_inputs (`List`, *optional*): 702 | The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list 703 | will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the 704 | `._callback_tensor_inputs` attribute of your pipeine class. 705 | 706 | Examples: 707 | 708 | Returns: 709 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: 710 | If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, 711 | otherwise a `tuple` is returned containing the output images. 712 | """ 713 | 714 | callback = kwargs.pop("callback", None) 715 | callback_steps = kwargs.pop("callback_steps", None) 716 | 717 | if callback is not None: 718 | deprecate( 719 | "callback", 720 | "1.0.0", 721 | "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", 722 | ) 723 | if callback_steps is not None: 724 | deprecate( 725 | "callback_steps", 726 | "1.0.0", 727 | "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", 728 | ) 729 | 730 | controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet 731 | 732 | # align format for control guidance 733 | if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): 734 | control_guidance_start = len(control_guidance_end) * [control_guidance_start] 735 | elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): 736 | control_guidance_end = len(control_guidance_start) * [control_guidance_end] 737 | elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): 738 | mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 739 | control_guidance_start, control_guidance_end = ( 740 | mult * [control_guidance_start], 741 | mult * [control_guidance_end], 742 | ) 743 | 744 | # 1. Check inputs. Raise error if not correct 745 | self.check_inputs( 746 | prompt, 747 | prompt_2, 748 | image, 749 | callback_steps, 750 | negative_prompt, 751 | negative_prompt_2, 752 | prompt_embeds, 753 | negative_prompt_embeds, 754 | pooled_prompt_embeds, 755 | negative_pooled_prompt_embeds, 756 | controlnet_conditioning_scale, 757 | control_guidance_start, 758 | control_guidance_end, 759 | callback_on_step_end_tensor_inputs, 760 | ) 761 | 762 | self._guidance_scale = guidance_scale 763 | self._clip_skip = clip_skip 764 | self._cross_attention_kwargs = cross_attention_kwargs 765 | 766 | # 2. Define call parameters 767 | if prompt is not None and isinstance(prompt, str): 768 | batch_size = 1 769 | elif prompt is not None and isinstance(prompt, list): 770 | batch_size = len(prompt) 771 | else: 772 | batch_size = prompt_embeds.shape[0] 773 | 774 | device = self._execution_device 775 | 776 | if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): 777 | controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) 778 | 779 | global_pool_conditions = ( 780 | controlnet.config.global_pool_conditions 781 | if isinstance(controlnet, ControlNetModel) 782 | else controlnet.nets[0].config.global_pool_conditions 783 | ) 784 | guess_mode = guess_mode or global_pool_conditions 785 | 786 | # 3.1 Encode input prompt 787 | text_encoder_lora_scale = ( 788 | self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None 789 | ) 790 | ( 791 | prompt_embeds, 792 | negative_prompt_embeds, 793 | pooled_prompt_embeds, 794 | negative_pooled_prompt_embeds, 795 | ) = self.encode_prompt( 796 | prompt, 797 | prompt_2, 798 | device, 799 | num_images_per_prompt, 800 | self.do_classifier_free_guidance, 801 | negative_prompt, 802 | negative_prompt_2, 803 | prompt_embeds=prompt_embeds, 804 | negative_prompt_embeds=negative_prompt_embeds, 805 | pooled_prompt_embeds=pooled_prompt_embeds, 806 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, 807 | lora_scale=text_encoder_lora_scale, 808 | clip_skip=self.clip_skip, 809 | ) 810 | 811 | # 3.2 Encode image prompt 812 | prompt_image_emb = self._encode_prompt_image_emb( 813 | image_embeds, device, self.unet.dtype, self.do_classifier_free_guidance 814 | ) 815 | bs_embed, seq_len, _ = prompt_image_emb.shape 816 | prompt_image_emb = prompt_image_emb.repeat(1, num_images_per_prompt, 1) 817 | prompt_image_emb = prompt_image_emb.view(bs_embed * num_images_per_prompt, seq_len, -1) 818 | 819 | # 4. Prepare image 820 | if isinstance(controlnet, ControlNetModel): 821 | image = self.prepare_image( 822 | image=image, 823 | width=width, 824 | height=height, 825 | batch_size=batch_size * num_images_per_prompt, 826 | num_images_per_prompt=num_images_per_prompt, 827 | device=device, 828 | dtype=controlnet.dtype, 829 | do_classifier_free_guidance=self.do_classifier_free_guidance, 830 | guess_mode=guess_mode, 831 | ) 832 | height, width = image.shape[-2:] 833 | elif isinstance(controlnet, MultiControlNetModel): 834 | images = [] 835 | 836 | for image_ in image: 837 | image_ = self.prepare_image( 838 | image=image_, 839 | width=width, 840 | height=height, 841 | batch_size=batch_size * num_images_per_prompt, 842 | num_images_per_prompt=num_images_per_prompt, 843 | device=device, 844 | dtype=controlnet.dtype, 845 | do_classifier_free_guidance=self.do_classifier_free_guidance, 846 | guess_mode=guess_mode, 847 | ) 848 | 849 | images.append(image_) 850 | 851 | image = images 852 | height, width = image[0].shape[-2:] 853 | else: 854 | assert False 855 | 856 | # 5. Prepare timesteps 857 | self.scheduler.set_timesteps(num_inference_steps, device=device) 858 | timesteps = self.scheduler.timesteps 859 | self._num_timesteps = len(timesteps) 860 | 861 | # 6. Prepare latent variables 862 | num_channels_latents = self.unet.config.in_channels 863 | latents = self.prepare_latents( 864 | batch_size * num_images_per_prompt, 865 | num_channels_latents, 866 | height, 867 | width, 868 | prompt_embeds.dtype, 869 | device, 870 | generator, 871 | latents, 872 | ) 873 | 874 | # 6.5 Optionally get Guidance Scale Embedding 875 | timestep_cond = None 876 | if self.unet.config.time_cond_proj_dim is not None: 877 | guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) 878 | timestep_cond = self.get_guidance_scale_embedding( 879 | guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim 880 | ).to(device=device, dtype=latents.dtype) 881 | 882 | # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline 883 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 884 | 885 | # 7.1 Create tensor stating which controlnets to keep 886 | controlnet_keep = [] 887 | for i in range(len(timesteps)): 888 | keeps = [ 889 | 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) 890 | for s, e in zip(control_guidance_start, control_guidance_end) 891 | ] 892 | controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) 893 | 894 | # 7.2 Prepare added time ids & embeddings 895 | if isinstance(image, list): 896 | original_size = original_size or image[0].shape[-2:] 897 | else: 898 | original_size = original_size or image.shape[-2:] 899 | target_size = target_size or (height, width) 900 | 901 | add_text_embeds = pooled_prompt_embeds 902 | if self.text_encoder_2 is None: 903 | text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) 904 | else: 905 | text_encoder_projection_dim = self.text_encoder_2.config.projection_dim 906 | 907 | add_time_ids = self._get_add_time_ids( 908 | original_size, 909 | crops_coords_top_left, 910 | target_size, 911 | dtype=prompt_embeds.dtype, 912 | text_encoder_projection_dim=text_encoder_projection_dim, 913 | ) 914 | 915 | if negative_original_size is not None and negative_target_size is not None: 916 | negative_add_time_ids = self._get_add_time_ids( 917 | negative_original_size, 918 | negative_crops_coords_top_left, 919 | negative_target_size, 920 | dtype=prompt_embeds.dtype, 921 | text_encoder_projection_dim=text_encoder_projection_dim, 922 | ) 923 | else: 924 | negative_add_time_ids = add_time_ids 925 | 926 | if self.do_classifier_free_guidance: 927 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) 928 | add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) 929 | add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) 930 | 931 | prompt_embeds = prompt_embeds.to(device) 932 | add_text_embeds = add_text_embeds.to(device) 933 | add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) 934 | encoder_hidden_states = torch.cat([prompt_embeds, prompt_image_emb], dim=1) 935 | 936 | # 8. Denoising loop 937 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 938 | is_unet_compiled = is_compiled_module(self.unet) 939 | is_controlnet_compiled = is_compiled_module(self.controlnet) 940 | is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") 941 | 942 | with self.progress_bar(total=num_inference_steps) as progress_bar: 943 | for i, t in enumerate(timesteps): 944 | # Relevant thread: 945 | # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 946 | if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: 947 | torch._inductor.cudagraph_mark_step_begin() 948 | # expand the latents if we are doing classifier free guidance 949 | latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents 950 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 951 | 952 | added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} 953 | 954 | # controlnet(s) inference 955 | if guess_mode and self.do_classifier_free_guidance: 956 | # Infer ControlNet only for the conditional batch. 957 | control_model_input = latents 958 | control_model_input = self.scheduler.scale_model_input(control_model_input, t) 959 | controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] 960 | controlnet_added_cond_kwargs = { 961 | "text_embeds": add_text_embeds.chunk(2)[1], 962 | "time_ids": add_time_ids.chunk(2)[1], 963 | } 964 | else: 965 | control_model_input = latent_model_input 966 | controlnet_prompt_embeds = prompt_embeds 967 | controlnet_added_cond_kwargs = added_cond_kwargs 968 | 969 | if isinstance(controlnet_keep[i], list): 970 | cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] 971 | else: 972 | controlnet_cond_scale = controlnet_conditioning_scale 973 | if isinstance(controlnet_cond_scale, list): 974 | controlnet_cond_scale = controlnet_cond_scale[0] 975 | cond_scale = controlnet_cond_scale * controlnet_keep[i] 976 | 977 | down_block_res_samples, mid_block_res_sample = self.controlnet( 978 | control_model_input, 979 | t, 980 | encoder_hidden_states=prompt_image_emb, 981 | controlnet_cond=image, 982 | conditioning_scale=cond_scale, 983 | guess_mode=guess_mode, 984 | added_cond_kwargs=controlnet_added_cond_kwargs, 985 | return_dict=False, 986 | ) 987 | 988 | if guess_mode and self.do_classifier_free_guidance: 989 | # Infered ControlNet only for the conditional batch. 990 | # To apply the output of ControlNet to both the unconditional and conditional batches, 991 | # add 0 to the unconditional batch to keep it unchanged. 992 | down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] 993 | mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) 994 | 995 | # predict the noise residual 996 | noise_pred = self.unet( 997 | latent_model_input, 998 | t, 999 | encoder_hidden_states=encoder_hidden_states, 1000 | timestep_cond=timestep_cond, 1001 | cross_attention_kwargs=self.cross_attention_kwargs, 1002 | down_block_additional_residuals=down_block_res_samples, 1003 | mid_block_additional_residual=mid_block_res_sample, 1004 | added_cond_kwargs=added_cond_kwargs, 1005 | return_dict=False, 1006 | )[0] 1007 | 1008 | # perform guidance 1009 | if self.do_classifier_free_guidance: 1010 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 1011 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 1012 | 1013 | # compute the previous noisy sample x_t -> x_t-1 1014 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] 1015 | 1016 | if callback_on_step_end is not None: 1017 | callback_kwargs = {} 1018 | for k in callback_on_step_end_tensor_inputs: 1019 | callback_kwargs[k] = locals()[k] 1020 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) 1021 | 1022 | latents = callback_outputs.pop("latents", latents) 1023 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) 1024 | negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) 1025 | 1026 | # call the callback, if provided 1027 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 1028 | progress_bar.update() 1029 | if callback is not None and i % callback_steps == 0: 1030 | step_idx = i // getattr(self.scheduler, "order", 1) 1031 | callback(step_idx, t, latents) 1032 | 1033 | if not output_type == "latent": 1034 | # make sure the VAE is in float32 mode, as it overflows in float16 1035 | needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast 1036 | if needs_upcasting: 1037 | self.upcast_vae() 1038 | latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) 1039 | 1040 | image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] 1041 | 1042 | # cast back to fp16 if needed 1043 | if needs_upcasting: 1044 | self.vae.to(dtype=torch.float16) 1045 | else: 1046 | image = latents 1047 | 1048 | if not output_type == "latent": 1049 | # apply watermark if available 1050 | if self.watermark is not None: 1051 | image = self.watermark.apply_watermark(image) 1052 | 1053 | image = self.image_processor.postprocess(image, output_type=output_type) 1054 | 1055 | # Offload all models 1056 | self.maybe_free_model_hooks() 1057 | 1058 | if not return_dict: 1059 | return (image,) 1060 | 1061 | return StableDiffusionXLPipelineOutput(images=image) 1062 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.27.2 2 | diffusers==0.26.3 3 | gradio==4.19.0 4 | transformers==4.37.2 5 | insightface==0.7.3 6 | opencv-python==4.9.0.80 7 | onnxruntime==1.17.0 -------------------------------------------------------------------------------- /start.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | setlocal 3 | echo Starting InstantID CPU... 4 | 5 | set "PYTHON_COMMAND=python" 6 | 7 | call python --version > nul 2>&1 8 | if %errorlevel% equ 0 ( 9 | echo Python command check :OK 10 | ) else ( 11 | echo "Error: Python command not found, please install Python (Recommended : Python 3.10 or Python 3.11) and try again" 12 | pause 13 | exit /b 1 14 | 15 | ) 16 | 17 | :check_python_version 18 | for /f "tokens=2" %%I in ('%PYTHON_COMMAND% --version 2^>^&1') do ( 19 | set "python_version=%%I" 20 | ) 21 | 22 | echo Python version: %python_version% 23 | 24 | set PATH=%PATH%;%~dp0env\Lib\site-packages\openvino\libs 25 | call "%~dp0env\Scripts\activate.bat" && %PYTHON_COMMAND% "%~dp0\app.py" -------------------------------------------------------------------------------- /start.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | echo Starting InstantID CPU please wait... 3 | set -e 4 | PYTHON_COMMAND="python3" 5 | 6 | if ! command -v python3 &>/dev/null; then 7 | if ! command -v python &>/dev/null; then 8 | echo "Error: Python not found, please install python 3.10 or higher and try again" 9 | exit 1 10 | fi 11 | fi 12 | 13 | if command -v python &>/dev/null; then 14 | PYTHON_COMMAND="python" 15 | fi 16 | 17 | echo "Found $PYTHON_COMMAND command" 18 | 19 | python_version=$($PYTHON_COMMAND --version 2>&1 | awk '{print $2}') 20 | echo "Python version : $python_version" 21 | 22 | BASEDIR=$(pwd) 23 | # shellcheck disable=SC1091 24 | source "$BASEDIR/env/bin/activate" 25 | $PYTHON_COMMAND app.py -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def check_file(path: str) -> bool: 5 | return os.path.exists(path) 6 | --------------------------------------------------------------------------------