├── .gitignore
├── Readme.md
├── __init__.py
├── app.py
├── download_models.py
├── install.bat
├── install.sh
├── instantidcpu-screenshot.jpg
├── paths.py
├── pipeline_stable_diffusion_xl_instantid.py
├── requirements.txt
├── start.bat
├── start.sh
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | env
2 | *.bak
3 | *.pyc
4 | __pycache__
5 | results
6 | models
--------------------------------------------------------------------------------
/Readme.md:
--------------------------------------------------------------------------------
1 | # InstantID CPU
2 |
3 | InstantID CPU inference with less memory requirement(11GB RAM).
4 |
5 | 
6 |
7 | ## Prerequisites
8 |
9 | - Python 3.10 or higher
10 | - 11 GB system RAM
11 | - Active internet connection to install and download models
12 |
13 | ## How to run
14 |
15 | Follow these steps :
16 |
17 | ### Windows
18 |
19 | - Double click install.bat(It will take some time to install,depending on your internet speed.)
20 | - To start desktop GUI double click start.bat
21 |
22 | ## Linux/Mac
23 |
24 | Run the following command in the terminal
25 |
26 | - Clone/download this repo
27 | - In the terminal, enter into instantidcpu directory
28 | - `chmod +x install.sh`
29 | - `./install.sh`
30 | - `./start.sh`
31 |
32 | ## Disclaimer
33 |
34 | The code of InstantID is released under [Apache License](https://github.com/InstantID/InstantID?tab=Apache-2.0-1-ov-file#readme) for both academic and commercial usage. **However, both manual-downloading and auto-downloading face models from insightface are for non-commercial research purposes only** according to their [license](https://github.com/deepinsight/insightface?tab=readme-ov-file#license). Users are granted the freedom to create images using this tool, but they are obligated to comply with local laws and utilize it responsibly. The developers will not assume any responsibility for potential misuse by users.
35 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rupeshs/instantidcpu/fbf506585cc1d11862d1c7706e36a20ba06adb05/__init__.py
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | from concurrent.futures import ThreadPoolExecutor
2 | from time import time
3 |
4 | import cv2
5 | import gradio as gr
6 | import numpy as np
7 | import torch
8 | from diffusers import LCMScheduler
9 | from diffusers.models import ControlNetModel
10 | from insightface.app import FaceAnalysis
11 |
12 | from download_models import download_instant_id_sdxl_models
13 | from pipeline_stable_diffusion_xl_instantid import (
14 | StableDiffusionXLInstantIDPipeline,
15 | draw_kps,
16 | )
17 |
18 | face_adapter = f"./models/sdxl/ip-adapter.bin"
19 | controlnet_path = f"./models/sdxl/ControlNetModel"
20 | base_model = "stabilityai/sdxl-turbo"
21 | APP_VERSION = "0.1.0"
22 | device = "cpu"
23 |
24 | download_instant_id_sdxl_models()
25 |
26 | app = FaceAnalysis(
27 | name="antelopev2",
28 | root="./",
29 | providers=["CPUExecutionProvider"],
30 | )
31 | app.prepare(ctx_id=0, det_size=(320, 320))
32 | torch_dtype = torch.float32
33 |
34 | controlnet = ControlNetModel.from_pretrained(
35 | controlnet_path,
36 | torch_dtype=torch_dtype,
37 | resume_download=True,
38 | )
39 | pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
40 | base_model,
41 | controlnet=controlnet,
42 | torch_dtype=torch_dtype,
43 | resume_download=True,
44 | )
45 | pipe.load_ip_adapter_instantid(face_adapter)
46 | pipe.to(device)
47 | pipe.unet = pipe.unet.to(memory_format=torch.channels_last)
48 | pipe.controlnet = pipe.controlnet.to(memory_format=torch.channels_last)
49 |
50 | pipe.text_encoder = pipe.text_encoder.to(memory_format=torch.channels_last)
51 | pipe.scheduler = LCMScheduler.from_config(
52 | pipe.scheduler.config,
53 | beta_start=0.001,
54 | beta_end=0.012,
55 | )
56 |
57 | # pipe.scheduler = DEISMultistepScheduler.from_config(
58 | # pipe.scheduler.config,
59 | # )
60 | pipe.enable_freeu(
61 | s1=0.6,
62 | s2=0.4,
63 | b1=1.1,
64 | b2=1.2,
65 | )
66 | print(pipe)
67 |
68 |
69 | def generate_image(
70 | face_image,
71 | prompt,
72 | identitynet_strength_ratio,
73 | adapter_strength_ratio,
74 | ):
75 | print(f"identitynet_strength_ratio :{identitynet_strength_ratio}")
76 | print(f"adapter_strength_ratio :{adapter_strength_ratio}")
77 | if prompt == "":
78 | prompt = "Photo of a person,high quality"
79 | face_image = face_image.resize((960, 1024))
80 | face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))
81 | face_info = sorted(
82 | face_info,
83 | key=lambda x: (x["bbox"][2] - x["bbox"][0]) * x["bbox"][3] - x["bbox"][1],
84 | )[
85 | -1
86 | ] # only use the maximum face
87 | face_emb = face_info["embedding"]
88 | face_kps = draw_kps(face_image, face_info["kps"])
89 |
90 | # generate image
91 | pipe.set_ip_adapter_scale(adapter_strength_ratio)
92 | print("start")
93 | images = pipe(
94 | prompt,
95 | image_embeds=face_emb,
96 | image=face_kps,
97 | controlnet_conditioning_scale=identitynet_strength_ratio,
98 | num_inference_steps=1,
99 | guidance_scale=0.0,
100 | ).images
101 | return images
102 |
103 |
104 | def process_image(
105 | face_image,
106 | prompt,
107 | identitynet_strength_ratio,
108 | adapter_strength_ratio,
109 | ):
110 | tick = time()
111 | with ThreadPoolExecutor(max_workers=1) as executor:
112 | future = executor.submit(
113 | generate_image,
114 | face_image,
115 | prompt,
116 | identitynet_strength_ratio,
117 | adapter_strength_ratio,
118 | )
119 | images = future.result()
120 | elapsed = time() - tick
121 | print(f"Latency : {elapsed:.2f} seconds")
122 | return images[0]
123 |
124 |
125 | def _get_footer_message() -> str:
126 | version = f"
v{APP_VERSION}"
127 | footer_msg = version
128 | return footer_msg
129 |
130 |
131 | css = """
132 | #generate_button {
133 | color: white;
134 | border-color: #007bff;
135 | background: #2563eb;
136 |
137 | }
138 | """
139 |
140 |
141 | def get_web_ui() -> gr.Blocks:
142 | with gr.Blocks(
143 | css=css,
144 | title="InstantID CPU",
145 | ) as web_ui:
146 | gr.HTML("
InstantID CPU
")
147 | with gr.Row():
148 | with gr.Column():
149 | input_image = gr.Image(
150 | label="Face image",
151 | type="pil",
152 | height=512,
153 | )
154 | with gr.Row():
155 | prompt = gr.Textbox(
156 | show_label=False,
157 | lines=3,
158 | placeholder="Oil painting",
159 | container=False,
160 | )
161 |
162 | generate_btn = gr.Button(
163 | "Generate",
164 | elem_id="generate_button",
165 | scale=0,
166 | )
167 | identitynet_strength_ratio = gr.Slider(
168 | label="IdentityNet strength (for fidelity)",
169 | minimum=0,
170 | maximum=1.5,
171 | step=0.05,
172 | value=0.80,
173 | )
174 | adapter_strength_ratio = gr.Slider(
175 | label="Image adapter strength (for detail)",
176 | minimum=0,
177 | maximum=1.5,
178 | step=0.05,
179 | value=0.80,
180 | )
181 |
182 | input_params = [
183 | input_image,
184 | prompt,
185 | identitynet_strength_ratio,
186 | adapter_strength_ratio,
187 | ]
188 |
189 | with gr.Column():
190 | gallery = gr.Image(
191 | label="Generated Image",
192 | )
193 | generate_btn.click(
194 | fn=process_image,
195 | inputs=input_params,
196 | outputs=gallery,
197 | )
198 |
199 | gr.HTML(_get_footer_message())
200 |
201 | return web_ui
202 |
203 |
204 | def start_webui(
205 | share: bool = False,
206 | ):
207 | webui = get_web_ui()
208 | webui.queue()
209 | webui.launch(share=share)
210 |
211 |
212 | start_webui()
213 |
--------------------------------------------------------------------------------
/download_models.py:
--------------------------------------------------------------------------------
1 | from huggingface_hub import hf_hub_download
2 | from huggingface_hub import snapshot_download
3 |
4 |
5 | def download_instant_id_sdxl_models():
6 | hf_hub_download(
7 | repo_id="InstantX/InstantID",
8 | filename="ControlNetModel/config.json",
9 | local_dir="./models/sdxl/",
10 | )
11 | hf_hub_download(
12 | repo_id="InstantX/InstantID",
13 | filename="ControlNetModel/diffusion_pytorch_model.safetensors",
14 | local_dir="./models/sdxl/",
15 | )
16 | hf_hub_download(
17 | repo_id="InstantX/InstantID",
18 | filename="ip-adapter.bin",
19 | local_dir="./models/sdxl/",
20 | )
21 | snapshot_download(
22 | repo_id="rupeshs/antelopev2",
23 | local_dir="./models/antelopev2",
24 | )
25 |
--------------------------------------------------------------------------------
/install.bat:
--------------------------------------------------------------------------------
1 |
2 | @echo off
3 | setlocal
4 | echo Starting InstantID CPU env installation...
5 |
6 | set "PYTHON_COMMAND=python"
7 |
8 | call python --version > nul 2>&1
9 | if %errorlevel% equ 0 (
10 | echo Python command check :OK
11 | ) else (
12 | echo "Error: Python command not found,please install Python(Recommended : Python 3.10 or higher) and try again."
13 | pause
14 | exit /b 1
15 |
16 | )
17 |
18 | :check_python_version
19 | for /f "tokens=2" %%I in ('%PYTHON_COMMAND% --version 2^>^&1') do (
20 | set "python_version=%%I"
21 | )
22 |
23 | echo Python version: %python_version%
24 |
25 | %PYTHON_COMMAND% -m venv "%~dp0env"
26 | call "%~dp0env\Scripts\activate.bat" && pip install torch==2.2.0 --index-url https://download.pytorch.org/whl/cpu
27 | call "%~dp0env\Scripts\activate.bat" && pip install -r "%~dp0requirements.txt"
28 | echo InstantID CPU env installation completed.
29 | pause
--------------------------------------------------------------------------------
/install.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | echo Starting InstantID CPU env installation...
3 | set -e
4 | PYTHON_COMMAND="python3"
5 |
6 | if ! command -v python3 &>/dev/null; then
7 | if ! command -v python &>/dev/null; then
8 | echo "Error: Python not found, please install python 3.10 or higher and try again"
9 | exit 1
10 | fi
11 | fi
12 |
13 | if command -v python &>/dev/null; then
14 | PYTHON_COMMAND="python"
15 | fi
16 |
17 | echo "Found $PYTHON_COMMAND command"
18 |
19 | python_version=$($PYTHON_COMMAND --version 2>&1 | awk '{print $2}')
20 | echo "Python version : $python_version"
21 |
22 | BASEDIR=$(pwd)
23 |
24 | $PYTHON_COMMAND -m venv "$BASEDIR/env"
25 | # shellcheck disable=SC1091
26 | source "$BASEDIR/env/bin/activate"
27 | pip install torch==2.2.0 --index-url https://download.pytorch.org/whl/cpu
28 | pip install -r "$BASEDIR/requirements.txt"
29 | chmod +x "start.sh"
30 | read -n1 -r -p "InstantID CPU installation completed,press any key to continue..." key
--------------------------------------------------------------------------------
/instantidcpu-screenshot.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rupeshs/instantidcpu/fbf506585cc1d11862d1c7706e36a20ba06adb05/instantidcpu-screenshot.jpg
--------------------------------------------------------------------------------
/paths.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 |
4 |
5 | def ensure_file(path: str) -> bool:
6 | return os.path.exists(path)
7 |
8 |
9 | def join_paths(
10 | first_path: str,
11 | second_path: str,
12 | ) -> str:
13 | return os.path.join(first_path, second_path)
14 |
15 |
16 | def get_file_name(file_path: str) -> str:
17 | return Path(file_path).stem
18 |
19 |
20 | def get_app_path() -> str:
21 | app_dir = os.path.dirname(__file__)
22 | work_dir = os.path.dirname(app_dir)
23 | return work_dir
24 |
--------------------------------------------------------------------------------
/pipeline_stable_diffusion_xl_instantid.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The InstantX Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import math
17 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union
18 |
19 | import cv2
20 | import numpy as np
21 | import PIL.Image
22 | import torch
23 | import torch.nn as nn
24 |
25 | from diffusers import StableDiffusionXLControlNetPipeline
26 | from diffusers.image_processor import PipelineImageInput
27 | from diffusers.models import ControlNetModel
28 | from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
29 | from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
30 | from diffusers.utils import (
31 | deprecate,
32 | logging,
33 | replace_example_docstring,
34 | )
35 | from diffusers.utils.import_utils import is_xformers_available
36 | from diffusers.utils.torch_utils import is_compiled_module, is_torch_version
37 |
38 |
39 | try:
40 | import xformers
41 | import xformers.ops
42 |
43 | xformers_available = True
44 | except Exception:
45 | xformers_available = False
46 |
47 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
48 |
49 |
50 | def FeedForward(dim, mult=4):
51 | inner_dim = int(dim * mult)
52 | return nn.Sequential(
53 | nn.LayerNorm(dim),
54 | nn.Linear(dim, inner_dim, bias=False),
55 | nn.GELU(),
56 | nn.Linear(inner_dim, dim, bias=False),
57 | )
58 |
59 |
60 | def reshape_tensor(x, heads):
61 | bs, length, width = x.shape
62 | # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
63 | x = x.view(bs, length, heads, -1)
64 | # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
65 | x = x.transpose(1, 2)
66 | # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
67 | x = x.reshape(bs, heads, length, -1)
68 | return x
69 |
70 |
71 | class PerceiverAttention(nn.Module):
72 | def __init__(self, *, dim, dim_head=64, heads=8):
73 | super().__init__()
74 | self.scale = dim_head**-0.5
75 | self.dim_head = dim_head
76 | self.heads = heads
77 | inner_dim = dim_head * heads
78 |
79 | self.norm1 = nn.LayerNorm(dim)
80 | self.norm2 = nn.LayerNorm(dim)
81 |
82 | self.to_q = nn.Linear(dim, inner_dim, bias=False)
83 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
84 | self.to_out = nn.Linear(inner_dim, dim, bias=False)
85 |
86 | def forward(self, x, latents):
87 | """
88 | Args:
89 | x (torch.Tensor): image features
90 | shape (b, n1, D)
91 | latent (torch.Tensor): latent features
92 | shape (b, n2, D)
93 | """
94 | x = self.norm1(x)
95 | latents = self.norm2(latents)
96 |
97 | b, l, _ = latents.shape
98 |
99 | q = self.to_q(latents)
100 | kv_input = torch.cat((x, latents), dim=-2)
101 | k, v = self.to_kv(kv_input).chunk(2, dim=-1)
102 |
103 | q = reshape_tensor(q, self.heads)
104 | k = reshape_tensor(k, self.heads)
105 | v = reshape_tensor(v, self.heads)
106 |
107 | # attention
108 | scale = 1 / math.sqrt(math.sqrt(self.dim_head))
109 | weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
110 | weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
111 | out = weight @ v
112 |
113 | out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
114 |
115 | return self.to_out(out)
116 |
117 |
118 | class Resampler(nn.Module):
119 | def __init__(
120 | self,
121 | dim=1024,
122 | depth=8,
123 | dim_head=64,
124 | heads=16,
125 | num_queries=8,
126 | embedding_dim=768,
127 | output_dim=1024,
128 | ff_mult=4,
129 | ):
130 | super().__init__()
131 |
132 | self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
133 |
134 | self.proj_in = nn.Linear(embedding_dim, dim)
135 |
136 | self.proj_out = nn.Linear(dim, output_dim)
137 | self.norm_out = nn.LayerNorm(output_dim)
138 |
139 | self.layers = nn.ModuleList([])
140 | for _ in range(depth):
141 | self.layers.append(
142 | nn.ModuleList(
143 | [
144 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
145 | FeedForward(dim=dim, mult=ff_mult),
146 | ]
147 | )
148 | )
149 |
150 | def forward(self, x):
151 | latents = self.latents.repeat(x.size(0), 1, 1)
152 | x = self.proj_in(x)
153 |
154 | for attn, ff in self.layers:
155 | latents = attn(x, latents) + latents
156 | latents = ff(latents) + latents
157 |
158 | latents = self.proj_out(latents)
159 | return self.norm_out(latents)
160 |
161 |
162 | class AttnProcessor(nn.Module):
163 | r"""
164 | Default processor for performing attention-related computations.
165 | """
166 |
167 | def __init__(
168 | self,
169 | hidden_size=None,
170 | cross_attention_dim=None,
171 | ):
172 | super().__init__()
173 |
174 | def __call__(
175 | self,
176 | attn,
177 | hidden_states,
178 | encoder_hidden_states=None,
179 | attention_mask=None,
180 | temb=None,
181 | ):
182 | residual = hidden_states
183 |
184 | if attn.spatial_norm is not None:
185 | hidden_states = attn.spatial_norm(hidden_states, temb)
186 |
187 | input_ndim = hidden_states.ndim
188 |
189 | if input_ndim == 4:
190 | batch_size, channel, height, width = hidden_states.shape
191 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
192 |
193 | batch_size, sequence_length, _ = (
194 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
195 | )
196 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
197 |
198 | if attn.group_norm is not None:
199 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
200 |
201 | query = attn.to_q(hidden_states)
202 |
203 | if encoder_hidden_states is None:
204 | encoder_hidden_states = hidden_states
205 | elif attn.norm_cross:
206 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
207 |
208 | key = attn.to_k(encoder_hidden_states)
209 | value = attn.to_v(encoder_hidden_states)
210 |
211 | query = attn.head_to_batch_dim(query)
212 | key = attn.head_to_batch_dim(key)
213 | value = attn.head_to_batch_dim(value)
214 |
215 | attention_probs = attn.get_attention_scores(query, key, attention_mask)
216 | hidden_states = torch.bmm(attention_probs, value)
217 | hidden_states = attn.batch_to_head_dim(hidden_states)
218 |
219 | # linear proj
220 | hidden_states = attn.to_out[0](hidden_states)
221 | # dropout
222 | hidden_states = attn.to_out[1](hidden_states)
223 |
224 | if input_ndim == 4:
225 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
226 |
227 | if attn.residual_connection:
228 | hidden_states = hidden_states + residual
229 |
230 | hidden_states = hidden_states / attn.rescale_output_factor
231 |
232 | return hidden_states
233 |
234 |
235 | class IPAttnProcessor(nn.Module):
236 | r"""
237 | Attention processor for IP-Adapater.
238 | Args:
239 | hidden_size (`int`):
240 | The hidden size of the attention layer.
241 | cross_attention_dim (`int`):
242 | The number of channels in the `encoder_hidden_states`.
243 | scale (`float`, defaults to 1.0):
244 | the weight scale of image prompt.
245 | num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
246 | The context length of the image features.
247 | """
248 |
249 | def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
250 | super().__init__()
251 |
252 | self.hidden_size = hidden_size
253 | self.cross_attention_dim = cross_attention_dim
254 | self.scale = scale
255 | self.num_tokens = num_tokens
256 |
257 | self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
258 | self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
259 |
260 | def __call__(
261 | self,
262 | attn,
263 | hidden_states,
264 | encoder_hidden_states=None,
265 | attention_mask=None,
266 | temb=None,
267 | ):
268 | residual = hidden_states
269 |
270 | if attn.spatial_norm is not None:
271 | hidden_states = attn.spatial_norm(hidden_states, temb)
272 |
273 | input_ndim = hidden_states.ndim
274 |
275 | if input_ndim == 4:
276 | batch_size, channel, height, width = hidden_states.shape
277 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
278 |
279 | batch_size, sequence_length, _ = (
280 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
281 | )
282 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
283 |
284 | if attn.group_norm is not None:
285 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
286 |
287 | query = attn.to_q(hidden_states)
288 |
289 | if encoder_hidden_states is None:
290 | encoder_hidden_states = hidden_states
291 | else:
292 | # get encoder_hidden_states, ip_hidden_states
293 | end_pos = encoder_hidden_states.shape[1] - self.num_tokens
294 | encoder_hidden_states, ip_hidden_states = (
295 | encoder_hidden_states[:, :end_pos, :],
296 | encoder_hidden_states[:, end_pos:, :],
297 | )
298 | if attn.norm_cross:
299 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
300 |
301 | key = attn.to_k(encoder_hidden_states)
302 | value = attn.to_v(encoder_hidden_states)
303 |
304 | query = attn.head_to_batch_dim(query)
305 | key = attn.head_to_batch_dim(key)
306 | value = attn.head_to_batch_dim(value)
307 |
308 | if xformers_available:
309 | hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
310 | else:
311 | attention_probs = attn.get_attention_scores(query, key, attention_mask)
312 | hidden_states = torch.bmm(attention_probs, value)
313 | hidden_states = attn.batch_to_head_dim(hidden_states)
314 |
315 | # for ip-adapter
316 | ip_key = self.to_k_ip(ip_hidden_states)
317 | ip_value = self.to_v_ip(ip_hidden_states)
318 |
319 | ip_key = attn.head_to_batch_dim(ip_key)
320 | ip_value = attn.head_to_batch_dim(ip_value)
321 |
322 | if xformers_available:
323 | ip_hidden_states = self._memory_efficient_attention_xformers(query, ip_key, ip_value, None)
324 | else:
325 | ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
326 | ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
327 | ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
328 |
329 | hidden_states = hidden_states + self.scale * ip_hidden_states
330 |
331 | # linear proj
332 | hidden_states = attn.to_out[0](hidden_states)
333 | # dropout
334 | hidden_states = attn.to_out[1](hidden_states)
335 |
336 | if input_ndim == 4:
337 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
338 |
339 | if attn.residual_connection:
340 | hidden_states = hidden_states + residual
341 |
342 | hidden_states = hidden_states / attn.rescale_output_factor
343 |
344 | return hidden_states
345 |
346 | def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
347 | # TODO attention_mask
348 | query = query.contiguous()
349 | key = key.contiguous()
350 | value = value.contiguous()
351 | hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
352 | return hidden_states
353 |
354 |
355 | EXAMPLE_DOC_STRING = """
356 | Examples:
357 | ```py
358 | >>> # !pip install opencv-python transformers accelerate insightface
359 | >>> import diffusers
360 | >>> from diffusers.utils import load_image
361 | >>> from diffusers.models import ControlNetModel
362 |
363 | >>> import cv2
364 | >>> import torch
365 | >>> import numpy as np
366 | >>> from PIL import Image
367 |
368 | >>> from insightface.app import FaceAnalysis
369 | >>> from pipeline_stable_diffusion_xl_instantid import StableDiffusionXLInstantIDPipeline, draw_kps
370 |
371 | >>> # download 'antelopev2' under ./models
372 | >>> app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
373 | >>> app.prepare(ctx_id=0, det_size=(640, 640))
374 |
375 | >>> # download models under ./checkpoints
376 | >>> face_adapter = f'./checkpoints/ip-adapter.bin'
377 | >>> controlnet_path = f'./checkpoints/ControlNetModel'
378 |
379 | >>> # load IdentityNet
380 | >>> controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
381 |
382 | >>> pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
383 | ... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16
384 | ... )
385 | >>> pipe.cuda()
386 |
387 | >>> # load adapter
388 | >>> pipe.load_ip_adapter_instantid(face_adapter)
389 |
390 | >>> prompt = "analog film photo of a man. faded film, desaturated, 35mm photo, grainy, vignette, vintage, Kodachrome, Lomography, stained, highly detailed, found footage, masterpiece, best quality"
391 | >>> negative_prompt = "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured (lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch,deformed, mutated, cross-eyed, ugly, disfigured"
392 |
393 | >>> # load an image
394 | >>> image = load_image("your-example.jpg")
395 |
396 | >>> face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))[-1]
397 | >>> face_emb = face_info['embedding']
398 | >>> face_kps = draw_kps(face_image, face_info['kps'])
399 |
400 | >>> pipe.set_ip_adapter_scale(0.8)
401 |
402 | >>> # generate image
403 | >>> image = pipe(
404 | ... prompt, image_embeds=face_emb, image=face_kps, controlnet_conditioning_scale=0.8
405 | ... ).images[0]
406 | ```
407 | """
408 |
409 |
410 | def draw_kps(image_pil, kps, color_list=[(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]):
411 | stickwidth = 4
412 | limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]])
413 | kps = np.array(kps)
414 |
415 | w, h = image_pil.size
416 | out_img = np.zeros([h, w, 3])
417 |
418 | for i in range(len(limbSeq)):
419 | index = limbSeq[i]
420 | color = color_list[index[0]]
421 |
422 | x = kps[index][:, 0]
423 | y = kps[index][:, 1]
424 | length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5
425 | angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))
426 | polygon = cv2.ellipse2Poly(
427 | (int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1
428 | )
429 | out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color)
430 | out_img = (out_img * 0.6).astype(np.uint8)
431 |
432 | for idx_kp, kp in enumerate(kps):
433 | color = color_list[idx_kp]
434 | x, y = kp
435 | out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1)
436 |
437 | out_img_pil = PIL.Image.fromarray(out_img.astype(np.uint8))
438 | return out_img_pil
439 |
440 |
441 | class StableDiffusionXLInstantIDPipeline(StableDiffusionXLControlNetPipeline):
442 | def cuda(self, dtype=torch.float16, use_xformers=False):
443 | self.to("cuda", dtype)
444 |
445 | if hasattr(self, "image_proj_model"):
446 | self.image_proj_model.to(self.unet.device).to(self.unet.dtype)
447 |
448 | if use_xformers:
449 | if is_xformers_available():
450 | import xformers
451 | from packaging import version
452 |
453 | xformers_version = version.parse(xformers.__version__)
454 | if xformers_version == version.parse("0.0.16"):
455 | logger.warn(
456 | "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
457 | )
458 | self.enable_xformers_memory_efficient_attention()
459 | else:
460 | raise ValueError("xformers is not available. Make sure it is installed correctly")
461 |
462 | def load_ip_adapter_instantid(self, model_ckpt, image_emb_dim=512, num_tokens=16, scale=0.5):
463 | self.set_image_proj_model(model_ckpt, image_emb_dim, num_tokens)
464 | self.set_ip_adapter(model_ckpt, num_tokens, scale)
465 |
466 | def set_image_proj_model(self, model_ckpt, image_emb_dim=512, num_tokens=16):
467 | image_proj_model = Resampler(
468 | dim=1280,
469 | depth=4,
470 | dim_head=64,
471 | heads=20,
472 | num_queries=num_tokens,
473 | embedding_dim=image_emb_dim,
474 | output_dim=self.unet.config.cross_attention_dim,
475 | ff_mult=4,
476 | )
477 |
478 | image_proj_model.eval()
479 |
480 | self.image_proj_model = image_proj_model.to(self.device, dtype=self.dtype)
481 | state_dict = torch.load(model_ckpt, map_location="cpu")
482 | if "image_proj" in state_dict:
483 | state_dict = state_dict["image_proj"]
484 | self.image_proj_model.load_state_dict(state_dict)
485 |
486 | self.image_proj_model_in_features = image_emb_dim
487 |
488 | def set_ip_adapter(self, model_ckpt, num_tokens, scale):
489 | unet = self.unet
490 | attn_procs = {}
491 | for name in unet.attn_processors.keys():
492 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
493 | if name.startswith("mid_block"):
494 | hidden_size = unet.config.block_out_channels[-1]
495 | elif name.startswith("up_blocks"):
496 | block_id = int(name[len("up_blocks.")])
497 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
498 | elif name.startswith("down_blocks"):
499 | block_id = int(name[len("down_blocks.")])
500 | hidden_size = unet.config.block_out_channels[block_id]
501 | if cross_attention_dim is None:
502 | attn_procs[name] = AttnProcessor().to(unet.device, dtype=unet.dtype)
503 | else:
504 | attn_procs[name] = IPAttnProcessor(
505 | hidden_size=hidden_size,
506 | cross_attention_dim=cross_attention_dim,
507 | scale=scale,
508 | num_tokens=num_tokens,
509 | ).to(unet.device, dtype=unet.dtype)
510 | unet.set_attn_processor(attn_procs)
511 |
512 | state_dict = torch.load(model_ckpt, map_location="cpu")
513 | ip_layers = torch.nn.ModuleList(self.unet.attn_processors.values())
514 | if "ip_adapter" in state_dict:
515 | state_dict = state_dict["ip_adapter"]
516 | ip_layers.load_state_dict(state_dict)
517 |
518 | def set_ip_adapter_scale(self, scale):
519 | unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
520 | for attn_processor in unet.attn_processors.values():
521 | if isinstance(attn_processor, IPAttnProcessor):
522 | attn_processor.scale = scale
523 |
524 | def _encode_prompt_image_emb(self, prompt_image_emb, device, dtype, do_classifier_free_guidance):
525 | if isinstance(prompt_image_emb, torch.Tensor):
526 | prompt_image_emb = prompt_image_emb.clone().detach()
527 | else:
528 | prompt_image_emb = torch.tensor(prompt_image_emb)
529 |
530 | prompt_image_emb = prompt_image_emb.to(device=device, dtype=dtype)
531 | prompt_image_emb = prompt_image_emb.reshape([1, -1, self.image_proj_model_in_features])
532 |
533 | if do_classifier_free_guidance:
534 | prompt_image_emb = torch.cat([torch.zeros_like(prompt_image_emb), prompt_image_emb], dim=0)
535 | else:
536 | prompt_image_emb = torch.cat([prompt_image_emb], dim=0)
537 |
538 | prompt_image_emb = self.image_proj_model(prompt_image_emb)
539 | return prompt_image_emb
540 |
541 | @torch.no_grad()
542 | @replace_example_docstring(EXAMPLE_DOC_STRING)
543 | def __call__(
544 | self,
545 | prompt: Union[str, List[str]] = None,
546 | prompt_2: Optional[Union[str, List[str]]] = None,
547 | image: PipelineImageInput = None,
548 | height: Optional[int] = None,
549 | width: Optional[int] = None,
550 | num_inference_steps: int = 50,
551 | guidance_scale: float = 5.0,
552 | negative_prompt: Optional[Union[str, List[str]]] = None,
553 | negative_prompt_2: Optional[Union[str, List[str]]] = None,
554 | num_images_per_prompt: Optional[int] = 1,
555 | eta: float = 0.0,
556 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
557 | latents: Optional[torch.FloatTensor] = None,
558 | prompt_embeds: Optional[torch.FloatTensor] = None,
559 | negative_prompt_embeds: Optional[torch.FloatTensor] = None,
560 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
561 | negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
562 | image_embeds: Optional[torch.FloatTensor] = None,
563 | output_type: Optional[str] = "pil",
564 | return_dict: bool = True,
565 | cross_attention_kwargs: Optional[Dict[str, Any]] = None,
566 | controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
567 | guess_mode: bool = False,
568 | control_guidance_start: Union[float, List[float]] = 0.0,
569 | control_guidance_end: Union[float, List[float]] = 1.0,
570 | original_size: Tuple[int, int] = None,
571 | crops_coords_top_left: Tuple[int, int] = (0, 0),
572 | target_size: Tuple[int, int] = None,
573 | negative_original_size: Optional[Tuple[int, int]] = None,
574 | negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
575 | negative_target_size: Optional[Tuple[int, int]] = None,
576 | clip_skip: Optional[int] = None,
577 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
578 | callback_on_step_end_tensor_inputs: List[str] = ["latents"],
579 | **kwargs,
580 | ):
581 | r"""
582 | The call function to the pipeline for generation.
583 |
584 | Args:
585 | prompt (`str` or `List[str]`, *optional*):
586 | The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
587 | prompt_2 (`str` or `List[str]`, *optional*):
588 | The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
589 | used in both text-encoders.
590 | image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
591 | `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
592 | The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
593 | specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be
594 | accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
595 | and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in
596 | `init`, images must be passed as a list such that each element of the list can be correctly batched for
597 | input to a single ControlNet.
598 | height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
599 | The height in pixels of the generated image. Anything below 512 pixels won't work well for
600 | [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
601 | and checkpoints that are not specifically fine-tuned on low resolutions.
602 | width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
603 | The width in pixels of the generated image. Anything below 512 pixels won't work well for
604 | [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
605 | and checkpoints that are not specifically fine-tuned on low resolutions.
606 | num_inference_steps (`int`, *optional*, defaults to 50):
607 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the
608 | expense of slower inference.
609 | guidance_scale (`float`, *optional*, defaults to 5.0):
610 | A higher guidance scale value encourages the model to generate images closely linked to the text
611 | `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
612 | negative_prompt (`str` or `List[str]`, *optional*):
613 | The prompt or prompts to guide what to not include in image generation. If not defined, you need to
614 | pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
615 | negative_prompt_2 (`str` or `List[str]`, *optional*):
616 | The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2`
617 | and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders.
618 | num_images_per_prompt (`int`, *optional*, defaults to 1):
619 | The number of images to generate per prompt.
620 | eta (`float`, *optional*, defaults to 0.0):
621 | Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
622 | to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
623 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
624 | A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
625 | generation deterministic.
626 | latents (`torch.FloatTensor`, *optional*):
627 | Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
628 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
629 | tensor is generated by sampling using the supplied random `generator`.
630 | prompt_embeds (`torch.FloatTensor`, *optional*):
631 | Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
632 | provided, text embeddings are generated from the `prompt` input argument.
633 | negative_prompt_embeds (`torch.FloatTensor`, *optional*):
634 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
635 | not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
636 | pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
637 | Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
638 | not provided, pooled text embeddings are generated from `prompt` input argument.
639 | negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
640 | Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt
641 | weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input
642 | argument.
643 | image_embeds (`torch.FloatTensor`, *optional*):
644 | Pre-generated image embeddings.
645 | output_type (`str`, *optional*, defaults to `"pil"`):
646 | The output format of the generated image. Choose between `PIL.Image` or `np.array`.
647 | return_dict (`bool`, *optional*, defaults to `True`):
648 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
649 | plain tuple.
650 | cross_attention_kwargs (`dict`, *optional*):
651 | A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
652 | [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
653 | controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
654 | The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
655 | to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
656 | the corresponding scale as a list.
657 | guess_mode (`bool`, *optional*, defaults to `False`):
658 | The ControlNet encoder tries to recognize the content of the input image even if you remove all
659 | prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.
660 | control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
661 | The percentage of total steps at which the ControlNet starts applying.
662 | control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
663 | The percentage of total steps at which the ControlNet stops applying.
664 | original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
665 | If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
666 | `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
667 | explained in section 2.2 of
668 | [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
669 | crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
670 | `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
671 | `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
672 | `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
673 | [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
674 | target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
675 | For most cases, `target_size` should be set to the desired height and width of the generated image. If
676 | not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
677 | section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
678 | negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
679 | To negatively condition the generation process based on a specific image resolution. Part of SDXL's
680 | micro-conditioning as explained in section 2.2 of
681 | [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
682 | information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
683 | negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
684 | To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
685 | micro-conditioning as explained in section 2.2 of
686 | [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
687 | information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
688 | negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
689 | To negatively condition the generation process based on a target image resolution. It should be as same
690 | as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
691 | [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
692 | information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
693 | clip_skip (`int`, *optional*):
694 | Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
695 | the output of the pre-final layer will be used for computing the prompt embeddings.
696 | callback_on_step_end (`Callable`, *optional*):
697 | A function that calls at the end of each denoising steps during the inference. The function is called
698 | with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
699 | callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
700 | `callback_on_step_end_tensor_inputs`.
701 | callback_on_step_end_tensor_inputs (`List`, *optional*):
702 | The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
703 | will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
704 | `._callback_tensor_inputs` attribute of your pipeine class.
705 |
706 | Examples:
707 |
708 | Returns:
709 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
710 | If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
711 | otherwise a `tuple` is returned containing the output images.
712 | """
713 |
714 | callback = kwargs.pop("callback", None)
715 | callback_steps = kwargs.pop("callback_steps", None)
716 |
717 | if callback is not None:
718 | deprecate(
719 | "callback",
720 | "1.0.0",
721 | "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
722 | )
723 | if callback_steps is not None:
724 | deprecate(
725 | "callback_steps",
726 | "1.0.0",
727 | "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
728 | )
729 |
730 | controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
731 |
732 | # align format for control guidance
733 | if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
734 | control_guidance_start = len(control_guidance_end) * [control_guidance_start]
735 | elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
736 | control_guidance_end = len(control_guidance_start) * [control_guidance_end]
737 | elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
738 | mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
739 | control_guidance_start, control_guidance_end = (
740 | mult * [control_guidance_start],
741 | mult * [control_guidance_end],
742 | )
743 |
744 | # 1. Check inputs. Raise error if not correct
745 | self.check_inputs(
746 | prompt,
747 | prompt_2,
748 | image,
749 | callback_steps,
750 | negative_prompt,
751 | negative_prompt_2,
752 | prompt_embeds,
753 | negative_prompt_embeds,
754 | pooled_prompt_embeds,
755 | negative_pooled_prompt_embeds,
756 | controlnet_conditioning_scale,
757 | control_guidance_start,
758 | control_guidance_end,
759 | callback_on_step_end_tensor_inputs,
760 | )
761 |
762 | self._guidance_scale = guidance_scale
763 | self._clip_skip = clip_skip
764 | self._cross_attention_kwargs = cross_attention_kwargs
765 |
766 | # 2. Define call parameters
767 | if prompt is not None and isinstance(prompt, str):
768 | batch_size = 1
769 | elif prompt is not None and isinstance(prompt, list):
770 | batch_size = len(prompt)
771 | else:
772 | batch_size = prompt_embeds.shape[0]
773 |
774 | device = self._execution_device
775 |
776 | if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
777 | controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
778 |
779 | global_pool_conditions = (
780 | controlnet.config.global_pool_conditions
781 | if isinstance(controlnet, ControlNetModel)
782 | else controlnet.nets[0].config.global_pool_conditions
783 | )
784 | guess_mode = guess_mode or global_pool_conditions
785 |
786 | # 3.1 Encode input prompt
787 | text_encoder_lora_scale = (
788 | self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
789 | )
790 | (
791 | prompt_embeds,
792 | negative_prompt_embeds,
793 | pooled_prompt_embeds,
794 | negative_pooled_prompt_embeds,
795 | ) = self.encode_prompt(
796 | prompt,
797 | prompt_2,
798 | device,
799 | num_images_per_prompt,
800 | self.do_classifier_free_guidance,
801 | negative_prompt,
802 | negative_prompt_2,
803 | prompt_embeds=prompt_embeds,
804 | negative_prompt_embeds=negative_prompt_embeds,
805 | pooled_prompt_embeds=pooled_prompt_embeds,
806 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
807 | lora_scale=text_encoder_lora_scale,
808 | clip_skip=self.clip_skip,
809 | )
810 |
811 | # 3.2 Encode image prompt
812 | prompt_image_emb = self._encode_prompt_image_emb(
813 | image_embeds, device, self.unet.dtype, self.do_classifier_free_guidance
814 | )
815 | bs_embed, seq_len, _ = prompt_image_emb.shape
816 | prompt_image_emb = prompt_image_emb.repeat(1, num_images_per_prompt, 1)
817 | prompt_image_emb = prompt_image_emb.view(bs_embed * num_images_per_prompt, seq_len, -1)
818 |
819 | # 4. Prepare image
820 | if isinstance(controlnet, ControlNetModel):
821 | image = self.prepare_image(
822 | image=image,
823 | width=width,
824 | height=height,
825 | batch_size=batch_size * num_images_per_prompt,
826 | num_images_per_prompt=num_images_per_prompt,
827 | device=device,
828 | dtype=controlnet.dtype,
829 | do_classifier_free_guidance=self.do_classifier_free_guidance,
830 | guess_mode=guess_mode,
831 | )
832 | height, width = image.shape[-2:]
833 | elif isinstance(controlnet, MultiControlNetModel):
834 | images = []
835 |
836 | for image_ in image:
837 | image_ = self.prepare_image(
838 | image=image_,
839 | width=width,
840 | height=height,
841 | batch_size=batch_size * num_images_per_prompt,
842 | num_images_per_prompt=num_images_per_prompt,
843 | device=device,
844 | dtype=controlnet.dtype,
845 | do_classifier_free_guidance=self.do_classifier_free_guidance,
846 | guess_mode=guess_mode,
847 | )
848 |
849 | images.append(image_)
850 |
851 | image = images
852 | height, width = image[0].shape[-2:]
853 | else:
854 | assert False
855 |
856 | # 5. Prepare timesteps
857 | self.scheduler.set_timesteps(num_inference_steps, device=device)
858 | timesteps = self.scheduler.timesteps
859 | self._num_timesteps = len(timesteps)
860 |
861 | # 6. Prepare latent variables
862 | num_channels_latents = self.unet.config.in_channels
863 | latents = self.prepare_latents(
864 | batch_size * num_images_per_prompt,
865 | num_channels_latents,
866 | height,
867 | width,
868 | prompt_embeds.dtype,
869 | device,
870 | generator,
871 | latents,
872 | )
873 |
874 | # 6.5 Optionally get Guidance Scale Embedding
875 | timestep_cond = None
876 | if self.unet.config.time_cond_proj_dim is not None:
877 | guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
878 | timestep_cond = self.get_guidance_scale_embedding(
879 | guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
880 | ).to(device=device, dtype=latents.dtype)
881 |
882 | # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
883 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
884 |
885 | # 7.1 Create tensor stating which controlnets to keep
886 | controlnet_keep = []
887 | for i in range(len(timesteps)):
888 | keeps = [
889 | 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
890 | for s, e in zip(control_guidance_start, control_guidance_end)
891 | ]
892 | controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
893 |
894 | # 7.2 Prepare added time ids & embeddings
895 | if isinstance(image, list):
896 | original_size = original_size or image[0].shape[-2:]
897 | else:
898 | original_size = original_size or image.shape[-2:]
899 | target_size = target_size or (height, width)
900 |
901 | add_text_embeds = pooled_prompt_embeds
902 | if self.text_encoder_2 is None:
903 | text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
904 | else:
905 | text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
906 |
907 | add_time_ids = self._get_add_time_ids(
908 | original_size,
909 | crops_coords_top_left,
910 | target_size,
911 | dtype=prompt_embeds.dtype,
912 | text_encoder_projection_dim=text_encoder_projection_dim,
913 | )
914 |
915 | if negative_original_size is not None and negative_target_size is not None:
916 | negative_add_time_ids = self._get_add_time_ids(
917 | negative_original_size,
918 | negative_crops_coords_top_left,
919 | negative_target_size,
920 | dtype=prompt_embeds.dtype,
921 | text_encoder_projection_dim=text_encoder_projection_dim,
922 | )
923 | else:
924 | negative_add_time_ids = add_time_ids
925 |
926 | if self.do_classifier_free_guidance:
927 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
928 | add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
929 | add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
930 |
931 | prompt_embeds = prompt_embeds.to(device)
932 | add_text_embeds = add_text_embeds.to(device)
933 | add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
934 | encoder_hidden_states = torch.cat([prompt_embeds, prompt_image_emb], dim=1)
935 |
936 | # 8. Denoising loop
937 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
938 | is_unet_compiled = is_compiled_module(self.unet)
939 | is_controlnet_compiled = is_compiled_module(self.controlnet)
940 | is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
941 |
942 | with self.progress_bar(total=num_inference_steps) as progress_bar:
943 | for i, t in enumerate(timesteps):
944 | # Relevant thread:
945 | # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
946 | if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
947 | torch._inductor.cudagraph_mark_step_begin()
948 | # expand the latents if we are doing classifier free guidance
949 | latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
950 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
951 |
952 | added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
953 |
954 | # controlnet(s) inference
955 | if guess_mode and self.do_classifier_free_guidance:
956 | # Infer ControlNet only for the conditional batch.
957 | control_model_input = latents
958 | control_model_input = self.scheduler.scale_model_input(control_model_input, t)
959 | controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
960 | controlnet_added_cond_kwargs = {
961 | "text_embeds": add_text_embeds.chunk(2)[1],
962 | "time_ids": add_time_ids.chunk(2)[1],
963 | }
964 | else:
965 | control_model_input = latent_model_input
966 | controlnet_prompt_embeds = prompt_embeds
967 | controlnet_added_cond_kwargs = added_cond_kwargs
968 |
969 | if isinstance(controlnet_keep[i], list):
970 | cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
971 | else:
972 | controlnet_cond_scale = controlnet_conditioning_scale
973 | if isinstance(controlnet_cond_scale, list):
974 | controlnet_cond_scale = controlnet_cond_scale[0]
975 | cond_scale = controlnet_cond_scale * controlnet_keep[i]
976 |
977 | down_block_res_samples, mid_block_res_sample = self.controlnet(
978 | control_model_input,
979 | t,
980 | encoder_hidden_states=prompt_image_emb,
981 | controlnet_cond=image,
982 | conditioning_scale=cond_scale,
983 | guess_mode=guess_mode,
984 | added_cond_kwargs=controlnet_added_cond_kwargs,
985 | return_dict=False,
986 | )
987 |
988 | if guess_mode and self.do_classifier_free_guidance:
989 | # Infered ControlNet only for the conditional batch.
990 | # To apply the output of ControlNet to both the unconditional and conditional batches,
991 | # add 0 to the unconditional batch to keep it unchanged.
992 | down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
993 | mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
994 |
995 | # predict the noise residual
996 | noise_pred = self.unet(
997 | latent_model_input,
998 | t,
999 | encoder_hidden_states=encoder_hidden_states,
1000 | timestep_cond=timestep_cond,
1001 | cross_attention_kwargs=self.cross_attention_kwargs,
1002 | down_block_additional_residuals=down_block_res_samples,
1003 | mid_block_additional_residual=mid_block_res_sample,
1004 | added_cond_kwargs=added_cond_kwargs,
1005 | return_dict=False,
1006 | )[0]
1007 |
1008 | # perform guidance
1009 | if self.do_classifier_free_guidance:
1010 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1011 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1012 |
1013 | # compute the previous noisy sample x_t -> x_t-1
1014 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1015 |
1016 | if callback_on_step_end is not None:
1017 | callback_kwargs = {}
1018 | for k in callback_on_step_end_tensor_inputs:
1019 | callback_kwargs[k] = locals()[k]
1020 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1021 |
1022 | latents = callback_outputs.pop("latents", latents)
1023 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1024 | negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1025 |
1026 | # call the callback, if provided
1027 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1028 | progress_bar.update()
1029 | if callback is not None and i % callback_steps == 0:
1030 | step_idx = i // getattr(self.scheduler, "order", 1)
1031 | callback(step_idx, t, latents)
1032 |
1033 | if not output_type == "latent":
1034 | # make sure the VAE is in float32 mode, as it overflows in float16
1035 | needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
1036 | if needs_upcasting:
1037 | self.upcast_vae()
1038 | latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1039 |
1040 | image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
1041 |
1042 | # cast back to fp16 if needed
1043 | if needs_upcasting:
1044 | self.vae.to(dtype=torch.float16)
1045 | else:
1046 | image = latents
1047 |
1048 | if not output_type == "latent":
1049 | # apply watermark if available
1050 | if self.watermark is not None:
1051 | image = self.watermark.apply_watermark(image)
1052 |
1053 | image = self.image_processor.postprocess(image, output_type=output_type)
1054 |
1055 | # Offload all models
1056 | self.maybe_free_model_hooks()
1057 |
1058 | if not return_dict:
1059 | return (image,)
1060 |
1061 | return StableDiffusionXLPipelineOutput(images=image)
1062 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.27.2
2 | diffusers==0.26.3
3 | gradio==4.19.0
4 | transformers==4.37.2
5 | insightface==0.7.3
6 | opencv-python==4.9.0.80
7 | onnxruntime==1.17.0
--------------------------------------------------------------------------------
/start.bat:
--------------------------------------------------------------------------------
1 | @echo off
2 | setlocal
3 | echo Starting InstantID CPU...
4 |
5 | set "PYTHON_COMMAND=python"
6 |
7 | call python --version > nul 2>&1
8 | if %errorlevel% equ 0 (
9 | echo Python command check :OK
10 | ) else (
11 | echo "Error: Python command not found, please install Python (Recommended : Python 3.10 or Python 3.11) and try again"
12 | pause
13 | exit /b 1
14 |
15 | )
16 |
17 | :check_python_version
18 | for /f "tokens=2" %%I in ('%PYTHON_COMMAND% --version 2^>^&1') do (
19 | set "python_version=%%I"
20 | )
21 |
22 | echo Python version: %python_version%
23 |
24 | set PATH=%PATH%;%~dp0env\Lib\site-packages\openvino\libs
25 | call "%~dp0env\Scripts\activate.bat" && %PYTHON_COMMAND% "%~dp0\app.py"
--------------------------------------------------------------------------------
/start.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | echo Starting InstantID CPU please wait...
3 | set -e
4 | PYTHON_COMMAND="python3"
5 |
6 | if ! command -v python3 &>/dev/null; then
7 | if ! command -v python &>/dev/null; then
8 | echo "Error: Python not found, please install python 3.10 or higher and try again"
9 | exit 1
10 | fi
11 | fi
12 |
13 | if command -v python &>/dev/null; then
14 | PYTHON_COMMAND="python"
15 | fi
16 |
17 | echo "Found $PYTHON_COMMAND command"
18 |
19 | python_version=$($PYTHON_COMMAND --version 2>&1 | awk '{print $2}')
20 | echo "Python version : $python_version"
21 |
22 | BASEDIR=$(pwd)
23 | # shellcheck disable=SC1091
24 | source "$BASEDIR/env/bin/activate"
25 | $PYTHON_COMMAND app.py
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 |
4 | def check_file(path: str) -> bool:
5 | return os.path.exists(path)
6 |
--------------------------------------------------------------------------------