├── LICENSE ├── README.md ├── examples ├── 20240708-234944.jpg ├── 20240708-235001.jpg ├── 20240708-235005.jpg ├── 20240708-235013.jpg ├── 20240708-235015.jpg └── 20240709-000053.jpg ├── hunyuan_utils ├── diffusers_learned_conditioning.py ├── sd_hijack_clip_diffusers.py └── utils.py ├── requirements.txt └── scripts └── hunyuandit.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 sethgggg 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Hunyuan extension for sd-webui 2 | 3 | The extension helps you to use [Hunyuan DiT Model](https://github.com/Tencent/HunyuanDiT) in [Stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui): 4 | 5 | ### Features 6 | 7 | - Core 8 | - [x] [Txt2Img] 9 | - [x] [Img2Img] 10 | - [] [LORA] 11 | - [] [ControlNet] 12 | - [] [HiresUpscaler] 13 | - Advanced 14 | - [] [MultiDiffusion] 15 | - [] [Adetailer] 16 | 17 | ### Installation 18 | 19 | 1. You can install this extension via the webui extension downloader by copying the git repository ```https://github.com/sethgggg/sd-webui-hunyuan-dit.git``` 20 | 21 | ![install](examples/20240709-000053.jpg) 22 | 23 | 2. Download the HunyuanDiT model from [Huggingface](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers) to local storage, the default storage location is in ```models/hunyuan``` of webui folder. You can change the default storage location via the settings card of the webui. 24 | 25 | ![folder](examples/20240708-235015.jpg) 26 | 27 | ![settings](examples/20240708-235001.jpg) 28 | 29 | 3. You have to place the transformer model in ```models/Stable-Diffusion```, which is the main storage location of checkpoints. If you have fine-tuned a new model, you can also place the transformer model in the same folder and then you could select the model here. 30 | 31 | 4. Find the HunyuanDiT card and enable them, if you want to use stable diffusion models, remember to disable the HunyuanDiT model. 32 | 33 | ![enable](examples/20240708-235013.jpg) 34 | 35 | 5. This project is use the diffusers as inference backend, thus we support the following samplers: 36 | 37 | | Sampler Name | Sampler Instance in diffusers | 38 | |-------------------------|------------------------------------------------------------------------------| 39 | | Euler a | EulerAncestralDiscreteScheduler() | 40 | | Euler | EulerDiscreteScheduler() | 41 | | LMS | LMSDiscreteScheduler() | 42 | | Heun | HeunDiscreteScheduler() | 43 | | DPM2 | KDPM2DiscreteScheduler() | 44 | | DPM2 a | KDPM2AncestralDiscreteScheduler() | 45 | | DPM++ SDE | DPMSolverSinglestepScheduler() | 46 | | DPM++ 2M | DPMSolverMultistepScheduler() | 47 | | DPM++ 2S a | DPMSolverSinglestepScheduler() | 48 | | LMS Karras | LMSDiscreteScheduler(use_karras_sigmas=True) | 49 | | DPM2 Karras | KDPM2DiscreteScheduler(use_karras_sigmas=True) | 50 | | DPM2 a Karras | KDPM2AncestralDiscreteScheduler(use_karras_sigmas=True) | 51 | | DPM++ SDE Karras | DPMSolverSinglestepScheduler(use_karras_sigmas=True) | 52 | | DPM++ 2M Karras | DPMSolverMultistepScheduler(use_karras_sigmas=True) | 53 | | DPM++ 2S a Karras | DPMSolverSinglestepScheduler(use_karras_sigmas=True) | 54 | | DDIM | DDIMScheduler() | 55 | | UniPC | UniPCMultistepScheduler() | 56 | | DPM++ 2M SDE Karras | DPMSolverMultistepScheduler(use_karras_sigmas=True, algorithm_type="sde-dpmsolver++") | 57 | | DPM++ 2M SDE | DPMSolverMultistepScheduler(algorithm_type="sde-dpmsolver++") | 58 | | LCM | LCMScheduler() | 59 | 60 | ### Examples 61 | 62 | ⚪ Txt2img: generating images, you can use the webui style prompts to generate 63 | 64 | ![txt2img](examples/20240708-235005.jpg) 65 | 66 | ⚪ Img2img: given a image, you can use the Hunyuan DiT model to generate more images. 67 | 68 | ![img2img](examples/20240708-234944.jpg) 69 | -------------------------------------------------------------------------------- /examples/20240708-234944.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sethgggg/sd-webui-hunyuan-dit/b0b3be88417acbc4460435e9269d371d70392652/examples/20240708-234944.jpg -------------------------------------------------------------------------------- /examples/20240708-235001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sethgggg/sd-webui-hunyuan-dit/b0b3be88417acbc4460435e9269d371d70392652/examples/20240708-235001.jpg -------------------------------------------------------------------------------- /examples/20240708-235005.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sethgggg/sd-webui-hunyuan-dit/b0b3be88417acbc4460435e9269d371d70392652/examples/20240708-235005.jpg -------------------------------------------------------------------------------- /examples/20240708-235013.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sethgggg/sd-webui-hunyuan-dit/b0b3be88417acbc4460435e9269d371d70392652/examples/20240708-235013.jpg -------------------------------------------------------------------------------- /examples/20240708-235015.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sethgggg/sd-webui-hunyuan-dit/b0b3be88417acbc4460435e9269d371d70392652/examples/20240708-235015.jpg -------------------------------------------------------------------------------- /examples/20240709-000053.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sethgggg/sd-webui-hunyuan-dit/b0b3be88417acbc4460435e9269d371d70392652/examples/20240709-000053.jpg -------------------------------------------------------------------------------- /hunyuan_utils/diffusers_learned_conditioning.py: -------------------------------------------------------------------------------- 1 | from modules import prompt_parser, shared 2 | 3 | def get_learned_conditioning_hunyuan(batch: prompt_parser.SdConditioning | list[str]): 4 | clip_l_conds, clip_l_attention = shared.clip_l_model(batch) 5 | t5_conds, t5_attention = shared.mt5_model(batch) 6 | return {"crossattn":clip_l_conds, "mask":clip_l_attention, "crossattn_2":t5_conds, "mask_2":t5_attention} -------------------------------------------------------------------------------- /hunyuan_utils/sd_hijack_clip_diffusers.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import namedtuple 3 | 4 | import torch 5 | 6 | from modules import prompt_parser, devices, sd_hijack, sd_emphasis 7 | from modules.shared import opts 8 | 9 | 10 | class PromptChunk: 11 | """ 12 | This object contains token ids, weight (multipliers:1.4) and textual inversion embedding info for a chunk of prompt. 13 | If a prompt is short, it is represented by one PromptChunk, otherwise, multiple are necessary. 14 | Each PromptChunk contains an exact amount of tokens - 77, which includes one for start and end token, 15 | so just 75 tokens from prompt. 16 | """ 17 | 18 | def __init__(self): 19 | self.tokens = [] 20 | self.multipliers = [] 21 | self.fixes = [] 22 | 23 | 24 | PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding']) 25 | """An object of this type is a marker showing that textual inversion embedding's vectors have to placed at offset in the prompt 26 | chunk. Thos objects are found in PromptChunk.fixes and, are placed into FrozenCLIPEmbedderWithCustomWordsBase.hijack.fixes, and finally 27 | are applied by sd_hijack.EmbeddingsWithFixes's forward function.""" 28 | 29 | 30 | class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): 31 | """A pytorch module that is a wrapper for FrozenCLIPEmbedder module. it enhances FrozenCLIPEmbedder, making it possible to 32 | have unlimited prompt length and assign weights to tokens in prompt. 33 | """ 34 | 35 | def __init__(self, wrapped, hijack): 36 | super().__init__() 37 | 38 | self.wrapped = wrapped 39 | """Original FrozenCLIPEmbedder module; can also be FrozenOpenCLIPEmbedder or xlmr.BertSeriesModelWithTransformation, 40 | depending on model.""" 41 | 42 | self.hijack: sd_hijack.StableDiffusionModelHijack = hijack 43 | self.chunk_length = 75 44 | 45 | self.is_trainable = getattr(wrapped, 'is_trainable', False) 46 | self.input_key = getattr(wrapped, 'input_key', 'txt') 47 | self.legacy_ucg_val = None 48 | 49 | def empty_chunk(self): 50 | """creates an empty PromptChunk and returns it""" 51 | 52 | chunk = PromptChunk() 53 | chunk.tokens = [self.id_start] + [self.id_end] * (self.chunk_length + 1) 54 | chunk.multipliers = [1.0] * (self.chunk_length + 2) 55 | return chunk 56 | 57 | def get_target_prompt_token_count(self, token_count): 58 | """returns the maximum number of tokens a prompt of a known length can have before it requires one more PromptChunk to be represented""" 59 | 60 | return math.ceil(max(token_count, 1) / self.chunk_length) * self.chunk_length 61 | 62 | def tokenize(self, texts): 63 | """Converts a batch of texts into a batch of token ids""" 64 | 65 | raise NotImplementedError 66 | 67 | def encode_with_transformers(self, tokens): 68 | """ 69 | converts a batch of token ids (in python lists) into a single tensor with numeric respresentation of those tokens; 70 | All python lists with tokens are assumed to have same length, usually 77. 71 | if input is a list with B elements and each element has T tokens, expected output shape is (B, T, C), where C depends on 72 | model - can be 768 and 1024. 73 | Among other things, this call will read self.hijack.fixes, apply it to its inputs, and clear it (setting it to None). 74 | """ 75 | 76 | raise NotImplementedError 77 | 78 | def encode_embedding_init_text(self, init_text, nvpt): 79 | """Converts text into a tensor with this text's tokens' embeddings. Note that those are embeddings before they are passed through 80 | transformers. nvpt is used as a maximum length in tokens. If text produces less teokens than nvpt, only this many is returned.""" 81 | 82 | raise NotImplementedError 83 | 84 | def tokenize_line(self, line): 85 | """ 86 | this transforms a single prompt into a list of PromptChunk objects - as many as needed to 87 | represent the prompt. 88 | Returns the list and the total number of tokens in the prompt. 89 | """ 90 | 91 | if opts.emphasis != "None": 92 | parsed = prompt_parser.parse_prompt_attention(line) 93 | else: 94 | parsed = [[line, 1.0]] 95 | 96 | tokenized = self.tokenize([text for text, _ in parsed]) 97 | 98 | chunks = [] 99 | chunk = PromptChunk() 100 | token_count = 0 101 | last_comma = -1 102 | 103 | def next_chunk(is_last=False): 104 | """puts current chunk into the list of results and produces the next one - empty; 105 | if is_last is true, tokens tokens at the end won't add to token_count""" 106 | nonlocal token_count 107 | nonlocal last_comma 108 | nonlocal chunk 109 | 110 | if is_last: 111 | token_count += len(chunk.tokens) 112 | else: 113 | token_count += self.chunk_length 114 | 115 | to_add = self.chunk_length - len(chunk.tokens) 116 | if to_add > 0: 117 | chunk.tokens += [self.id_end] * to_add 118 | chunk.multipliers += [1.0] * to_add 119 | 120 | chunk.tokens = [self.id_start] + chunk.tokens + [self.id_end] 121 | chunk.multipliers = [1.0] + chunk.multipliers + [1.0] 122 | 123 | last_comma = -1 124 | chunks.append(chunk) 125 | chunk = PromptChunk() 126 | 127 | for tokens, (text, weight) in zip(tokenized, parsed): 128 | if text == 'BREAK' and weight == -1: 129 | next_chunk() 130 | continue 131 | 132 | position = 0 133 | while position < len(tokens): 134 | token = tokens[position] 135 | 136 | if token == self.comma_token: 137 | last_comma = len(chunk.tokens) 138 | 139 | # this is when we are at the end of alloted 75 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack 140 | # is a setting that specifies that if there is a comma nearby, the text after the comma should be moved out of this chunk and into the next. 141 | elif opts.comma_padding_backtrack != 0 and len(chunk.tokens) == self.chunk_length and last_comma != -1 and len(chunk.tokens) - last_comma <= opts.comma_padding_backtrack: 142 | break_location = last_comma + 1 143 | 144 | reloc_tokens = chunk.tokens[break_location:] 145 | reloc_mults = chunk.multipliers[break_location:] 146 | 147 | chunk.tokens = chunk.tokens[:break_location] 148 | chunk.multipliers = chunk.multipliers[:break_location] 149 | 150 | next_chunk() 151 | chunk.tokens = reloc_tokens 152 | chunk.multipliers = reloc_mults 153 | 154 | if len(chunk.tokens) == self.chunk_length: 155 | next_chunk() 156 | 157 | embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, position) 158 | if embedding is None: 159 | chunk.tokens.append(token) 160 | chunk.multipliers.append(weight) 161 | position += 1 162 | continue 163 | 164 | emb_len = int(embedding.vectors) 165 | if len(chunk.tokens) + emb_len > self.chunk_length: 166 | next_chunk() 167 | 168 | chunk.fixes.append(PromptChunkFix(len(chunk.tokens), embedding)) 169 | 170 | chunk.tokens += [0] * emb_len 171 | chunk.multipliers += [weight] * emb_len 172 | position += embedding_length_in_tokens 173 | 174 | if chunk.tokens or not chunks: 175 | next_chunk(is_last=True) 176 | 177 | return chunks, token_count 178 | 179 | def process_texts(self, texts): 180 | """ 181 | Accepts a list of texts and calls tokenize_line() on each, with cache. Returns the list of results and maximum 182 | length, in tokens, of all texts. 183 | """ 184 | 185 | token_count = 0 186 | 187 | cache = {} 188 | batch_chunks = [] 189 | for line in texts: 190 | if line in cache: 191 | chunks = cache[line] 192 | else: 193 | chunks, current_token_count = self.tokenize_line(line) 194 | token_count = max(current_token_count, token_count) 195 | 196 | cache[line] = chunks 197 | 198 | batch_chunks.append(chunks) 199 | 200 | return batch_chunks, token_count 201 | 202 | def forward(self, texts): 203 | """ 204 | Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts. 205 | Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will 206 | be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, for SD2 it's 1024, and for SDXL it's 1280. 207 | An example shape returned by this function can be: (2, 77, 768). 208 | For SDXL, instead of returning one tensor avobe, it returns a tuple with two: the other one with shape (B, 1280) with pooled values. 209 | Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet 210 | is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream" 211 | """ 212 | 213 | if opts.use_old_emphasis_implementation: 214 | import modules.sd_hijack_clip_old 215 | return modules.sd_hijack_clip_old.forward_old(self, texts) 216 | 217 | batch_chunks, token_count = self.process_texts(texts) 218 | 219 | used_embeddings = {} 220 | chunk_count = max([len(x) for x in batch_chunks]) 221 | 222 | zs = [] 223 | tk = [] 224 | for i in range(chunk_count): 225 | batch_chunk = [chunks[i] if i < len(chunks) else self.empty_chunk() for chunks in batch_chunks] 226 | 227 | tokens = [x.tokens for x in batch_chunk] 228 | attn_mask = [] 229 | for token in tokens: 230 | temp_mask = [] 231 | for token_id in token: 232 | if token_id != self.id_end: 233 | temp_mask.append(1) 234 | else: 235 | temp_mask.append(0) 236 | attn_mask.append(temp_mask) 237 | multipliers = [x.multipliers for x in batch_chunk] 238 | self.hijack.fixes = [x.fixes for x in batch_chunk] 239 | 240 | for fixes in self.hijack.fixes: 241 | for _position, embedding in fixes: 242 | used_embeddings[embedding.name] = embedding 243 | 244 | z = self.process_tokens(tokens, multipliers) 245 | zs.append(z) 246 | tk.append(torch.tensor(attn_mask)) 247 | 248 | if opts.textual_inversion_add_hashes_to_infotext and used_embeddings: 249 | hashes = [] 250 | for name, embedding in used_embeddings.items(): 251 | shorthash = embedding.shorthash 252 | if not shorthash: 253 | continue 254 | 255 | name = name.replace(":", "").replace(",", "") 256 | hashes.append(f"{name}: {shorthash}") 257 | 258 | if hashes: 259 | if self.hijack.extra_generation_params.get("TI hashes"): 260 | hashes.append(self.hijack.extra_generation_params.get("TI hashes")) 261 | self.hijack.extra_generation_params["TI hashes"] = ", ".join(hashes) 262 | 263 | if any(x for x in texts if "(" in x or "[" in x) and opts.emphasis != "Original": 264 | self.hijack.extra_generation_params["Emphasis"] = opts.emphasis 265 | 266 | if getattr(self, 'return_pooled', False): 267 | return torch.hstack(zs), zs[0].pooled 268 | elif getattr(self, 'return_masks', False): 269 | return torch.hstack(zs), torch.hstack(tk) 270 | else: 271 | return torch.hstack(zs) 272 | 273 | def process_tokens(self, remade_batch_tokens, batch_multipliers): 274 | """ 275 | sends one single prompt chunk to be encoded by transformers neural network. 276 | remade_batch_tokens is a batch of tokens - a list, where every element is a list of tokens; usually 277 | there are exactly 77 tokens in the list. batch_multipliers is the same but for multipliers instead of tokens. 278 | Multipliers are used to give more or less weight to the outputs of transformers network. Each multiplier 279 | corresponds to one token. 280 | """ 281 | tokens = torch.asarray(remade_batch_tokens).to(devices.device) 282 | 283 | # this is for SD2: SD1 uses the same token for padding and end of text, while SD2 uses different ones. 284 | if self.id_end != self.id_pad: 285 | for batch_pos in range(len(remade_batch_tokens)): 286 | index = remade_batch_tokens[batch_pos].index(self.id_end) 287 | tokens[batch_pos, index+1:tokens.shape[1]] = self.id_pad 288 | 289 | z = self.encode_with_transformers(tokens) 290 | 291 | pooled = getattr(z, 'pooled', None) 292 | 293 | emphasis = sd_emphasis.get_current_option(opts.emphasis)() 294 | emphasis.tokens = remade_batch_tokens 295 | emphasis.multipliers = torch.asarray(batch_multipliers).to(devices.device) 296 | emphasis.z = z 297 | 298 | emphasis.after_transformers() 299 | 300 | z = emphasis.z 301 | 302 | if pooled is not None: 303 | z.pooled = pooled 304 | 305 | return z 306 | 307 | 308 | class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase): 309 | def __init__(self, wrapped, hijack): 310 | super().__init__(wrapped, hijack) 311 | self.tokenizer = wrapped.tokenizer 312 | 313 | vocab = self.tokenizer.get_vocab() 314 | 315 | self.comma_token = vocab.get(',', None) 316 | 317 | self.token_mults = {} 318 | tokens_with_parens = [(k, v) for k, v in vocab.items() if '(' in k or ')' in k or '[' in k or ']' in k] 319 | for text, ident in tokens_with_parens: 320 | mult = 1.0 321 | for c in text: 322 | if c == '[': 323 | mult /= 1.1 324 | if c == ']': 325 | mult *= 1.1 326 | if c == '(': 327 | mult *= 1.1 328 | if c == ')': 329 | mult /= 1.1 330 | 331 | if mult != 1.0: 332 | self.token_mults[ident] = mult 333 | 334 | self.id_start = self.wrapped.tokenizer.bos_token_id 335 | self.id_end = self.wrapped.tokenizer.eos_token_id 336 | self.id_pad = self.id_end 337 | 338 | def tokenize(self, texts): 339 | tokenized = self.wrapped.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"] 340 | 341 | return tokenized 342 | 343 | def encode_with_transformers(self, tokens): 344 | outputs = self.wrapped(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers) 345 | 346 | if opts.CLIP_stop_at_last_layers > 1: 347 | z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers] 348 | else: 349 | z = outputs.last_hidden_state 350 | z.pooled = outputs.text_embeds 351 | 352 | return z 353 | 354 | def encode_embedding_init_text(self, init_text, nvpt): 355 | embedding_layer = self.wrapped.text_model.embeddings 356 | ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"] 357 | embedded = embedding_layer.token_embedding(ids.to(embedding_layer.token_embedding.weight.device)).squeeze(0) 358 | 359 | return embedded 360 | 361 | class FrozenBertEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase): 362 | def __init__(self, wrapped, hijack): 363 | super().__init__(wrapped, hijack) 364 | self.tokenizer = wrapped.tokenizer 365 | 366 | vocab = self.tokenizer.get_vocab() 367 | 368 | self.comma_token = vocab.get(',', None) 369 | 370 | self.token_mults = {} 371 | tokens_with_parens = [(k, v) for k, v in vocab.items() if '(' in k or ')' in k or '[' in k or ']' in k] 372 | for text, ident in tokens_with_parens: 373 | mult = 1.0 374 | for c in text: 375 | if c == '[': 376 | mult /= 1.1 377 | if c == ']': 378 | mult *= 1.1 379 | if c == '(': 380 | mult *= 1.1 381 | if c == ')': 382 | mult /= 1.1 383 | 384 | if mult != 1.0: 385 | self.token_mults[ident] = mult 386 | 387 | self.id_start = self.wrapped.tokenizer.cls_token_id 388 | self.id_end = self.wrapped.tokenizer.sep_token_id 389 | self.id_pad = self.wrapped.tokenizer.pad_token_id 390 | 391 | def empty_chunk(self): 392 | """creates an empty PromptChunk and returns it""" 393 | 394 | chunk = PromptChunk() 395 | chunk.tokens = [self.id_start] + [self.id_end] + [self.id_pad] * (self.chunk_length) 396 | chunk.multipliers = [1.0] * (self.chunk_length + 2) 397 | return chunk 398 | 399 | def tokenize_line(self, line): 400 | """ 401 | this transforms a single prompt into a list of PromptChunk objects - as many as needed to 402 | represent the prompt. 403 | Returns the list and the total number of tokens in the prompt. 404 | """ 405 | 406 | if opts.emphasis != "None": 407 | parsed = prompt_parser.parse_prompt_attention(line) 408 | else: 409 | parsed = [[line, 1.0]] 410 | 411 | tokenized = self.tokenize([text for text, _ in parsed]) 412 | 413 | chunks = [] 414 | chunk = PromptChunk() 415 | token_count = 0 416 | last_comma = -1 417 | 418 | def next_chunk(is_last=False): 419 | """puts current chunk into the list of results and produces the next one - empty; 420 | if is_last is true, tokens tokens at the end won't add to token_count""" 421 | nonlocal token_count 422 | nonlocal last_comma 423 | nonlocal chunk 424 | 425 | if is_last: 426 | token_count += len(chunk.tokens) 427 | else: 428 | token_count += self.chunk_length 429 | 430 | to_add = self.chunk_length - len(chunk.tokens) 431 | if to_add > 0: 432 | chunk.tokens += [self.id_end] + [self.id_pad] * to_add 433 | chunk.multipliers += [1.0] * to_add 434 | else: 435 | chunk.tokens += [self.id_end] 436 | 437 | chunk.tokens = [self.id_start] + chunk.tokens 438 | chunk.multipliers = [1.0] + chunk.multipliers + [1.0] 439 | 440 | last_comma = -1 441 | chunks.append(chunk) 442 | chunk = PromptChunk() 443 | 444 | for tokens, (text, weight) in zip(tokenized, parsed): 445 | if text == 'BREAK' and weight == -1: 446 | next_chunk() 447 | continue 448 | 449 | position = 0 450 | while position < len(tokens): 451 | token = tokens[position] 452 | 453 | if token == self.comma_token: 454 | last_comma = len(chunk.tokens) 455 | 456 | # this is when we are at the end of alloted 75 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack 457 | # is a setting that specifies that if there is a comma nearby, the text after the comma should be moved out of this chunk and into the next. 458 | elif opts.comma_padding_backtrack != 0 and len(chunk.tokens) == self.chunk_length and last_comma != -1 and len(chunk.tokens) - last_comma <= opts.comma_padding_backtrack: 459 | break_location = last_comma + 1 460 | 461 | reloc_tokens = chunk.tokens[break_location:] 462 | reloc_mults = chunk.multipliers[break_location:] 463 | 464 | chunk.tokens = chunk.tokens[:break_location] 465 | chunk.multipliers = chunk.multipliers[:break_location] 466 | 467 | next_chunk() 468 | chunk.tokens = reloc_tokens 469 | chunk.multipliers = reloc_mults 470 | 471 | if len(chunk.tokens) == self.chunk_length: 472 | next_chunk() 473 | 474 | embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, position) 475 | if embedding is None: 476 | chunk.tokens.append(token) 477 | chunk.multipliers.append(weight) 478 | position += 1 479 | continue 480 | 481 | emb_len = int(embedding.vectors) 482 | if len(chunk.tokens) + emb_len > self.chunk_length: 483 | next_chunk() 484 | 485 | chunk.fixes.append(PromptChunkFix(len(chunk.tokens), embedding)) 486 | 487 | chunk.tokens += [0] * emb_len 488 | chunk.multipliers += [weight] * emb_len 489 | position += embedding_length_in_tokens 490 | 491 | if chunk.tokens or not chunks: 492 | next_chunk(is_last=True) 493 | 494 | return chunks, token_count 495 | 496 | def tokenize(self, texts): 497 | tokenized = self.wrapped.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"] 498 | 499 | return tokenized 500 | 501 | def encode_with_transformers(self, tokens): 502 | attn_mask = [] 503 | for token in tokens: 504 | temp_mask = [] 505 | for token_id in token: 506 | if token_id != self.id_pad: 507 | temp_mask.append(1) 508 | else: 509 | temp_mask.append(0) 510 | attn_mask.append(temp_mask) 511 | outputs = self.wrapped(input_ids=tokens,attention_mask=torch.tensor(attn_mask).to(devices.device),output_hidden_states=-opts.CLIP_stop_at_last_layers) 512 | 513 | if opts.CLIP_stop_at_last_layers > 1: 514 | z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers] 515 | else: 516 | z = outputs.last_hidden_state 517 | 518 | return z 519 | 520 | def encode_embedding_init_text(self, init_text, nvpt): 521 | embedding_layer = self.wrapped.text_model.embeddings 522 | ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"] 523 | embedded = embedding_layer.token_embedding.wrapped(ids.to(embedding_layer.token_embedding.wrapped.weight.device)).squeeze(0) 524 | 525 | return embedded 526 | 527 | def forward(self, texts): 528 | """ 529 | Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts. 530 | Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will 531 | be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, for SD2 it's 1024, and for SDXL it's 1280. 532 | An example shape returned by this function can be: (2, 77, 768). 533 | For SDXL, instead of returning one tensor avobe, it returns a tuple with two: the other one with shape (B, 1280) with pooled values. 534 | Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet 535 | is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream" 536 | """ 537 | 538 | batch_chunks, token_count = self.process_texts(texts) 539 | 540 | chunk_count = max([len(x) for x in batch_chunks]) 541 | 542 | zs = [] 543 | tk = [] 544 | for i in range(chunk_count): 545 | batch_chunk = [chunks[i] if i < len(chunks) else self.empty_chunk() for chunks in batch_chunks] 546 | 547 | tokens = [x.tokens for x in batch_chunk] 548 | attn_mask = [] 549 | for token in tokens: 550 | temp_mask = [] 551 | for token_id in token: 552 | if token_id != self.id_pad: 553 | temp_mask.append(1) 554 | else: 555 | temp_mask.append(0) 556 | attn_mask.append(temp_mask) 557 | multipliers = [x.multipliers for x in batch_chunk] 558 | 559 | z = self.process_tokens(tokens, multipliers) 560 | zs.append(z) 561 | tk.append(torch.tensor(attn_mask)) 562 | 563 | if getattr(self, 'return_masks', False): 564 | return torch.hstack([zs[0]]), torch.hstack([tk[0]]).to(devices.device) 565 | else: 566 | return torch.hstack([zs[0]]) 567 | 568 | class FrozenT5EmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWords): 569 | def __init__(self, wrapped, hijack): 570 | super().__init__(wrapped, hijack) 571 | 572 | self.tokenizer = wrapped.tokenizer 573 | self.chunk_length = 255 574 | self.id_start = self.tokenizer.bos_token_id 575 | self.id_end = self.tokenizer.eos_token_id 576 | self.id_pad = 0 577 | 578 | def empty_chunk(self): 579 | """creates an empty PromptChunk and returns it""" 580 | 581 | chunk = PromptChunk() 582 | chunk.tokens = [self.id_end] + [self.id_pad] * self.chunk_length 583 | chunk.multipliers = [1.0] + [1.0]* self.chunk_length 584 | return chunk 585 | 586 | def tokenize_line(self, line): 587 | """ 588 | this transforms a single prompt into a list of PromptChunk objects - as many as needed to 589 | represent the prompt. 590 | Returns the list and the total number of tokens in the prompt. 591 | """ 592 | 593 | parsed = prompt_parser.parse_prompt_attention(line) 594 | 595 | tokenized = self.tokenize([text for text, _ in parsed]) 596 | 597 | chunks = [] 598 | chunk = PromptChunk() 599 | token_count = 0 600 | last_comma = -1 601 | 602 | def next_chunk(is_last=False): 603 | """puts current chunk into the list of results and produces the next one - empty; 604 | if is_last is true, tokens tokens at the end won't add to token_count""" 605 | nonlocal token_count 606 | nonlocal last_comma 607 | nonlocal chunk 608 | 609 | if is_last: 610 | token_count += len(chunk.tokens) 611 | else: 612 | token_count += self.chunk_length 613 | 614 | to_add = self.chunk_length - len(chunk.tokens) 615 | if to_add > 0: 616 | chunk.tokens += [self.id_end] + [self.id_pad] * to_add 617 | chunk.multipliers += [1.0] * to_add 618 | else: 619 | chunk.tokens += [self.id_end] 620 | 621 | chunk.tokens = [] + chunk.tokens 622 | chunk.multipliers = [] + chunk.multipliers + [1.0] 623 | 624 | last_comma = -1 625 | chunks.append(chunk) 626 | chunk = PromptChunk() 627 | 628 | for tokens, (text, weight) in zip(tokenized, parsed): 629 | if text == 'BREAK' and weight == -1: 630 | next_chunk() 631 | continue 632 | 633 | position = 0 634 | while position < len(tokens): 635 | token = tokens[position] 636 | 637 | if token == self.comma_token: 638 | last_comma = len(chunk.tokens) 639 | 640 | # this is when we are at the end of alloted 75 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack 641 | # is a setting that specifies that if there is a comma nearby, the text after the comma should be moved out of this chunk and into the next. 642 | elif len(chunk.tokens) == self.chunk_length and last_comma != -1 and len(chunk.tokens) - last_comma <= 20: 643 | break_location = last_comma + 1 644 | 645 | reloc_tokens = chunk.tokens[break_location:] 646 | reloc_mults = chunk.multipliers[break_location:] 647 | 648 | chunk.tokens = chunk.tokens[:break_location] 649 | chunk.multipliers = chunk.multipliers[:break_location] 650 | 651 | next_chunk() 652 | chunk.tokens = reloc_tokens 653 | chunk.multipliers = reloc_mults 654 | 655 | if len(chunk.tokens) == self.chunk_length: 656 | next_chunk() 657 | 658 | chunk.tokens.append(token) 659 | chunk.multipliers.append(weight) 660 | position += 1 661 | 662 | if chunk.tokens or not chunks: 663 | next_chunk(is_last=True) 664 | 665 | return chunks, token_count 666 | 667 | def encode_with_transformers(self, tokens): 668 | attn_mask = [] 669 | for token in tokens: 670 | temp_mask = [] 671 | for token_id in token: 672 | if token_id != self.id_pad: 673 | temp_mask.append(1) 674 | else: 675 | temp_mask.append(0) 676 | attn_mask.append(temp_mask) 677 | outputs = self.wrapped(input_ids=tokens, attention_mask=torch.tensor(attn_mask).to(devices.device), output_hidden_states=True) 678 | 679 | ''' 680 | if self.wrapped.layer == "last": 681 | z = outputs.last_hidden_state 682 | else: 683 | z = outputs.hidden_states[self.wrapped.layer_idx] 684 | ''' 685 | z = outputs.last_hidden_state 686 | return z 687 | 688 | def forward(self, texts): 689 | """ 690 | Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts. 691 | Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will 692 | be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, for SD2 it's 1024, and for SDXL it's 1280. 693 | An example shape returned by this function can be: (2, 77, 768). 694 | For SDXL, instead of returning one tensor avobe, it returns a tuple with two: the other one with shape (B, 1280) with pooled values. 695 | Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet 696 | is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream" 697 | """ 698 | 699 | batch_chunks, token_count = self.process_texts(texts) 700 | 701 | chunk_count = max([len(x) for x in batch_chunks]) 702 | 703 | zs = [] 704 | tk = [] 705 | for i in range(chunk_count): 706 | batch_chunk = [chunks[i] if i < len(chunks) else self.empty_chunk() for chunks in batch_chunks] 707 | 708 | tokens = [x.tokens for x in batch_chunk] 709 | attn_mask = [] 710 | for token in tokens: 711 | temp_mask = [] 712 | for token_id in token: 713 | if token_id != self.id_pad: 714 | temp_mask.append(1) 715 | else: 716 | temp_mask.append(0) 717 | attn_mask.append(temp_mask) 718 | multipliers = [x.multipliers for x in batch_chunk] 719 | 720 | z = self.process_tokens(tokens, multipliers) 721 | zs.append(z) 722 | tk.append(torch.tensor(attn_mask)) 723 | 724 | if getattr(self, 'return_masks', False): 725 | return torch.hstack([zs[0]]), torch.hstack([tk[0]]).to(devices.device) 726 | else: 727 | return torch.hstack([zs[0]]) -------------------------------------------------------------------------------- /hunyuan_utils/utils.py: -------------------------------------------------------------------------------- 1 | from modules import devices, rng, shared 2 | import numpy as np 3 | import gc 4 | import inspect 5 | import torch 6 | from typing import Any, Dict, List, Optional, Union, Tuple 7 | from diffusers.schedulers import ( 8 | DDIMScheduler, 9 | DDPMScheduler, 10 | LMSDiscreteScheduler, 11 | EulerDiscreteScheduler, 12 | HeunDiscreteScheduler, 13 | EulerAncestralDiscreteScheduler, 14 | DPMSolverMultistepScheduler, 15 | DPMSolverSinglestepScheduler, 16 | KDPM2DiscreteScheduler, 17 | KDPM2AncestralDiscreteScheduler, 18 | UniPCMultistepScheduler, 19 | LCMScheduler, 20 | ) 21 | 22 | hunyuan_transformer_config_v12 = { 23 | "_class_name": "HunyuanDiT2DModel", 24 | "_diffusers_version": "0.30.0.dev0", 25 | "activation_fn": "gelu-approximate", 26 | "attention_head_dim": 88, 27 | "cross_attention_dim": 1024, 28 | "cross_attention_dim_t5": 2048, 29 | "hidden_size": 1408, 30 | "in_channels": 4, 31 | "learn_sigma": True, 32 | "mlp_ratio": 4.3637, 33 | "norm_type": "layer_norm", 34 | "num_attention_heads": 16, 35 | "num_layers": 40, 36 | "patch_size": 2, 37 | "pooled_projection_dim": 1024, 38 | "sample_size": 128, 39 | "text_len": 77, 40 | "text_len_t5": 256, 41 | "use_style_cond_and_image_meta_size": False 42 | } 43 | 44 | dit_sampler_dict = { 45 | "Euler a":EulerAncestralDiscreteScheduler(), 46 | "Euler":EulerDiscreteScheduler(), 47 | "LMS":LMSDiscreteScheduler(), 48 | "Heun":HeunDiscreteScheduler(), 49 | "DPM2":KDPM2DiscreteScheduler(), 50 | "DPM2 a":KDPM2AncestralDiscreteScheduler(), 51 | "DPM++ SDE":DPMSolverSinglestepScheduler(), 52 | "DPM++ 2M":DPMSolverMultistepScheduler(), 53 | "DPM++ 2S a":DPMSolverSinglestepScheduler(), 54 | "LMS Karras":LMSDiscreteScheduler(use_karras_sigmas=True), 55 | "DPM2 Karras":KDPM2DiscreteScheduler(use_karras_sigmas=True), 56 | "DPM2 a Karras":KDPM2AncestralDiscreteScheduler(use_karras_sigmas=True), 57 | "DPM++ SDE Karras":DPMSolverSinglestepScheduler(use_karras_sigmas=True), 58 | "DPM++ 2M Karras":DPMSolverMultistepScheduler(use_karras_sigmas=True), 59 | "DPM++ 2S a Karras":DPMSolverSinglestepScheduler(use_karras_sigmas=True), 60 | "DDIM":DDIMScheduler(), 61 | "UniPC":UniPCMultistepScheduler(), 62 | "DPM++ 2M SDE Karras":DPMSolverMultistepScheduler(use_karras_sigmas=True,algorithm_type="sde-dpmsolver++"), 63 | "DPM++ 2M SDE":DPMSolverMultistepScheduler(algorithm_type="sde-dpmsolver++"), 64 | "LCM":LCMScheduler() 65 | } 66 | 67 | def get_resize_crop_region_for_grid(src, tgt_size): 68 | th = tw = tgt_size 69 | h, w = src 70 | 71 | r = h / w 72 | 73 | # resize 74 | if r > 1: 75 | resize_height = th 76 | resize_width = int(round(th / h * w)) 77 | else: 78 | resize_width = tw 79 | resize_height = int(round(tw / w * h)) 80 | 81 | crop_top = int(round((th - resize_height) / 2.0)) 82 | crop_left = int(round((tw - resize_width) / 2.0)) 83 | 84 | return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) 85 | 86 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg 87 | def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): 88 | """ 89 | Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and 90 | Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 91 | """ 92 | std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) 93 | std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) 94 | # rescale the results from guidance (fixes overexposure) 95 | noise_pred_rescaled = noise_cfg * (std_text / std_cfg) 96 | # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images 97 | noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg 98 | return noise_cfg 99 | 100 | def unload_model(current_model): 101 | if current_model is not None: 102 | current_model.to(devices.cpu) 103 | current_model = None 104 | gc.collect() 105 | devices.torch_gc() 106 | return current_model 107 | 108 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs 109 | def prepare_extra_step_kwargs(scheduler, generator, eta): 110 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 111 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 112 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 113 | # and should be between [0, 1] 114 | 115 | accepts_eta = "eta" in set(inspect.signature(scheduler.step).parameters.keys()) 116 | extra_step_kwargs = {} 117 | if accepts_eta: 118 | extra_step_kwargs["eta"] = eta 119 | 120 | # check if the scheduler accepts generator 121 | accepts_generator = "generator" in set(inspect.signature(scheduler.step).parameters.keys()) 122 | if accepts_generator: 123 | extra_step_kwargs["generator"] = generator 124 | return extra_step_kwargs 125 | 126 | def randn_tensor( 127 | shape: Union[Tuple, List], 128 | generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None, 129 | device: Optional["torch.device"] = None, 130 | dtype: Optional["torch.dtype"] = None, 131 | layout: Optional["torch.layout"] = None, 132 | ): 133 | """A helper function to create random tensors on the desired `device` with the desired `dtype`. When 134 | passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor 135 | is always created on the CPU. 136 | """ 137 | # device on which tensor is created defaults to device 138 | rand_device = device 139 | batch_size = shape[0] 140 | 141 | layout = layout or torch.strided 142 | device = device or torch.device("cpu") 143 | 144 | if generator is not None: 145 | gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type 146 | if gen_device_type != device.type and gen_device_type == "cpu": 147 | rand_device = "cpu" 148 | if device != "mps": 149 | print( 150 | f"The passed generator was created on 'cpu' even though a tensor on {device} was expected." 151 | f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably" 152 | f" slighly speed up this function by passing a generator that was created on the {device} device." 153 | ) 154 | elif gen_device_type != device.type and gen_device_type == "cuda": 155 | raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.") 156 | 157 | # make sure generator list of length 1 is treated like a non-list 158 | if isinstance(generator, list) and len(generator) == 1: 159 | generator = generator[0] 160 | 161 | if isinstance(generator, list): 162 | shape = (1,) + shape[1:] 163 | latents = [ 164 | torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout) 165 | for i in range(batch_size) 166 | ] 167 | latents = torch.cat(latents, dim=0).to(device) 168 | else: 169 | latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) 170 | 171 | return latents 172 | 173 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents 174 | def prepare_latents_txt2img(vae_scale_factor, scheduler, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): 175 | shape = ( 176 | batch_size, 177 | num_channels_latents, 178 | int(height) // vae_scale_factor, 179 | int(width) // vae_scale_factor, 180 | ) 181 | if isinstance(generator, list) and len(generator) != batch_size: 182 | raise ValueError( 183 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 184 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 185 | ) 186 | 187 | if latents is None: 188 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) 189 | else: 190 | latents = latents.to(device) 191 | 192 | # scale the initial noise by the standard deviation required by the scheduler 193 | if hasattr(scheduler, 'init_noise_sigma'): 194 | latents = latents * scheduler.init_noise_sigma 195 | return latents 196 | 197 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps 198 | def retrieve_timesteps( 199 | scheduler, 200 | num_inference_steps: Optional[int] = None, 201 | device: Optional[Union[str, torch.device]] = None, 202 | timesteps: Optional[List[int]] = None, 203 | sigmas: Optional[List[float]] = None, 204 | **kwargs, 205 | ): 206 | """ 207 | Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles 208 | custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. 209 | 210 | Args: 211 | scheduler (`SchedulerMixin`): 212 | The scheduler to get timesteps from. 213 | num_inference_steps (`int`): 214 | The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` 215 | must be `None`. 216 | device (`str` or `torch.device`, *optional*): 217 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. 218 | timesteps (`List[int]`, *optional*): 219 | Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, 220 | `num_inference_steps` and `sigmas` must be `None`. 221 | sigmas (`List[float]`, *optional*): 222 | Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, 223 | `num_inference_steps` and `timesteps` must be `None`. 224 | 225 | Returns: 226 | `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the 227 | second element is the number of inference steps. 228 | """ 229 | if timesteps is not None and sigmas is not None: 230 | raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") 231 | if timesteps is not None: 232 | accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) 233 | if not accepts_timesteps: 234 | raise ValueError( 235 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" 236 | f" timestep schedules. Please check whether you are using the correct scheduler." 237 | ) 238 | scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) 239 | timesteps = scheduler.timesteps 240 | num_inference_steps = len(timesteps) 241 | elif sigmas is not None: 242 | accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) 243 | if not accept_sigmas: 244 | raise ValueError( 245 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" 246 | f" sigmas schedules. Please check whether you are using the correct scheduler." 247 | ) 248 | scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) 249 | timesteps = scheduler.timesteps 250 | num_inference_steps = len(timesteps) 251 | else: 252 | scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) 253 | timesteps = scheduler.timesteps 254 | return timesteps, num_inference_steps 255 | 256 | # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps 257 | def get_timesteps(scheduler, num_inference_steps, strength, device, denoising_start=None): 258 | # get the original timestep using init_timestep 259 | if denoising_start is None: 260 | init_timestep = min(int(num_inference_steps * strength), num_inference_steps) 261 | t_start = max(num_inference_steps - init_timestep, 0) 262 | else: 263 | t_start = 0 264 | 265 | timesteps = scheduler.timesteps[t_start * scheduler.order :] 266 | 267 | # Strength is irrelevant if we directly request a timestep to start at; 268 | # that is, strength is determined by the denoising_start instead. 269 | if denoising_start is not None: 270 | discrete_timestep_cutoff = int( 271 | round( 272 | scheduler.config.num_train_timesteps 273 | - (denoising_start * scheduler.config.num_train_timesteps) 274 | ) 275 | ) 276 | 277 | num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item() 278 | if scheduler.order == 2 and num_inference_steps % 2 == 0: 279 | # if the scheduler is a 2nd order scheduler we might have to do +1 280 | # because `num_inference_steps` might be even given that every timestep 281 | # (except the highest one) is duplicated. If `num_inference_steps` is even it would 282 | # mean that we cut the timesteps in the middle of the denoising step 283 | # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1 284 | # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler 285 | num_inference_steps = num_inference_steps + 1 286 | 287 | # because t_n+1 >= t_n, we slice the timesteps starting from the end 288 | timesteps = timesteps[-num_inference_steps:] 289 | return timesteps, num_inference_steps 290 | 291 | return timesteps, num_inference_steps - t_start 292 | 293 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents 294 | def retrieve_latents( 295 | encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" 296 | ): 297 | if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": 298 | return encoder_output.latent_dist.sample(generator) 299 | elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": 300 | return encoder_output.latent_dist.mode() 301 | elif hasattr(encoder_output, "latents"): 302 | return encoder_output.latents 303 | else: 304 | raise AttributeError("Could not access latents of provided encoder_output") 305 | 306 | def _encode_vae_image(image: torch.Tensor, generator: torch.Generator): 307 | #dtype = image.dtype 308 | #image = image.float() 309 | #self.vae_model.to(dtype=torch.float32) 310 | 311 | if isinstance(generator, list): 312 | image_latents = [ 313 | retrieve_latents(shared.vae_model.encode(image[i : i + 1]), generator=generator[i]) 314 | for i in range(image.shape[0]) 315 | ] 316 | image_latents = torch.cat(image_latents, dim=0) 317 | else: 318 | image_latents = retrieve_latents(shared.vae_model.encode(image), generator=generator) 319 | 320 | #self.vae_model.to(dtype) 321 | 322 | #image_latents = image_latents.to(dtype) 323 | image_latents = shared.vae_model.config.scaling_factor * image_latents 324 | 325 | return image_latents 326 | 327 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents 328 | def prepare_latents_img2img(vae_scale_factor, scheduler, image, batch_size, num_channels_latents, height, width, dtype, device, generator, seeds, timestep): 329 | shape = ( 330 | batch_size, 331 | num_channels_latents, 332 | int(height) // vae_scale_factor, 333 | int(width) // vae_scale_factor, 334 | ) 335 | generators = [rng.create_generator(seed) for seed in seeds] 336 | if isinstance(generator, list) and len(generator) != batch_size: 337 | raise ValueError( 338 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 339 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 340 | ) 341 | 342 | image_latents = _encode_vae_image(image, generator=generators) 343 | image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) 344 | 345 | noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype).to(devices.device) 346 | init_latents = scheduler.add_noise(image_latents, noise, timestep) 347 | latents = init_latents.to(device=devices.device, dtype=dtype) 348 | 349 | return latents, noise, image_latents 350 | 351 | def guess_dit_model(state_dict): 352 | if "state_dict" in state_dict: 353 | state_dict = state_dict["state_dict"] 354 | elif "module" in state_dict: 355 | state_dict = state_dict["module"] 356 | if "mlp_t5.0.weight" in state_dict: 357 | return "hunyuan-original" 358 | elif "text_embedder.linear_1.weight" in state_dict: 359 | return "hunyuan" 360 | else: 361 | return "non supported dit" 362 | 363 | def convert_hunyuan_to_diffusers(state_dict): 364 | if "state_dict" in state_dict: 365 | state_dict = state_dict["state_dict"] 366 | elif "module" in state_dict: 367 | state_dict = state_dict["module"] 368 | # input_size -> sample_size, text_dim -> cross_attention_dim 369 | num_layers = 40 370 | for i in range(num_layers): 371 | # attn1 372 | # Wkqv -> to_q, to_k, to_v 373 | q, k, v = torch.chunk(state_dict[f"blocks.{i}.attn1.Wqkv.weight"], 3, dim=0) 374 | q_bias, k_bias, v_bias = torch.chunk(state_dict[f"blocks.{i}.attn1.Wqkv.bias"], 3, dim=0) 375 | state_dict[f"blocks.{i}.attn1.to_q.weight"] = q 376 | state_dict[f"blocks.{i}.attn1.to_q.bias"] = q_bias 377 | state_dict[f"blocks.{i}.attn1.to_k.weight"] = k 378 | state_dict[f"blocks.{i}.attn1.to_k.bias"] = k_bias 379 | state_dict[f"blocks.{i}.attn1.to_v.weight"] = v 380 | state_dict[f"blocks.{i}.attn1.to_v.bias"] = v_bias 381 | state_dict.pop(f"blocks.{i}.attn1.Wqkv.weight") 382 | state_dict.pop(f"blocks.{i}.attn1.Wqkv.bias") 383 | 384 | # q_norm, k_norm -> norm_q, norm_k 385 | state_dict[f"blocks.{i}.attn1.norm_q.weight"] = state_dict[f"blocks.{i}.attn1.q_norm.weight"] 386 | state_dict[f"blocks.{i}.attn1.norm_q.bias"] = state_dict[f"blocks.{i}.attn1.q_norm.bias"] 387 | state_dict[f"blocks.{i}.attn1.norm_k.weight"] = state_dict[f"blocks.{i}.attn1.k_norm.weight"] 388 | state_dict[f"blocks.{i}.attn1.norm_k.bias"] = state_dict[f"blocks.{i}.attn1.k_norm.bias"] 389 | 390 | state_dict.pop(f"blocks.{i}.attn1.q_norm.weight") 391 | state_dict.pop(f"blocks.{i}.attn1.q_norm.bias") 392 | state_dict.pop(f"blocks.{i}.attn1.k_norm.weight") 393 | state_dict.pop(f"blocks.{i}.attn1.k_norm.bias") 394 | 395 | # out_proj -> to_out 396 | state_dict[f"blocks.{i}.attn1.to_out.0.weight"] = state_dict[f"blocks.{i}.attn1.out_proj.weight"] 397 | state_dict[f"blocks.{i}.attn1.to_out.0.bias"] = state_dict[f"blocks.{i}.attn1.out_proj.bias"] 398 | state_dict.pop(f"blocks.{i}.attn1.out_proj.weight") 399 | state_dict.pop(f"blocks.{i}.attn1.out_proj.bias") 400 | 401 | # attn2 402 | # kq_proj -> to_k, to_v 403 | k, v = torch.chunk(state_dict[f"blocks.{i}.attn2.kv_proj.weight"], 2, dim=0) 404 | k_bias, v_bias = torch.chunk(state_dict[f"blocks.{i}.attn2.kv_proj.bias"], 2, dim=0) 405 | state_dict[f"blocks.{i}.attn2.to_k.weight"] = k 406 | state_dict[f"blocks.{i}.attn2.to_k.bias"] = k_bias 407 | state_dict[f"blocks.{i}.attn2.to_v.weight"] = v 408 | state_dict[f"blocks.{i}.attn2.to_v.bias"] = v_bias 409 | state_dict.pop(f"blocks.{i}.attn2.kv_proj.weight") 410 | state_dict.pop(f"blocks.{i}.attn2.kv_proj.bias") 411 | 412 | # q_proj -> to_q 413 | state_dict[f"blocks.{i}.attn2.to_q.weight"] = state_dict[f"blocks.{i}.attn2.q_proj.weight"] 414 | state_dict[f"blocks.{i}.attn2.to_q.bias"] = state_dict[f"blocks.{i}.attn2.q_proj.bias"] 415 | state_dict.pop(f"blocks.{i}.attn2.q_proj.weight") 416 | state_dict.pop(f"blocks.{i}.attn2.q_proj.bias") 417 | 418 | # q_norm, k_norm -> norm_q, norm_k 419 | state_dict[f"blocks.{i}.attn2.norm_q.weight"] = state_dict[f"blocks.{i}.attn2.q_norm.weight"] 420 | state_dict[f"blocks.{i}.attn2.norm_q.bias"] = state_dict[f"blocks.{i}.attn2.q_norm.bias"] 421 | state_dict[f"blocks.{i}.attn2.norm_k.weight"] = state_dict[f"blocks.{i}.attn2.k_norm.weight"] 422 | state_dict[f"blocks.{i}.attn2.norm_k.bias"] = state_dict[f"blocks.{i}.attn2.k_norm.bias"] 423 | 424 | state_dict.pop(f"blocks.{i}.attn2.q_norm.weight") 425 | state_dict.pop(f"blocks.{i}.attn2.q_norm.bias") 426 | state_dict.pop(f"blocks.{i}.attn2.k_norm.weight") 427 | state_dict.pop(f"blocks.{i}.attn2.k_norm.bias") 428 | 429 | # out_proj -> to_out 430 | state_dict[f"blocks.{i}.attn2.to_out.0.weight"] = state_dict[f"blocks.{i}.attn2.out_proj.weight"] 431 | state_dict[f"blocks.{i}.attn2.to_out.0.bias"] = state_dict[f"blocks.{i}.attn2.out_proj.bias"] 432 | state_dict.pop(f"blocks.{i}.attn2.out_proj.weight") 433 | state_dict.pop(f"blocks.{i}.attn2.out_proj.bias") 434 | 435 | # switch norm 2 and norm 3 436 | norm2_weight = state_dict[f"blocks.{i}.norm2.weight"] 437 | norm2_bias = state_dict[f"blocks.{i}.norm2.bias"] 438 | state_dict[f"blocks.{i}.norm2.weight"] = state_dict[f"blocks.{i}.norm3.weight"] 439 | state_dict[f"blocks.{i}.norm2.bias"] = state_dict[f"blocks.{i}.norm3.bias"] 440 | state_dict[f"blocks.{i}.norm3.weight"] = norm2_weight 441 | state_dict[f"blocks.{i}.norm3.bias"] = norm2_bias 442 | 443 | # norm1 -> norm1.norm 444 | # default_modulation.1 -> norm1.linear 445 | state_dict[f"blocks.{i}.norm1.norm.weight"] = state_dict[f"blocks.{i}.norm1.weight"] 446 | state_dict[f"blocks.{i}.norm1.norm.bias"] = state_dict[f"blocks.{i}.norm1.bias"] 447 | state_dict[f"blocks.{i}.norm1.linear.weight"] = state_dict[f"blocks.{i}.default_modulation.1.weight"] 448 | state_dict[f"blocks.{i}.norm1.linear.bias"] = state_dict[f"blocks.{i}.default_modulation.1.bias"] 449 | state_dict.pop(f"blocks.{i}.norm1.weight") 450 | state_dict.pop(f"blocks.{i}.norm1.bias") 451 | state_dict.pop(f"blocks.{i}.default_modulation.1.weight") 452 | state_dict.pop(f"blocks.{i}.default_modulation.1.bias") 453 | 454 | # mlp.fc1 -> ff.net.0, mlp.fc2 -> ff.net.2 455 | state_dict[f"blocks.{i}.ff.net.0.proj.weight"] = state_dict[f"blocks.{i}.mlp.fc1.weight"] 456 | state_dict[f"blocks.{i}.ff.net.0.proj.bias"] = state_dict[f"blocks.{i}.mlp.fc1.bias"] 457 | state_dict[f"blocks.{i}.ff.net.2.weight"] = state_dict[f"blocks.{i}.mlp.fc2.weight"] 458 | state_dict[f"blocks.{i}.ff.net.2.bias"] = state_dict[f"blocks.{i}.mlp.fc2.bias"] 459 | state_dict.pop(f"blocks.{i}.mlp.fc1.weight") 460 | state_dict.pop(f"blocks.{i}.mlp.fc1.bias") 461 | state_dict.pop(f"blocks.{i}.mlp.fc2.weight") 462 | state_dict.pop(f"blocks.{i}.mlp.fc2.bias") 463 | 464 | # pooler -> time_extra_emb 465 | state_dict["time_extra_emb.pooler.positional_embedding"] = state_dict["pooler.positional_embedding"] 466 | state_dict["time_extra_emb.pooler.k_proj.weight"] = state_dict["pooler.k_proj.weight"] 467 | state_dict["time_extra_emb.pooler.k_proj.bias"] = state_dict["pooler.k_proj.bias"] 468 | state_dict["time_extra_emb.pooler.q_proj.weight"] = state_dict["pooler.q_proj.weight"] 469 | state_dict["time_extra_emb.pooler.q_proj.bias"] = state_dict["pooler.q_proj.bias"] 470 | state_dict["time_extra_emb.pooler.v_proj.weight"] = state_dict["pooler.v_proj.weight"] 471 | state_dict["time_extra_emb.pooler.v_proj.bias"] = state_dict["pooler.v_proj.bias"] 472 | state_dict["time_extra_emb.pooler.c_proj.weight"] = state_dict["pooler.c_proj.weight"] 473 | state_dict["time_extra_emb.pooler.c_proj.bias"] = state_dict["pooler.c_proj.bias"] 474 | state_dict.pop("pooler.k_proj.weight") 475 | state_dict.pop("pooler.k_proj.bias") 476 | state_dict.pop("pooler.q_proj.weight") 477 | state_dict.pop("pooler.q_proj.bias") 478 | state_dict.pop("pooler.v_proj.weight") 479 | state_dict.pop("pooler.v_proj.bias") 480 | state_dict.pop("pooler.c_proj.weight") 481 | state_dict.pop("pooler.c_proj.bias") 482 | state_dict.pop("pooler.positional_embedding") 483 | 484 | # t_embedder -> time_embedding (`TimestepEmbedding`) 485 | state_dict["time_extra_emb.timestep_embedder.linear_1.bias"] = state_dict["t_embedder.mlp.0.bias"] 486 | state_dict["time_extra_emb.timestep_embedder.linear_1.weight"] = state_dict["t_embedder.mlp.0.weight"] 487 | state_dict["time_extra_emb.timestep_embedder.linear_2.bias"] = state_dict["t_embedder.mlp.2.bias"] 488 | state_dict["time_extra_emb.timestep_embedder.linear_2.weight"] = state_dict["t_embedder.mlp.2.weight"] 489 | 490 | state_dict.pop("t_embedder.mlp.0.bias") 491 | state_dict.pop("t_embedder.mlp.0.weight") 492 | state_dict.pop("t_embedder.mlp.2.bias") 493 | state_dict.pop("t_embedder.mlp.2.weight") 494 | 495 | # x_embedder -> pos_embd (`PatchEmbed`) 496 | state_dict["pos_embed.proj.weight"] = state_dict["x_embedder.proj.weight"] 497 | state_dict["pos_embed.proj.bias"] = state_dict["x_embedder.proj.bias"] 498 | state_dict.pop("x_embedder.proj.weight") 499 | state_dict.pop("x_embedder.proj.bias") 500 | 501 | # mlp_t5 -> text_embedder 502 | state_dict["text_embedder.linear_1.bias"] = state_dict["mlp_t5.0.bias"] 503 | state_dict["text_embedder.linear_1.weight"] = state_dict["mlp_t5.0.weight"] 504 | state_dict["text_embedder.linear_2.bias"] = state_dict["mlp_t5.2.bias"] 505 | state_dict["text_embedder.linear_2.weight"] = state_dict["mlp_t5.2.weight"] 506 | state_dict.pop("mlp_t5.0.bias") 507 | state_dict.pop("mlp_t5.0.weight") 508 | state_dict.pop("mlp_t5.2.bias") 509 | state_dict.pop("mlp_t5.2.weight") 510 | 511 | # extra_embedder -> extra_embedder 512 | state_dict["time_extra_emb.extra_embedder.linear_1.bias"] = state_dict["extra_embedder.0.bias"] 513 | state_dict["time_extra_emb.extra_embedder.linear_1.weight"] = state_dict["extra_embedder.0.weight"] 514 | state_dict["time_extra_emb.extra_embedder.linear_2.bias"] = state_dict["extra_embedder.2.bias"] 515 | state_dict["time_extra_emb.extra_embedder.linear_2.weight"] = state_dict["extra_embedder.2.weight"] 516 | state_dict.pop("extra_embedder.0.bias") 517 | state_dict.pop("extra_embedder.0.weight") 518 | state_dict.pop("extra_embedder.2.bias") 519 | state_dict.pop("extra_embedder.2.weight") 520 | 521 | # model.final_adaLN_modulation.1 -> norm_out.linear 522 | def swap_scale_shift(weight): 523 | shift, scale = weight.chunk(2, dim=0) 524 | new_weight = torch.cat([scale, shift], dim=0) 525 | return new_weight 526 | 527 | state_dict["norm_out.linear.weight"] = swap_scale_shift(state_dict["final_layer.adaLN_modulation.1.weight"]) 528 | state_dict["norm_out.linear.bias"] = swap_scale_shift(state_dict["final_layer.adaLN_modulation.1.bias"]) 529 | state_dict.pop("final_layer.adaLN_modulation.1.weight") 530 | state_dict.pop("final_layer.adaLN_modulation.1.bias") 531 | 532 | # final_linear -> proj_out 533 | state_dict["proj_out.weight"] = state_dict["final_layer.linear.weight"] 534 | state_dict["proj_out.bias"] = state_dict["final_layer.linear.bias"] 535 | state_dict.pop("final_layer.linear.weight") 536 | state_dict.pop("final_layer.linear.bias") 537 | return state_dict 538 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.40.1 2 | git+https://github.com/huggingface/diffusers.git -------------------------------------------------------------------------------- /scripts/hunyuandit.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import gradio as gr 4 | from transformers import T5EncoderModel, MT5Tokenizer, BertModel, BertTokenizer 5 | from diffusers import AutoencoderKL, DDPMScheduler 6 | from modules import prompt_parser, shared, rng, devices, processing, scripts, masking, sd_models, sd_samplers_common, images, paths, face_restoration, script_callbacks 7 | from modules.sd_hijack import model_hijack 8 | from modules.timer import Timer 9 | from hunyuan_utils.utils import dit_sampler_dict, hunyuan_transformer_config_v12, retrieve_timesteps, get_timesteps, get_resize_crop_region_for_grid, unload_model, prepare_extra_step_kwargs, prepare_latents_txt2img, prepare_latents_img2img, guess_dit_model, convert_hunyuan_to_diffusers 10 | from hunyuan_utils import sd_hijack_clip_diffusers, diffusers_learned_conditioning 11 | import os 12 | import numpy as np 13 | from PIL import Image, ImageOps 14 | import cv2 15 | import hashlib 16 | 17 | shared.clip_l_model = None 18 | shared.mt5_model = None 19 | shared.vae_model = None 20 | 21 | def sample_txt2img(self, conditioning, unconditional_conditioning, seeds): 22 | # define sampler"" 23 | self.sampler = dit_sampler_dict.get((self.sampler_name+" "+self.scheduler.replace("Automatic","")).strip(),DDPMScheduler()).from_pretrained(shared.opts.Hunyuan_model_path,subfolder="scheduler") 24 | # reuse webui generated conditionings 25 | _, tensor = prompt_parser.reconstruct_multicond_batch(conditioning, 0) 26 | prompt_embeds = tensor["crossattn"] 27 | prompt_attention_mask = tensor["mask"] 28 | prompt_embeds_2 = tensor["crossattn_2"] 29 | prompt_attention_mask_2 = tensor["mask_2"] 30 | uncond = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, 0) 31 | negative_prompt_embeds = uncond["crossattn"] 32 | negative_prompt_attention_mask = uncond["mask"] 33 | negative_prompt_embeds_2 = uncond["crossattn_2"] 34 | negative_prompt_attention_mask_2 = uncond["mask_2"] 35 | # 4. Prepare timesteps 36 | self.sampler.set_timesteps(self.steps, device=devices.device) 37 | timesteps = self.sampler.timesteps 38 | shared.state.sampling_steps = len(timesteps) 39 | # 5. Prepare latents. 40 | latent_channels = self.sd_model.config.in_channels 41 | generators = [rng.create_generator(seed) for seed in seeds] 42 | latents = prepare_latents_txt2img( 43 | 2 ** (len(shared.vae_model.config.block_out_channels) - 1), 44 | self.sampler, 45 | self.batch_size, 46 | latent_channels, 47 | self.height, 48 | self.width, 49 | prompt_embeds.dtype, 50 | torch.device("cuda") if shared.opts.randn_source == "GPU" else torch.device("cpu"), 51 | generators, 52 | None 53 | ).to(devices.device) 54 | # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline 55 | extra_step_kwargs = prepare_extra_step_kwargs(self.sampler, generators, 0.0) 56 | 57 | # 7 create image_rotary_emb, style embedding & time ids 58 | grid_height = self.height // 8 // self.sd_model.config.patch_size 59 | grid_width = self.width // 8 // self.sd_model.config.patch_size 60 | base_size = 512 // 8 // self.sd_model.config.patch_size 61 | grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size) 62 | from diffusers.models.embeddings import get_2d_rotary_pos_embed 63 | image_rotary_emb = get_2d_rotary_pos_embed( 64 | self.sd_model.inner_dim // self.sd_model.num_heads, grid_crops_coords, (grid_height, grid_width) 65 | ) 66 | style = torch.tensor([0], device=devices.device) 67 | 68 | target_size = (self.height, self.width) 69 | add_time_ids = list((1024, 1024) + target_size + (0,0)) 70 | add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype) 71 | if self.cfg_scale > 1: 72 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) 73 | prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask]) 74 | prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2]) 75 | prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2]) 76 | add_time_ids = torch.cat([add_time_ids] * 2, dim=0) 77 | style = torch.cat([style] * 2, dim=0) 78 | add_time_ids = add_time_ids.to(dtype=prompt_embeds.dtype, device=devices.device).repeat( 79 | self.batch_size, 1 80 | ) 81 | style = style.to(device=devices.device).repeat(self.batch_size) 82 | for i, t in enumerate(timesteps): 83 | if shared.state.interrupted or shared.state.skipped: 84 | raise sd_samplers_common.InterruptedException 85 | # expand the latents if we are doing classifier free guidance 86 | latent_model_input = torch.cat([latents] * 2) if self.cfg_scale > 1.0 else latents 87 | latent_model_input = self.sampler.scale_model_input(latent_model_input, t) 88 | 89 | # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input 90 | t_expand = torch.tensor([t] * latent_model_input.shape[0], device=devices.device).to( 91 | dtype=latent_model_input.dtype 92 | ) 93 | # predict the noise residual 94 | noise_pred = self.sd_model( 95 | latent_model_input, 96 | t_expand, 97 | encoder_hidden_states=prompt_embeds, 98 | text_embedding_mask=prompt_attention_mask, 99 | encoder_hidden_states_t5=prompt_embeds_2, 100 | text_embedding_mask_t5=prompt_attention_mask_2, 101 | image_meta_size=add_time_ids, 102 | style=style, 103 | image_rotary_emb=image_rotary_emb, 104 | return_dict=False, 105 | )[0] 106 | 107 | noise_pred, _ = noise_pred.chunk(2, dim=1) 108 | 109 | # perform guidance 110 | if self.cfg_scale > 1.0: 111 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 112 | noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond) 113 | 114 | # compute the previous noisy sample x_t -> x_t-1 115 | latents = self.sampler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] 116 | # update process 117 | shared.state.sampling_step += 1 118 | shared.total_tqdm.update() 119 | return latents.to(devices.dtype) 120 | 121 | def sample_img2img(self, conditioning, unconditional_conditioning, seeds): 122 | # define sampler 123 | self.sampler = dit_sampler_dict.get((self.sampler_name+" "+self.scheduler.replace("Automatic","")).strip(),DDPMScheduler()).from_pretrained(shared.opts.Hunyuan_model_path,subfolder="scheduler") 124 | # reuse webui generated conditionings 125 | _, tensor = prompt_parser.reconstruct_multicond_batch(conditioning, 0) 126 | prompt_embeds = tensor["crossattn"] 127 | prompt_attention_mask = tensor["mask"] 128 | prompt_embeds_2 = tensor["crossattn_2"] 129 | prompt_attention_mask_2 = tensor["mask_2"] 130 | uncond = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, 0) 131 | negative_prompt_embeds = uncond["crossattn"] 132 | negative_prompt_attention_mask = uncond["mask"] 133 | negative_prompt_embeds_2 = uncond["crossattn_2"] 134 | negative_prompt_attention_mask_2 = uncond["mask_2"] 135 | # 4. Prepare timesteps 136 | timesteps, num_inference_steps = retrieve_timesteps( 137 | self.sampler, self.steps, devices.device, None, None 138 | ) 139 | timesteps, num_inference_steps = get_timesteps( 140 | self.sampler, 141 | num_inference_steps, 142 | self.denoising_strength, 143 | devices.device, 144 | denoising_start=None, 145 | ) 146 | latent_timestep = timesteps[:1].repeat(self.batch_size) 147 | shared.state.sampling_steps = len(timesteps) 148 | # 5. Prepare latents. 149 | latent_channels = self.sd_model.config.in_channels 150 | latents_outputs = prepare_latents_img2img( 151 | 2 ** (len(shared.vae_model.config.block_out_channels) - 1), 152 | self.sampler, 153 | self.image, 154 | self.batch_size, 155 | latent_channels, 156 | self.height, 157 | self.width, 158 | prompt_embeds.dtype, 159 | torch.device("cuda") if shared.opts.randn_source == "GPU" else torch.device("cpu"), 160 | None, 161 | seeds, 162 | latent_timestep 163 | ) 164 | latents, noise, image_latents = latents_outputs 165 | self.init_latent = latents 166 | # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline 167 | extra_step_kwargs = prepare_extra_step_kwargs(self.sampler, None, 0.0) 168 | 169 | # 7 create image_rotary_emb, style embedding & time ids 170 | grid_height = self.height // 8 // self.sd_model.config.patch_size 171 | grid_width = self.width // 8 // self.sd_model.config.patch_size 172 | base_size = 512 // 8 // self.sd_model.config.patch_size 173 | grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size) 174 | from diffusers.models.embeddings import get_2d_rotary_pos_embed 175 | image_rotary_emb = get_2d_rotary_pos_embed( 176 | self.sd_model.inner_dim // self.sd_model.num_heads, grid_crops_coords, (grid_height, grid_width) 177 | ) 178 | style = torch.tensor([0], device=devices.device) 179 | 180 | target_size = (self.height, self.width) 181 | add_time_ids = list((1024, 1024) + target_size + (0,0)) 182 | add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype) 183 | if self.cfg_scale > 1: 184 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) 185 | prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask]) 186 | prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2]) 187 | prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2]) 188 | add_time_ids = torch.cat([add_time_ids] * 2, dim=0) 189 | style = torch.cat([style] * 2, dim=0) 190 | add_time_ids = add_time_ids.to(dtype=prompt_embeds.dtype, device=devices.device).repeat( 191 | self.batch_size, 1 192 | ) 193 | style = style.to(device=devices.device).repeat(self.batch_size) 194 | for i, t in enumerate(timesteps): 195 | if shared.state.interrupted or shared.state.skipped: 196 | raise sd_samplers_common.InterruptedException 197 | # expand the latents if we are doing classifier free guidance 198 | latent_model_input = latents 199 | latent_model_input = self.sampler.scale_model_input(latent_model_input, t) 200 | 201 | # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input 202 | t_expand = torch.tensor([t] * latent_model_input.shape[0], device=devices.device).to( 203 | dtype=latent_model_input.dtype 204 | ) 205 | 206 | # predict the noise residual 207 | noise_pred = self.sd_model( 208 | latent_model_input, 209 | t_expand, 210 | encoder_hidden_states=prompt_embeds, 211 | text_embedding_mask=prompt_attention_mask, 212 | encoder_hidden_states_t5=prompt_embeds_2, 213 | text_embedding_mask_t5=prompt_attention_mask_2, 214 | image_meta_size=add_time_ids, 215 | style=style, 216 | image_rotary_emb=image_rotary_emb, 217 | return_dict=False, 218 | )[0] 219 | 220 | noise_pred, _ = noise_pred.chunk(2, dim=1) 221 | 222 | # perform guidance 223 | if self.cfg_scale > 1.0: 224 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 225 | noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond) 226 | # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf 227 | 228 | # compute the previous noisy sample x_t -> x_t-1 229 | latents = self.sampler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] 230 | if latent_channels == 4 and self.image_mask is not None: 231 | latents = self.mask * self.init_latent + self.nmask * latents 232 | # update process 233 | shared.state.sampling_step += 1 234 | shared.total_tqdm.update() 235 | 236 | return latents.to(devices.dtype) 237 | 238 | def init_img2img(self, all_prompts, all_seeds, all_subseeds): 239 | self.extra_generation_params["Denoising strength"] = self.denoising_strength 240 | 241 | self.image_cfg_scale: float = self.image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None 242 | 243 | #self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) 244 | crop_region = None 245 | 246 | image_mask = self.image_mask 247 | 248 | if image_mask is not None: 249 | # image_mask is passed in as RGBA by Gradio to support alpha masks, 250 | # but we still want to support binary masks. 251 | image_mask = processing.create_binary_mask(image_mask, round=self.mask_round) 252 | 253 | if self.inpainting_mask_invert: 254 | image_mask = ImageOps.invert(image_mask) 255 | self.extra_generation_params["Mask mode"] = "Inpaint not masked" 256 | 257 | if self.mask_blur_x > 0: 258 | np_mask = np.array(image_mask) 259 | kernel_size = 2 * int(2.5 * self.mask_blur_x + 0.5) + 1 260 | np_mask = cv2.GaussianBlur(np_mask, (kernel_size, 1), self.mask_blur_x) 261 | image_mask = Image.fromarray(np_mask) 262 | 263 | if self.mask_blur_y > 0: 264 | np_mask = np.array(image_mask) 265 | kernel_size = 2 * int(2.5 * self.mask_blur_y + 0.5) + 1 266 | np_mask = cv2.GaussianBlur(np_mask, (1, kernel_size), self.mask_blur_y) 267 | image_mask = Image.fromarray(np_mask) 268 | 269 | if self.mask_blur_x > 0 or self.mask_blur_y > 0: 270 | self.extra_generation_params["Mask blur"] = self.mask_blur 271 | 272 | if self.inpaint_full_res: 273 | self.mask_for_overlay = image_mask 274 | mask = image_mask.convert('L') 275 | crop_region = masking.get_crop_region_v2(mask, self.inpaint_full_res_padding) 276 | if crop_region: 277 | crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height) 278 | x1, y1, x2, y2 = crop_region 279 | mask = mask.crop(crop_region) 280 | image_mask = images.resize_image(2, mask, self.width, self.height) 281 | self.paste_to = (x1, y1, x2-x1, y2-y1) 282 | self.extra_generation_params["Inpaint area"] = "Only masked" 283 | self.extra_generation_params["Masked area padding"] = self.inpaint_full_res_padding 284 | else: 285 | crop_region = None 286 | image_mask = None 287 | self.mask_for_overlay = None 288 | self.inpaint_full_res = False 289 | massage = 'Unable to perform "Inpaint Only mask" because mask is blank, switch to img2img mode.' 290 | model_hijack.comments.append(massage) 291 | else: 292 | image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height) 293 | np_mask = np.array(image_mask) 294 | np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8) 295 | self.mask_for_overlay = Image.fromarray(np_mask) 296 | 297 | self.overlay_images = [] 298 | 299 | latent_mask = self.latent_mask if self.latent_mask is not None else image_mask 300 | 301 | add_color_corrections = shared.opts.img2img_color_correction and self.color_corrections is None 302 | if add_color_corrections: 303 | self.color_corrections = [] 304 | imgs = [] 305 | for img in self.init_images: 306 | 307 | # Save init image 308 | if shared.opts.save_init_img: 309 | self.init_img_hash = hashlib.md5(img.tobytes()).hexdigest() 310 | images.save_image(img, path=shared.opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, save_to_dirs=False, existing_info=img.info) 311 | 312 | image = images.flatten(img, shared.opts.img2img_background_color) 313 | 314 | if crop_region is None and self.resize_mode != 3: 315 | image = images.resize_image(self.resize_mode, image, self.width, self.height) 316 | 317 | if image_mask is not None: 318 | if self.mask_for_overlay.size != (image.width, image.height): 319 | self.mask_for_overlay = images.resize_image(self.resize_mode, self.mask_for_overlay, image.width, image.height) 320 | image_masked = Image.new('RGBa', (image.width, image.height)) 321 | image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L'))) 322 | 323 | self.overlay_images.append(image_masked.convert('RGBA')) 324 | 325 | # crop_region is not None if we are doing inpaint full res 326 | if crop_region is not None: 327 | image = image.crop(crop_region) 328 | image = images.resize_image(2, image, self.width, self.height) 329 | 330 | if image_mask is not None: 331 | if self.inpainting_fill != 1: 332 | image = masking.fill(image, latent_mask) 333 | 334 | if self.inpainting_fill == 0: 335 | self.extra_generation_params["Masked content"] = 'fill' 336 | 337 | if add_color_corrections: 338 | self.color_corrections.append(processing.setup_color_correction(image)) 339 | 340 | image = np.array(image).astype(np.float32) / 255.0 341 | image = np.moveaxis(image, 2, 0) 342 | 343 | imgs.append(image) 344 | 345 | if len(imgs) == 1: 346 | batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0) 347 | if self.overlay_images is not None: 348 | self.overlay_images = self.overlay_images * self.batch_size 349 | 350 | if self.color_corrections is not None and len(self.color_corrections) == 1: 351 | self.color_corrections = self.color_corrections * self.batch_size 352 | 353 | elif len(imgs) <= self.batch_size: 354 | self.batch_size = len(imgs) 355 | batch_images = np.array(imgs) 356 | else: 357 | raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less") 358 | 359 | image = torch.from_numpy(batch_images) 360 | self.image = image.to(shared.device, dtype=devices.dtype_vae) 361 | 362 | def process_images_inner_hunyuan(p: processing.StableDiffusionProcessing) -> processing.Processed: 363 | """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch""" 364 | 365 | if isinstance(p.prompt, list): 366 | assert(len(p.prompt) > 0) 367 | else: 368 | assert p.prompt is not None 369 | 370 | devices.torch_gc() 371 | 372 | seed = processing.get_fixed_seed(p.seed) 373 | subseed = processing.get_fixed_seed(p.subseed) 374 | 375 | if p.restore_faces is None: 376 | p.restore_faces = shared.opts.face_restoration 377 | 378 | if p.tiling is None: 379 | p.tiling = shared.opts.tiling 380 | 381 | # disable refiner 382 | ''' 383 | if p.refiner_checkpoint not in (None, "", "None", "none"): 384 | p.refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(p.refiner_checkpoint) 385 | if p.refiner_checkpoint_info is None: 386 | raise Exception(f'Could not find checkpoint with name {p.refiner_checkpoint}') 387 | ''' 388 | p.sd_model_name = shared.sd_model.sd_checkpoint_info.name_for_extra 389 | p.sd_model_hash = shared.sd_model.sd_model_hash 390 | # disable stable diffusion vae 391 | ''' 392 | p.sd_vae_name = sd_vae.get_loaded_vae_name() 393 | p.sd_vae_hash = sd_vae.get_loaded_vae_hash() 394 | ''' 395 | model_hijack.apply_circular(p.tiling) 396 | model_hijack.clear_comments() 397 | 398 | p.setup_prompts() 399 | 400 | if isinstance(seed, list): 401 | p.all_seeds = seed 402 | else: 403 | p.all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(p.all_prompts))] 404 | 405 | if isinstance(subseed, list): 406 | p.all_subseeds = subseed 407 | else: 408 | p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))] 409 | 410 | if os.path.exists(shared.cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings: 411 | model_hijack.embedding_db.load_textual_inversion_embeddings() 412 | 413 | if p.scripts is not None: 414 | p.scripts.process(p) 415 | 416 | infotexts = [] 417 | output_images = [] 418 | with torch.no_grad(): 419 | with devices.autocast(): 420 | p.init(p.all_prompts, p.all_seeds, p.all_subseeds) 421 | 422 | # disable stable diffusion vae 423 | ''' 424 | # for OSX, loading the model during sampling changes the generated picture, so it is loaded here 425 | if shared.opts.live_previews_enable and opts.show_progress_type == "Approx NN": 426 | sd_vae_approx.model() 427 | 428 | sd_unet.apply_unet() 429 | ''' 430 | if shared.state.job_count == -1: 431 | shared.state.job_count = p.n_iter 432 | 433 | for n in range(p.n_iter): 434 | p.iteration = n 435 | 436 | if shared.state.skipped: 437 | shared.state.skipped = False 438 | 439 | if shared.state.interrupted or shared.state.stopping_generation: 440 | break 441 | 442 | sd_models.reload_model_weights() # model can be changed for example by refiner 443 | 444 | p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size] 445 | p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size] 446 | p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size] 447 | p.subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size] 448 | 449 | # disable webui rng for stable diffusion 450 | #p.rng = rng.ImageRNG((opt_C, p.height // opt_f, p.width // opt_f), p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w) 451 | 452 | if p.scripts is not None: 453 | p.scripts.before_process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds) 454 | 455 | if len(p.prompts) == 0: 456 | break 457 | # disabled sd webui type loras 458 | ''' 459 | p.parse_extra_network_prompts() 460 | 461 | if not p.disable_extra_networks: 462 | with devices.autocast(): 463 | extra_networks.activate(p, p.extra_network_data) 464 | ''' 465 | if p.scripts is not None: 466 | p.scripts.process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds) 467 | 468 | p.setup_conds() 469 | 470 | # p.extra_generation_params.update(model_hijack.extra_generation_params) 471 | 472 | # params.txt should be saved after scripts.process_batch, since the 473 | # infotext could be modified by that callback 474 | # Example: a wildcard processed by process_batch sets an extra model 475 | # strength, which is saved as "Model Strength: 1.0" in the infotext 476 | if n == 0 and not shared.cmd_opts.no_prompt_history: 477 | with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file: 478 | processed = processing.Processed(p, []) 479 | file.write(processed.infotext(p, 0)) 480 | 481 | for comment in model_hijack.comments: 482 | p.comment(comment) 483 | 484 | if p.n_iter > 1: 485 | shared.state.job = f"Batch {n+1} out of {p.n_iter}" 486 | 487 | sd_models.apply_alpha_schedule_override(p.sd_model, p) 488 | 489 | with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): 490 | samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds) 491 | 492 | if p.scripts is not None: 493 | ps = scripts.PostSampleArgs(samples_ddim) 494 | p.scripts.post_sample(p, ps) 495 | samples_ddim = ps.samples 496 | 497 | if getattr(samples_ddim, 'already_decoded', False): 498 | x_samples_ddim = samples_ddim 499 | else: 500 | if shared.opts.sd_vae_decode_method != 'Full': 501 | p.extra_generation_params['VAE Decoder'] = shared.opts.sd_vae_decode_method 502 | x_samples_ddim = shared.vae_model.decode(samples_ddim / shared.vae_model.config.scaling_factor, return_dict=False)[0] 503 | 504 | x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) 505 | x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).float().numpy() 506 | 507 | del samples_ddim 508 | 509 | devices.torch_gc() 510 | 511 | shared.state.nextjob() 512 | 513 | if p.scripts is not None: 514 | p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n) 515 | 516 | p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size] 517 | p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size] 518 | 519 | batch_params = scripts.PostprocessBatchListArgs(list(x_samples_ddim)) 520 | p.scripts.postprocess_batch_list(p, batch_params, batch_number=n) 521 | x_samples_ddim = batch_params.images 522 | 523 | def infotext(index=0, use_main_prompt=False): 524 | return processing.create_infotext(p, p.prompts, p.seeds, p.subseeds, use_main_prompt=use_main_prompt, index=index, all_negative_prompts=p.negative_prompts) 525 | 526 | save_samples = p.save_samples() 527 | 528 | for i, x_sample in enumerate(x_samples_ddim): 529 | p.batch_index = i 530 | 531 | x_sample = 255. * x_sample 532 | x_sample = x_sample.astype(np.uint8) 533 | 534 | if p.restore_faces: 535 | if save_samples and shared.opts.save_images_before_face_restoration: 536 | images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", p.seeds[i], p.prompts[i], shared.opts.samples_format, info=infotext(i), p=p, suffix="-before-face-restoration") 537 | 538 | devices.torch_gc() 539 | 540 | x_sample = face_restoration.restore_faces(x_sample) 541 | devices.torch_gc() 542 | 543 | image = Image.fromarray(x_sample) 544 | 545 | if p.scripts is not None: 546 | pp = scripts.PostprocessImageArgs(image) 547 | p.scripts.postprocess_image(p, pp) 548 | image = pp.image 549 | 550 | mask_for_overlay = getattr(p, "mask_for_overlay", None) 551 | 552 | if not shared.opts.overlay_inpaint: 553 | overlay_image = None 554 | elif getattr(p, "overlay_images", None) is not None and i < len(p.overlay_images): 555 | overlay_image = p.overlay_images[i] 556 | else: 557 | overlay_image = None 558 | 559 | if p.scripts is not None: 560 | ppmo = scripts.PostProcessMaskOverlayArgs(i, mask_for_overlay, overlay_image) 561 | p.scripts.postprocess_maskoverlay(p, ppmo) 562 | mask_for_overlay, overlay_image = ppmo.mask_for_overlay, ppmo.overlay_image 563 | 564 | if p.color_corrections is not None and i < len(p.color_corrections): 565 | if save_samples and shared.opts.save_images_before_color_correction: 566 | image_without_cc, _ = processing.apply_overlay(image, p.paste_to, overlay_image) 567 | images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], shared.opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction") 568 | image = processing.apply_color_correction(p.color_corrections[i], image) 569 | 570 | # If the intention is to show the output from the model 571 | # that is being composited over the original image, 572 | # we need to keep the original image around 573 | # and use it in the composite step. 574 | image, original_denoised_image = processing.apply_overlay(image, p.paste_to, overlay_image) 575 | 576 | if p.scripts is not None: 577 | pp = scripts.PostprocessImageArgs(image) 578 | p.scripts.postprocess_image_after_composite(p, pp) 579 | image = pp.image 580 | 581 | if save_samples: 582 | images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], shared.opts.samples_format, info=infotext(i), p=p) 583 | 584 | text = infotext(i) 585 | infotexts.append(text) 586 | if shared.opts.enable_pnginfo: 587 | image.info["parameters"] = text 588 | output_images.append(image) 589 | 590 | if mask_for_overlay is not None: 591 | if shared.opts.return_mask or shared.opts.save_mask: 592 | image_mask = mask_for_overlay.convert('RGB') 593 | if save_samples and shared.opts.save_mask: 594 | images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], shared.opts.samples_format, info=infotext(i), p=p, suffix="-mask") 595 | if shared.opts.return_mask: 596 | output_images.append(image_mask) 597 | 598 | if shared.opts.return_mask_composite or shared.opts.save_mask_composite: 599 | image_mask_composite = Image.composite(original_denoised_image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA') 600 | if save_samples and shared.opts.save_mask_composite: 601 | images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], shared.opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite") 602 | if shared.opts.return_mask_composite: 603 | output_images.append(image_mask_composite) 604 | 605 | del x_samples_ddim 606 | 607 | devices.torch_gc() 608 | 609 | if not infotexts: 610 | infotexts.append(processing.Processed(p, []).infotext(p, 0)) 611 | 612 | p.color_corrections = None 613 | 614 | index_of_first_image = 0 615 | unwanted_grid_because_of_img_count = len(output_images) < 2 and shared.opts.grid_only_if_multiple 616 | if (shared.opts.return_grid or shared.opts.grid_save) and not p.do_not_save_grid and not unwanted_grid_because_of_img_count: 617 | grid = images.image_grid(output_images, p.batch_size) 618 | 619 | if shared.opts.return_grid: 620 | text = infotext(use_main_prompt=True) 621 | infotexts.insert(0, text) 622 | if shared.opts.enable_pnginfo: 623 | grid.info["parameters"] = text 624 | output_images.insert(0, grid) 625 | index_of_first_image = 1 626 | if shared.opts.grid_save: 627 | images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], shared.opts.grid_format, info=infotext(use_main_prompt=True), short_filename=not shared.opts.grid_extended_filename, p=p, grid=True) 628 | 629 | # disable sd webui type loras 630 | ''' 631 | if not p.disable_extra_networks and p.extra_network_data: 632 | extra_networks.deactivate(p, p.extra_network_data) 633 | ''' 634 | devices.torch_gc() 635 | 636 | res = processing.Processed( 637 | p, 638 | images_list=output_images, 639 | seed=p.all_seeds[0], 640 | info=infotexts[0], 641 | subseed=p.all_subseeds[0], 642 | index_of_first_image=index_of_first_image, 643 | infotexts=infotexts, 644 | ) 645 | 646 | if p.scripts is not None: 647 | p.scripts.postprocess(p, res) 648 | 649 | return res 650 | 651 | def load_model_hunyuan(checkpoint_info=None, already_loaded_state_dict=None): 652 | from modules import sd_hijack 653 | from diffusers import HunyuanDiT2DModel 654 | checkpoint_info = checkpoint_info or sd_models.select_checkpoint() 655 | 656 | timer = Timer() 657 | 658 | if sd_models.model_data.sd_model: 659 | sd_models.model_data.sd_model.to("cpu") 660 | sd_models.model_data.sd_model = None 661 | devices.torch_gc() 662 | 663 | timer.record("unload existing model") 664 | 665 | if already_loaded_state_dict is not None: 666 | state_dict = already_loaded_state_dict 667 | else: 668 | state_dict = sd_models.get_checkpoint_state_dict(checkpoint_info, timer) 669 | 670 | timer.record("load weights from state dict") 671 | 672 | sd_model = HunyuanDiT2DModel.from_config(hunyuan_transformer_config_v12) 673 | print("loading hunyuan DiT") 674 | checkpoint_config = guess_dit_model(state_dict) 675 | sd_model.used_config = checkpoint_config 676 | if checkpoint_config == "hunyuan-original": 677 | state_dict = convert_hunyuan_to_diffusers(state_dict) 678 | elif "hunyuan" not in checkpoint_config: 679 | raise ValueError("Found no hunyuan DiT model") 680 | sd_model.load_state_dict(state_dict, strict=False) 681 | del state_dict 682 | 683 | print("loading text encoder and vae") 684 | shared.clip_l_model = BertModel.from_pretrained(shared.opts.Hunyuan_model_path,subfolder="text_encoder",torch_dtype=devices.dtype).to(devices.device) 685 | shared.mt5_model = T5EncoderModel.from_pretrained(shared.opts.Hunyuan_model_path,subfolder="text_encoder_2",torch_dtype=devices.dtype).to(devices.device) 686 | shared.clip_l_model.tokenizer = BertTokenizer.from_pretrained(shared.opts.Hunyuan_model_path,subfolder="tokenizer") 687 | shared.mt5_model.tokenizer = MT5Tokenizer.from_pretrained(shared.opts.Hunyuan_model_path,subfolder="tokenizer_2") 688 | shared.clip_l_model = sd_hijack_clip_diffusers.FrozenBertEmbedderWithCustomWords(shared.clip_l_model,sd_hijack.model_hijack) 689 | shared.mt5_model = sd_hijack_clip_diffusers.FrozenT5EmbedderWithCustomWords(shared.mt5_model,sd_hijack.model_hijack) 690 | shared.clip_l_model.return_masks = True 691 | shared.mt5_model.return_masks = True 692 | shared.vae_model = AutoencoderKL.from_pretrained(shared.opts.Hunyuan_model_path,subfolder="vae",torch_dtype=devices.dtype).to(devices.device) 693 | 694 | sd_model.to(devices.dtype) 695 | sd_model.to(devices.device) 696 | sd_model.eval() 697 | sd_model_hash = checkpoint_info.calculate_shorthash() 698 | sd_model.sd_model_hash = sd_model_hash 699 | sd_model.sd_model_checkpoint = checkpoint_info.filename 700 | sd_model.sd_checkpoint_info = checkpoint_info 701 | sd_model.lowvram = False 702 | sd_model.is_sd1 = False 703 | sd_model.is_sd2 = False 704 | sd_model.is_sdxl = False 705 | sd_model.is_ssd = False 706 | sd_model.is_sd3 = False 707 | sd_model.model = None 708 | sd_model.first_stage_model = None 709 | sd_model.cond_stage_key = None 710 | sd_model.cond_stage_model = None 711 | sd_model.get_learned_conditioning = diffusers_learned_conditioning.get_learned_conditioning_hunyuan 712 | sd_models.model_data.set_sd_model(sd_model) 713 | sd_models.model_data.was_loaded_at_least_once = True 714 | 715 | script_callbacks.model_loaded_callback(sd_model) 716 | 717 | timer.record("scripts callbacks") 718 | 719 | print(f"Model loaded in {timer.summary()}.") 720 | 721 | return sd_model 722 | 723 | def reload_model_weights_hunyuan(sd_model=None, info=None, forced_reload=False): 724 | checkpoint_info = info or sd_models.select_checkpoint() 725 | 726 | timer = Timer() 727 | 728 | if not sd_model: 729 | sd_model = sd_models.model_data.sd_model 730 | 731 | if sd_model is None: # previous model load failed 732 | current_checkpoint_info = None 733 | else: 734 | current_checkpoint_info = sd_model.sd_checkpoint_info 735 | if sd_model.sd_model_checkpoint == checkpoint_info.filename and not forced_reload: 736 | return sd_model 737 | 738 | sd_model.to(devices.dtype) 739 | sd_model.to(devices.device) 740 | if not forced_reload and sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename: 741 | return sd_model 742 | 743 | if sd_model is not None: 744 | sd_models.send_model_to_cpu(sd_model) 745 | 746 | state_dict = sd_models.get_checkpoint_state_dict(checkpoint_info, timer) 747 | 748 | checkpoint_config = guess_dit_model(state_dict) 749 | if checkpoint_config == "hunyuan-original": 750 | state_dict = convert_hunyuan_to_diffusers(state_dict) 751 | elif "hunyuan" not in checkpoint_config: 752 | raise ValueError("Found no hunyuan DiT model") 753 | timer.record("find config") 754 | 755 | if sd_model is None or checkpoint_config != sd_model.used_config: 756 | load_model_hunyuan(checkpoint_info, already_loaded_state_dict=state_dict) 757 | return sd_models.model_data.sd_model 758 | try: 759 | sd_model.load_state_dict(state_dict, strict=False) 760 | del state_dict 761 | sd_model_hash = checkpoint_info.calculate_shorthash() 762 | sd_model.sd_model_hash = sd_model_hash 763 | sd_model.sd_model_checkpoint = checkpoint_info.filename 764 | sd_model.sd_checkpoint_info = checkpoint_info 765 | sd_model.lowvram = False 766 | sd_model.is_sd1 = False 767 | sd_model.is_sd2 = False 768 | sd_model.is_sdxl = False 769 | sd_model.is_ssd = False 770 | sd_model.is_sd3 = False 771 | sd_model.model = None 772 | sd_model.first_stage_model = None 773 | sd_model.cond_stage_key = None 774 | sd_model.cond_stage_model = None 775 | except Exception: 776 | print("Failed to load checkpoint, restoring previous") 777 | state_dict = sd_models.get_checkpoint_state_dict(current_checkpoint_info, timer) 778 | sd_model.load_state_dict(state_dict, strict=False) 779 | del state_dict 780 | sd_model_hash = checkpoint_info.calculate_shorthash() 781 | sd_model.sd_model_hash = sd_model_hash 782 | sd_model.sd_model_checkpoint = checkpoint_info.filename 783 | sd_model.sd_checkpoint_info = checkpoint_info 784 | sd_model.lowvram = False 785 | sd_model.is_sd1 = False 786 | sd_model.is_sd2 = False 787 | sd_model.is_sdxl = False 788 | sd_model.is_ssd = False 789 | sd_model.is_sd3 = False 790 | sd_model.model = None 791 | sd_model.first_stage_model = None 792 | sd_model.cond_stage_key = None 793 | sd_model.cond_stage_model = None 794 | raise 795 | finally: 796 | script_callbacks.model_loaded_callback(sd_model) 797 | timer.record("script callbacks") 798 | 799 | print(f"Weights loaded in {timer.summary()}.") 800 | 801 | sd_models.model_data.set_sd_model(sd_model) 802 | 803 | return sd_model 804 | 805 | class Script(scripts.Script): 806 | 807 | def __init__(self): 808 | super(Script, self).__init__() 809 | def title(self): 810 | return 'Hunyuan DiT' 811 | 812 | def show(self, is_img2img): 813 | return scripts.AlwaysVisible 814 | 815 | def ui(self, is_img2img): 816 | tab = 't2i' if not is_img2img else 'i2i' 817 | is_t2i = 'true' if not is_img2img else 'false' 818 | uid = lambda name: f'MD-{tab}-{name}' 819 | 820 | with gr.Accordion('Hunyuan DiT', open=False): 821 | with gr.Row(variant='compact') as tab_enable: 822 | enabled = gr.Checkbox(label='Enable Hunyuan DiT', value=False, elem_id=uid('enabled')) 823 | enabled.change( 824 | fn=on_enable_change, 825 | inputs=[enabled], 826 | outputs=None 827 | ) 828 | return [ 829 | enabled 830 | ] 831 | 832 | def on_enable_change(enabled: bool): 833 | if enabled: 834 | print("Enable Hunyuan DiT") 835 | hijack() 836 | else: 837 | print("Disable Hunyuan DiT") 838 | reset() 839 | shared.clip_l_model = unload_model(shared.clip_l_model) 840 | shared.mt5_model = unload_model(shared.mt5_model) 841 | shared.vae_model = unload_model(shared.vae_model) 842 | 843 | def reset(): 844 | ''' unhijack inner APIs ''' 845 | if hasattr(processing,"process_images_inner_original"): 846 | processing.process_images_inner = processing.process_images_inner_original 847 | if hasattr(processing.StableDiffusionProcessingTxt2Img,"sample_original"): 848 | processing.StableDiffusionProcessingTxt2Img.sample = processing.StableDiffusionProcessingTxt2Img.sample_original 849 | if hasattr(processing.StableDiffusionProcessingImg2Img,"sample_original"): 850 | processing.StableDiffusionProcessingImg2Img.sample = processing.StableDiffusionProcessingImg2Img.sample_original 851 | if hasattr(sd_models,"load_model_original"): 852 | sd_models.load_model = sd_models.load_model_original 853 | if hasattr(sd_models,"reload_model_weights_original"): 854 | sd_models.reload_model_weights = sd_models.reload_model_weights_original 855 | if hasattr(processing.StableDiffusionProcessingImg2Img,"init_img2img_original"): 856 | processing.StableDiffusionProcessingImg2Img.init = processing.StableDiffusionProcessingImg2Img.init_img2img_original 857 | 858 | def hijack(): 859 | ''' hijack inner APIs ''' 860 | if not hasattr(processing,"process_images_inner_original"): 861 | processing.process_images_inner_original = processing.process_images_inner 862 | if not hasattr(processing.StableDiffusionProcessingTxt2Img,"sample_original"): 863 | processing.StableDiffusionProcessingTxt2Img.sample_original = processing.StableDiffusionProcessingTxt2Img.sample 864 | if not hasattr(processing.StableDiffusionProcessingImg2Img,"sample_original"): 865 | processing.StableDiffusionProcessingImg2Img.sample_original = processing.StableDiffusionProcessingImg2Img.sample 866 | if not hasattr(sd_models,"load_model_original"): 867 | sd_models.load_model_original = sd_models.load_model 868 | if not hasattr(sd_models,"reload_model_weights_original"): 869 | sd_models.reload_model_weights_original = sd_models.reload_model_weights 870 | if not hasattr(processing.StableDiffusionProcessingImg2Img,"init_img2img_original"): 871 | processing.StableDiffusionProcessingImg2Img.init_img2img_original = processing.StableDiffusionProcessingImg2Img.init 872 | processing.process_images_inner = process_images_inner_hunyuan 873 | processing.StableDiffusionProcessingTxt2Img.sample = sample_txt2img 874 | processing.StableDiffusionProcessingImg2Img.sample = sample_img2img 875 | sd_models.load_model = load_model_hunyuan 876 | sd_models.reload_model_weights = reload_model_weights_hunyuan 877 | processing.StableDiffusionProcessingImg2Img.init = init_img2img 878 | 879 | def on_ui_settings(): 880 | 881 | shared.opts.add_option("Hunyuan_model_path", shared.OptionInfo("./models/hunyuan", "Hunyuan Model Path",section=('hunyuanDiT', "HunyuanDiT"))) 882 | 883 | script_callbacks.on_ui_settings(on_ui_settings) 884 | --------------------------------------------------------------------------------