├── README.md ├── __init__.py ├── inference.py ├── modules ├── __init__.py ├── attention.py ├── autoencoder.py ├── conditioner.py ├── connector_edit.py ├── layers.py └── model_edit.py ├── nodes.py ├── requirements.txt └── sampling.py /README.md: -------------------------------------------------------------------------------- 1 | # Step1X Edit for ComfyUI 2 | 3 | This is a ComfyUI custom node implementation for image editing using the Step-1 model architecture, specifically adapted for reference-based image editing guided by text prompts. 4 | 5 | ## Online Access 6 | You can access RunningHub online to use this plugin and models for free: 7 | ### English Version 8 | - **Run & Download Workflow**: 9 | [https://www.runninghub.ai/post/1916456042962817026](https://www.runninghub.ai/post/1916456042962817026) 10 | ### 中文版本 11 | - **运行并下载工作流**: 12 | [https://www.runninghub.cn/post/1916456042962817026](https://www.runninghub.cn/post/1916456042962817026) 13 | ## Features 14 | 15 | * Implementation of the Step-1 image editing concept within ComfyUI. 16 | * Optimized for running on GPUs with 24GB VRAM. 17 | * Inference for potentially faster performance and lower memory usage. 18 | * Simple node interface for ease of use. 19 | * Passed testing locally on Windows with an RTX 4090; generating a single image takes approximately 100 seconds. 20 | 21 | ## Model Download Guide 22 | 23 | Place the downloaded models in your `ComfyUI/models/step-1/` directory. 24 | 25 | ### Required Models: 26 | 27 | 1. **Step1X Edit Model:** Contains the main diffusion model weights adapted for editing. 28 | 2. **VAE:** Used for encoding and decoding images to/from latent space. 29 | 3. **Qwen2.5-VL-7B-Instruct:** The vision-language model used for text and image conditioning. 30 | 31 | ### Choose a Download Method (Pick One) 32 | 33 | 1. **One-Click Download with Python Script:** 34 | *Requires the `huggingface_hub` library (`pip install huggingface-hub`)* 35 | ```python 36 | from huggingface_hub import snapshot_download 37 | import os 38 | 39 | # Define the target directory within ComfyUI models 40 | target_dir = "path/to/your/ComfyUI/models/step-1" 41 | os.makedirs(target_dir, exist_ok=True) 42 | 43 | # --- Download Step1X Edit Model --- 44 | snapshot_download( 45 | repo_id="stepfun-ai/Step1X-Edit", 46 | local_dir=step-1, 47 | allow_patterns=["step1x-edit-i1258.safetensors"], 48 | local_dir_use_symlinks=False 49 | ) 50 | 51 | # --- Download VAE --- 52 | snapshot_download( 53 | repo_id="stepfun-ai/Step1X-Edit", # VAE is in the same repo 54 | local_dir=step-1, 55 | allow_patterns=["vae.safetensors"], 56 | local_dir_use_symlinks=False 57 | ) 58 | 59 | # --- Download Qwen2.5-VL-7B-Instruct --- 60 | qwen_dir = os.path.join(target_dir, "Qwen2.5-VL-7B-Instruct") 61 | snapshot_download( 62 | repo_id="Qwen/Qwen2.5-VL-7B-Instruct", 63 | local_dir=step-1/Qwen2.5-VL-7B-Instruct, 64 | # ignore_patterns=["*.git*", "*.log*", "*.md", "*.jpg"], # Optional: reduce download size 65 | local_dir_use_symlinks=False 66 | ) 67 | 68 | print(f"Downloads complete. Models should be in {target_dir}") 69 | ``` 70 | 71 | 2. **Manual Download:** 72 | * **Step1X Edit:** Download `step1x-edit-i1258.safetensors` ([step1x-edit-i1258.safetensors](https://huggingface.co/stepfun-ai/Step1X-Edit/resolve/main/step1x-edit-i1258.safetensors)) 73 | * **VAE:** Download `vae.safetensors` ([vae.safetensors](https://huggingface.co/stepfun-ai/Step1X-Edit/resolve/main/vae.safetensors)) 74 | * **Qwen2.5-VL:** Download the entire repository: [Qwen/Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) 75 | 76 | Place the `step1x-edit-i1258.safetensors` and `vae.safetensors` files, and the `Qwen2.5-VL-7B-Instruct` folder into your ComfyUI `models/step-1` directory (you might need to create the `step-1` subfolder). Your final structure should look like: 77 | ``` 78 | ComfyUI/ 79 | └── models/ 80 | └── step-1/ 81 | ├── step1x-edit-i1258.safetensors 82 | ├── vae.safetensors 83 | └── Qwen2.5-VL-7B-Instruct/ 84 | ├── ... (all files from the Qwen repo) 85 | ``` 86 | 87 | ## Installation 88 | 89 | 1. Clone this repository into your `ComfyUI/custom_nodes/` directory: 90 | ```bash 91 | cd path/to/your/ComfyUI/custom_nodes/ 92 | git clone https://github.com/HM-RunningHub/ComfyUI_RH_Step1XEdit.git 93 | ``` 94 | 2. Install the required dependencies: 95 | ```bash 96 | cd ComfyUI_RH_Step1XEdit 97 | pip install -r requirements.txt 98 | ``` 99 | 3. Restart ComfyUI. 100 | 101 | **(Example Image/Workflow)** 102 | 103 | ![image](https://github.com/user-attachments/assets/035274a4-fc47-4249-acf0-a5e31cdd1671) 104 | 105 | 4. Thanks 106 | https://huggingface.co/stepfun-ai/Step1X-Edit 107 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .nodes import NODE_CLASS_MAPPINGS 3 | NODE_DISPLAY_NAME_MAPPINGS = {k:v.TITLE for k,v in NODE_CLASS_MAPPINGS.items()} 4 | __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS'] 5 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import json 4 | import itertools 5 | import math 6 | import os 7 | import time 8 | from pathlib import Path 9 | 10 | import numpy as np 11 | import torch 12 | from einops import rearrange, repeat 13 | from PIL import Image, ImageOps 14 | from safetensors.torch import load_file 15 | from torchvision.transforms import functional as F 16 | from tqdm import tqdm 17 | 18 | from . import sampling 19 | from .modules.autoencoder import AutoEncoder 20 | from .modules.conditioner import Qwen25VL_7b_Embedder as Qwen2VLEmbedder 21 | from .modules.model_edit import Step1XParams, Step1XEdit 22 | 23 | import gc 24 | import subprocess 25 | import folder_paths 26 | 27 | # def load_state_dict(model, ckpt_path, device="cuda", strict=False, assign=True): 28 | def load_state_dict(model, ckpt_path, device="cpu", strict=False, assign=True): 29 | if Path(ckpt_path).suffix == ".safetensors": 30 | state_dict = load_file(ckpt_path, device) 31 | else: 32 | state_dict = torch.load(ckpt_path, map_location="cpu") 33 | 34 | missing, unexpected = model.load_state_dict( 35 | state_dict, strict=strict, assign=assign 36 | ) 37 | if len(missing) > 0 and len(unexpected) > 0: 38 | print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) 39 | print("\n" + "-" * 79 + "\n") 40 | print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) 41 | elif len(missing) > 0: 42 | print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) 43 | elif len(unexpected) > 0: 44 | print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) 45 | return model 46 | 47 | 48 | def load_models( 49 | dit_path=None, 50 | ae_path=None, 51 | qwen2vl_model_path=None, 52 | # device="cuda", 53 | device='cpu', 54 | max_length=256, 55 | dtype=torch.bfloat16, 56 | ): 57 | qwen2vl_encoder = Qwen2VLEmbedder( 58 | qwen2vl_model_path, 59 | device=device, 60 | max_length=max_length, 61 | dtype=dtype, 62 | ) 63 | 64 | with torch.device("meta"): 65 | ae = AutoEncoder( 66 | resolution=256, 67 | in_channels=3, 68 | ch=128, 69 | out_ch=3, 70 | ch_mult=[1, 2, 4, 4], 71 | num_res_blocks=2, 72 | z_channels=16, 73 | scale_factor=0.3611, 74 | shift_factor=0.1159, 75 | ) 76 | 77 | step1x_params = Step1XParams( 78 | in_channels=64, 79 | out_channels=64, 80 | vec_in_dim=768, 81 | context_in_dim=4096, 82 | hidden_size=3072, 83 | mlp_ratio=4.0, 84 | num_heads=24, 85 | depth=19, 86 | depth_single_blocks=38, 87 | axes_dim=[16, 56, 56], 88 | theta=10_000, 89 | qkv_bias=True, 90 | ) 91 | dit = Step1XEdit(step1x_params) 92 | 93 | # ae = load_state_dict(ae, ae_path) 94 | # dit = load_state_dict( 95 | # dit, dit_path 96 | # ) 97 | ae = load_state_dict(ae, ae_path) 98 | dit = load_state_dict( 99 | dit, dit_path 100 | ) 101 | 102 | dit = dit.to(device=device, dtype=dtype) 103 | ae = ae.to(device=device, dtype=torch.float32) 104 | 105 | return ae, dit, qwen2vl_encoder 106 | 107 | 108 | class Step1XImageGenerator: 109 | def __init__( 110 | self, 111 | dit_path=None, 112 | ae_path=None, 113 | qwen2vl_model_path=None, 114 | device="cuda", 115 | max_length=640, 116 | dtype=torch.bfloat16, 117 | ) -> None: 118 | self.device = torch.device(device) 119 | self.ae, self.dit, self.llm_encoder = load_models( 120 | dit_path=dit_path, 121 | ae_path=ae_path, 122 | qwen2vl_model_path=qwen2vl_model_path, 123 | max_length=max_length, 124 | dtype=dtype, 125 | ) 126 | 127 | def prepare(self, prompt, img, ref_image, ref_image_raw): 128 | bs, _, h, w = img.shape 129 | bs, _, ref_h, ref_w = ref_image.shape 130 | 131 | assert h == ref_h and w == ref_w 132 | 133 | if bs == 1 and not isinstance(prompt, str): 134 | bs = len(prompt) 135 | elif bs >= 1 and isinstance(prompt, str): 136 | prompt = [prompt] * bs 137 | 138 | img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) 139 | ref_img = rearrange(ref_image, "b c (ref_h ph) (ref_w pw) -> b (ref_h ref_w) (c ph pw)", ph=2, pw=2) 140 | if img.shape[0] == 1 and bs > 1: 141 | img = repeat(img, "1 ... -> bs ...", bs=bs) 142 | ref_img = repeat(ref_img, "1 ... -> bs ...", bs=bs) 143 | 144 | img_ids = torch.zeros(h // 2, w // 2, 3) 145 | 146 | img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] 147 | img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] 148 | img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) 149 | 150 | ref_img_ids = torch.zeros(ref_h // 2, ref_w // 2, 3) 151 | 152 | ref_img_ids[..., 1] = ref_img_ids[..., 1] + torch.arange(ref_h // 2)[:, None] 153 | ref_img_ids[..., 2] = ref_img_ids[..., 2] + torch.arange(ref_w // 2)[None, :] 154 | ref_img_ids = repeat(ref_img_ids, "ref_h ref_w c -> b (ref_h ref_w) c", b=bs) 155 | 156 | if isinstance(prompt, str): 157 | prompt = [prompt] 158 | 159 | self.llm_encoder.model.to('cuda') 160 | txt, mask = self.llm_encoder(prompt, ref_image_raw) 161 | self.llm_encoder.model.to('cpu') 162 | 163 | txt_ids = torch.zeros(bs, txt.shape[1], 3) 164 | 165 | img = torch.cat([img, ref_img.to(device=img.device, dtype=img.dtype)], dim=-2) 166 | img_ids = torch.cat([img_ids, ref_img_ids], dim=-2) 167 | 168 | 169 | return { 170 | "img": img, 171 | "mask": mask, 172 | "img_ids": img_ids.to(img.device), 173 | "llm_embedding": txt.to(img.device), 174 | "txt_ids": txt_ids.to(img.device), 175 | } 176 | 177 | @staticmethod 178 | def process_diff_norm(diff_norm, k): 179 | pow_result = torch.pow(diff_norm, k) 180 | 181 | result = torch.where( 182 | diff_norm > 1.0, 183 | pow_result, 184 | torch.where(diff_norm < 1.0, torch.ones_like(diff_norm), diff_norm), 185 | ) 186 | return result 187 | 188 | def denoise( 189 | self, 190 | img: torch.Tensor, 191 | img_ids: torch.Tensor, 192 | llm_embedding: torch.Tensor, 193 | txt_ids: torch.Tensor, 194 | timesteps: list[float], 195 | cfg_guidance: float = 4.5, 196 | mask=None, 197 | show_progress=False, 198 | timesteps_truncate=1.0, 199 | **kwargs 200 | ): 201 | if show_progress: 202 | pbar = tqdm(itertools.pairwise(timesteps), desc='denoising...') 203 | else: 204 | pbar = itertools.pairwise(timesteps) 205 | for t_curr, t_prev in pbar: 206 | if 'rh_hook' in kwargs: 207 | kwargs['rh_hook']() 208 | if img.shape[0] == 1 and cfg_guidance != -1: 209 | img = torch.cat([img, img], dim=0) 210 | t_vec = torch.full( 211 | (img.shape[0],), t_curr, dtype=img.dtype, device=img.device 212 | ) 213 | 214 | txt, vec = self.dit.connector(llm_embedding, t_vec, mask) 215 | 216 | 217 | pred = self.dit( 218 | img=img, 219 | img_ids=img_ids, 220 | txt=txt, 221 | txt_ids=txt_ids, 222 | y=vec, 223 | timesteps=t_vec, 224 | **kwargs 225 | ) 226 | 227 | if cfg_guidance != -1: 228 | cond, uncond = ( 229 | pred[0 : pred.shape[0] // 2, :], 230 | pred[pred.shape[0] // 2 :, :], 231 | ) 232 | if t_curr > timesteps_truncate: 233 | diff = cond - uncond 234 | diff_norm = torch.norm(diff, dim=(2), keepdim=True) 235 | pred = uncond + cfg_guidance * ( 236 | cond - uncond 237 | ) / self.process_diff_norm(diff_norm, k=0.4) 238 | else: 239 | pred = uncond + cfg_guidance * (cond - uncond) 240 | tem_img = img[0 : img.shape[0] // 2, :] + (t_prev - t_curr) * pred 241 | img_input_length = img.shape[1] // 2 242 | img = torch.cat( 243 | [ 244 | tem_img[:, :img_input_length], 245 | img[ : img.shape[0] // 2, img_input_length:], 246 | ], dim=1 247 | ) 248 | 249 | return img[:, :img.shape[1] // 2] 250 | 251 | @staticmethod 252 | def unpack(x: torch.Tensor, height: int, width: int) -> torch.Tensor: 253 | return rearrange( 254 | x, 255 | "b (h w) (c ph pw) -> b c (h ph) (w pw)", 256 | h=math.ceil(height / 16), 257 | w=math.ceil(width / 16), 258 | ph=2, 259 | pw=2, 260 | ) 261 | 262 | @staticmethod 263 | def load_image(image): 264 | from PIL import Image 265 | 266 | if isinstance(image, np.ndarray): 267 | image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0 268 | image = image.unsqueeze(0) 269 | return image 270 | elif isinstance(image, Image.Image): 271 | image = F.to_tensor(image.convert("RGB")) 272 | image = image.unsqueeze(0) 273 | return image 274 | elif isinstance(image, torch.Tensor): 275 | return image 276 | elif isinstance(image, str): 277 | image = F.to_tensor(Image.open(image).convert("RGB")) 278 | image = image.unsqueeze(0) 279 | return image 280 | else: 281 | raise ValueError(f"Unsupported image type: {type(image)}") 282 | 283 | def output_process_image(self, resize_img, image_size): 284 | res_image = resize_img.resize(image_size) 285 | return res_image 286 | 287 | def input_process_image(self, img, img_size=512): 288 | # 1. 打开图片 289 | w, h = img.size 290 | r = w / h 291 | 292 | if w > h: 293 | w_new = math.ceil(math.sqrt(img_size * img_size * r)) 294 | h_new = math.ceil(w_new / r) 295 | else: 296 | h_new = math.ceil(math.sqrt(img_size * img_size / r)) 297 | w_new = math.ceil(h_new * r) 298 | h_new = math.ceil(h_new) // 16 * 16 299 | w_new = math.ceil(w_new) // 16 * 16 300 | 301 | img_resized = img.resize((w_new, h_new)) 302 | return img_resized, img.size 303 | 304 | @torch.inference_mode() 305 | def generate_image( 306 | self, 307 | prompt, 308 | negative_prompt, 309 | ref_images, 310 | num_steps, 311 | cfg_guidance, 312 | seed, 313 | num_samples=1, 314 | init_image=None, 315 | image2image_strength=0.0, 316 | show_progress=False, 317 | size_level=512, 318 | **kwargs 319 | ): 320 | assert num_samples == 1, "num_samples > 1 is not supported yet." 321 | ref_images_raw, img_info = self.input_process_image(ref_images, img_size=size_level) 322 | 323 | use_fp8 = kwargs['use_fp8'] if 'use_fp8' in kwargs else False 324 | print(f"use_fp8:{use_fp8}") 325 | 326 | width, height = ref_images_raw.width, ref_images_raw.height 327 | 328 | 329 | ref_images_raw = self.load_image(ref_images_raw) 330 | ref_images_raw = ref_images_raw.to(self.device) 331 | 332 | #kiki 333 | self.ae.to('cuda') 334 | ref_images = self.ae.encode(ref_images_raw.to(self.device) * 2 - 1) 335 | 336 | seed = int(seed) 337 | seed = torch.Generator(device="cpu").seed() if seed < 0 else seed 338 | 339 | t0 = time.perf_counter() 340 | 341 | if init_image is not None: 342 | init_image = self.load_image(init_image) 343 | init_image = init_image.to(self.device) 344 | init_image = torch.nn.functional.interpolate(init_image, (height, width)) 345 | init_image = self.ae.encode(init_image.to() * 2 - 1) 346 | 347 | x = torch.randn( 348 | num_samples, 349 | 16, 350 | height // 8, 351 | width // 8, 352 | device=self.device, 353 | dtype=torch.bfloat16, 354 | generator=torch.Generator(device=self.device).manual_seed(seed), 355 | ) 356 | 357 | timesteps = sampling.get_schedule( 358 | num_steps, x.shape[-1] * x.shape[-2] // 4, shift=True 359 | ) 360 | 361 | if init_image is not None: 362 | t_idx = int((1 - image2image_strength) * num_steps) 363 | t = timesteps[t_idx] 364 | timesteps = timesteps[t_idx:] 365 | x = t * x + (1.0 - t) * init_image.to(x.dtype) 366 | 367 | x = torch.cat([x, x], dim=0) 368 | ref_images = torch.cat([ref_images, ref_images], dim=0) 369 | ref_images_raw = torch.cat([ref_images_raw, ref_images_raw], dim=0) 370 | inputs = self.prepare([prompt, negative_prompt], x, ref_image=ref_images, ref_image_raw=ref_images_raw) 371 | 372 | if use_fp8: 373 | self.dit.to('cuda', dtype=torch.float8_e4m3fn) 374 | else: 375 | self.dit.to('cuda', dtype=torch.bfloat16) 376 | 377 | with torch.autocast(enabled=use_fp8, device_type='cuda', dtype=torch.bfloat16): 378 | x = self.denoise( 379 | **inputs, 380 | cfg_guidance=cfg_guidance, 381 | timesteps=timesteps, 382 | show_progress=show_progress, 383 | timesteps_truncate=1.0, 384 | **kwargs 385 | ) 386 | 387 | self.dit.to('cpu', dtype=torch.bfloat16) 388 | x = self.unpack(x.float(), height, width) 389 | with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16): 390 | self.ae.to('cuda') 391 | x = self.ae.decode(x) 392 | self.ae.to('cpu') 393 | x = x.clamp(-1, 1) 394 | x = x.mul(0.5).add(0.5) 395 | 396 | t1 = time.perf_counter() 397 | print(f"Done in {t1 - t0:.1f}s.") 398 | images_list = [] 399 | for img in x.float(): 400 | images_list.append(self.output_process_image(F.to_pil_image(img), img_info)) 401 | return images_list 402 | 403 | 404 | def main(): 405 | 406 | parser = argparse.ArgumentParser() 407 | parser.add_argument('--model_path', type=str, required=True, help='Path to the model checkpoint') 408 | parser.add_argument('--input_dir', type=str, required=True, help='Path to the input image directory') 409 | parser.add_argument('--output_dir', type=str, required=True, help='Path to the output image directory') 410 | parser.add_argument('--json_path', type=str, required=True, help='Path to the JSON file containing image names and prompts') 411 | parser.add_argument('--seed', type=int, default=42, help='Random seed for generation') 412 | parser.add_argument('--num_steps', type=int, default=28, help='Number of diffusion steps') 413 | parser.add_argument('--cfg_guidance', type=float, default=6.0, help='CFG guidance strength') 414 | parser.add_argument('--size_level', default=512, type=int) 415 | args = parser.parse_args() 416 | 417 | assert os.path.exists(args.input_dir), f"Input directory {args.input_dir} does not exist." 418 | assert os.path.exists(args.json_path), f"JSON file {args.json_path} does not exist." 419 | os.makedirs(args.output_dir, exist_ok=True) 420 | 421 | image_and_prompts = json.load(open(args.json_path, 'r')) 422 | 423 | #kiki-modify 424 | qwen2vl_model_path = '/workspace/comfyui/models/step-1/Qwen2.5-VL-7B-Instruct/' 425 | args.model_path = '/workspace/comfyui/models/step-1' 426 | 427 | image_edit = Step1XImageGenerator( 428 | ae_path=os.path.join(args.model_path, 'vae.safetensors'), 429 | dit_path=os.path.join(args.model_path, "step1x-edit-i1258.safetensors"), 430 | # qwen2vl_model_path='Qwen/Qwen2.5-VL-7B-Instruct', 431 | qwen2vl_model_path=qwen2vl_model_path, 432 | max_length=640, 433 | ) 434 | 435 | for image_name, prompt in image_and_prompts.items(): 436 | image_path = os.path.join(args.input_dir, image_name) 437 | output_path = os.path.join(args.output_dir, image_name) 438 | start_time = time.time() 439 | 440 | image = image_edit.generate_image( 441 | prompt, 442 | negative_prompt="", 443 | ref_images=Image.open(image_path).convert("RGB"), 444 | num_samples=1, 445 | num_steps=args.num_steps, 446 | cfg_guidance=args.cfg_guidance, 447 | seed=args.seed, 448 | show_progress=True, 449 | size_level=args.size_level, 450 | )[0] 451 | 452 | print(f"Time taken: {time.time() - start_time:.2f} seconds") 453 | 454 | image.save( 455 | os.path.join(output_path), lossless=True 456 | ) 457 | 458 | def kiki_tensor_to_pil(image): 459 | i = 255. * image.cpu().numpy() 460 | img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) 461 | return img 462 | 463 | def run(**kwargs): 464 | 465 | img = kiki_tensor_to_pil(kwargs['ref_image'][0]) 466 | prompt = kwargs['prompt'] 467 | model_path = os.path.join(folder_paths.models_dir, 'step-1') 468 | qwen2vl_model_path = os.path.join(model_path, 'Qwen2.5-VL-7B-Instruct') 469 | 470 | image_edit = Step1XImageGenerator( 471 | ae_path=os.path.join(model_path, 'vae.safetensors'), 472 | dit_path=os.path.join(model_path, "step1x-edit-i1258.safetensors"), 473 | qwen2vl_model_path=qwen2vl_model_path, 474 | max_length=640, 475 | ) 476 | 477 | image = image_edit.generate_image( 478 | prompt, 479 | negative_prompt="", 480 | ref_images=img, 481 | num_samples=1, 482 | num_steps=kwargs['num_steps'], 483 | cfg_guidance=kwargs['cfg_guidance'], 484 | seed=kwargs['seed'], 485 | show_progress=True, 486 | size_level=kwargs['size_level'], 487 | )[0] 488 | 489 | image = np.array(image).astype(np.float32) / 255.0 490 | image = torch.from_numpy(image)[None,] 491 | return image 492 | 493 | if __name__ == "__main__": 494 | main() 495 | -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HM-RunningHub/ComfyUI_RH_Step1XEdit/7aa4217cf581984d8ada0d43533c3fca81904512/modules/__init__.py -------------------------------------------------------------------------------- /modules/attention.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | try: 8 | import flash_attn 9 | from flash_attn.flash_attn_interface import ( 10 | _flash_attn_forward, 11 | flash_attn_func, 12 | flash_attn_varlen_func, 13 | ) 14 | except ImportError: 15 | flash_attn = None 16 | flash_attn_varlen_func = None 17 | _flash_attn_forward = None 18 | flash_attn_func = None 19 | 20 | MEMORY_LAYOUT = { 21 | # flash模式: 22 | # 预处理: 输入 [batch_size, seq_len, num_heads, head_dim] 23 | # 后处理: 保持形状不变 24 | "flash": ( 25 | lambda x: x, # 保持形状 26 | lambda x: x, # 保持形状 27 | ), 28 | # torch/vanilla模式: 29 | # 预处理: 交换序列和注意力头的维度 [B,S,A,D] -> [B,A,S,D] 30 | # 后处理: 交换回原始维度 [B,A,S,D] -> [B,S,A,D] 31 | "torch": ( 32 | lambda x: x.transpose(1, 2), # (B,S,A,D) -> (B,A,S,D) 33 | lambda x: x.transpose(1, 2), # (B,A,S,D) -> (B,S,A,D) 34 | ), 35 | "vanilla": ( 36 | lambda x: x.transpose(1, 2), 37 | lambda x: x.transpose(1, 2), 38 | ), 39 | } 40 | 41 | 42 | def attention( 43 | q, 44 | k, 45 | v, 46 | mode="flash", 47 | drop_rate=0, 48 | attn_mask=None, 49 | causal=False, 50 | ): 51 | """ 52 | 执行QKV自注意力计算 53 | 54 | Args: 55 | q (torch.Tensor): 查询张量,形状 [batch_size, seq_len, num_heads, head_dim] 56 | k (torch.Tensor): 键张量,形状 [batch_size, seq_len_kv, num_heads, head_dim] 57 | v (torch.Tensor): 值张量,形状 [batch_size, seq_len_kv, num_heads, head_dim] 58 | mode (str): 注意力模式,可选 'flash', 'torch', 'vanilla' 59 | drop_rate (float): 注意力矩阵的dropout概率 60 | attn_mask (torch.Tensor): 注意力掩码,形状根据模式不同而变化 61 | causal (bool): 是否使用因果注意力(仅关注前面位置) 62 | 63 | Returns: 64 | torch.Tensor: 注意力输出,形状 [batch_size, seq_len, num_heads * head_dim] 65 | """ 66 | # 获取预处理和后处理函数 67 | pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode] 68 | 69 | # 应用预处理变换 70 | q = pre_attn_layout(q) # 形状根据模式变化 71 | k = pre_attn_layout(k) 72 | v = pre_attn_layout(v) 73 | 74 | if mode == "torch": 75 | # 使用PyTorch原生的scaled_dot_product_attention 76 | if attn_mask is not None and attn_mask.dtype != torch.bool: 77 | attn_mask = attn_mask.to(q.dtype) 78 | x = F.scaled_dot_product_attention( 79 | q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal 80 | ) 81 | elif mode == "flash": 82 | assert flash_attn_func is not None, "flash_attn_func未定义" 83 | assert attn_mask is None, "不支持的注意力掩码" 84 | x: torch.Tensor = flash_attn_func( 85 | q, k, v, dropout_p=drop_rate, causal=causal, softmax_scale=None 86 | ) # type: ignore 87 | elif mode == "vanilla": 88 | # 手动实现注意力机制 89 | scale_factor = 1 / math.sqrt(q.size(-1)) # 缩放因子 1/sqrt(d_k) 90 | 91 | b, a, s, _ = q.shape # 获取形状参数 92 | s1 = k.size(2) # 键值序列长度 93 | 94 | # 初始化注意力偏置 95 | attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device) 96 | 97 | # 处理因果掩码 98 | if causal: 99 | assert attn_mask is None, "因果掩码和注意力掩码不能同时使用" 100 | # 生成下三角因果掩码 101 | temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril( 102 | diagonal=0 103 | ) 104 | attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) 105 | attn_bias = attn_bias.to(q.dtype) 106 | 107 | # 处理自定义注意力掩码 108 | if attn_mask is not None: 109 | if attn_mask.dtype == torch.bool: 110 | attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) 111 | else: 112 | attn_bias += attn_mask # 允许类似ALiBi的位置偏置 113 | 114 | # 计算注意力矩阵 115 | attn = (q @ k.transpose(-2, -1)) * scale_factor # [B,A,S,S1] 116 | attn += attn_bias 117 | 118 | # softmax和dropout 119 | attn = attn.softmax(dim=-1) 120 | attn = torch.dropout(attn, p=drop_rate, train=True) 121 | 122 | # 计算输出 123 | x = attn @ v # [B,A,S,D] 124 | else: 125 | raise NotImplementedError(f"不支持的注意力模式: {mode}") 126 | 127 | # 应用后处理变换 128 | x = post_attn_layout(x) # 恢复原始维度顺序 129 | 130 | # 合并注意力头维度 131 | b, s, a, d = x.shape 132 | out = x.reshape(b, s, -1) # [B,S,A*D] 133 | return out 134 | -------------------------------------------------------------------------------- /modules/autoencoder.py: -------------------------------------------------------------------------------- 1 | # Modified from Flux 2 | # 3 | # Copyright 2024 Black Forest Labs 4 | 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | # This source code is licensed under the license found in the 18 | # LICENSE file in the root directory of this source tree. 19 | import torch 20 | from einops import rearrange 21 | from torch import Tensor, nn 22 | 23 | 24 | def swish(x: Tensor) -> Tensor: 25 | return x * torch.sigmoid(x) 26 | 27 | 28 | class AttnBlock(nn.Module): 29 | def __init__(self, in_channels: int): 30 | super().__init__() 31 | self.in_channels = in_channels 32 | 33 | self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 34 | 35 | self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1) 36 | self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1) 37 | self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1) 38 | self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1) 39 | 40 | def attention(self, h_: Tensor) -> Tensor: 41 | h_ = self.norm(h_) 42 | q = self.q(h_) 43 | k = self.k(h_) 44 | v = self.v(h_) 45 | 46 | b, c, h, w = q.shape 47 | q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() 48 | k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() 49 | v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() 50 | h_ = nn.functional.scaled_dot_product_attention(q, k, v) 51 | 52 | return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) 53 | 54 | def forward(self, x: Tensor) -> Tensor: 55 | return x + self.proj_out(self.attention(x)) 56 | 57 | 58 | class ResnetBlock(nn.Module): 59 | def __init__(self, in_channels: int, out_channels: int): 60 | super().__init__() 61 | self.in_channels = in_channels 62 | out_channels = in_channels if out_channels is None else out_channels 63 | self.out_channels = out_channels 64 | 65 | self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 66 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) 67 | self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) 68 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 69 | if self.in_channels != self.out_channels: 70 | self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) 71 | 72 | def forward(self, x): 73 | h = x 74 | h = self.norm1(h) 75 | h = swish(h) 76 | h = self.conv1(h) 77 | 78 | h = self.norm2(h) 79 | h = swish(h) 80 | h = self.conv2(h) 81 | 82 | if self.in_channels != self.out_channels: 83 | x = self.nin_shortcut(x) 84 | 85 | return x + h 86 | 87 | 88 | class Downsample(nn.Module): 89 | def __init__(self, in_channels: int): 90 | super().__init__() 91 | # no asymmetric padding in torch conv, must do it ourselves 92 | self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) 93 | 94 | def forward(self, x: Tensor): 95 | pad = (0, 1, 0, 1) 96 | x = nn.functional.pad(x, pad, mode="constant", value=0) 97 | x = self.conv(x) 98 | return x 99 | 100 | 101 | class Upsample(nn.Module): 102 | def __init__(self, in_channels: int): 103 | super().__init__() 104 | self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) 105 | 106 | def forward(self, x: Tensor): 107 | x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") 108 | x = self.conv(x) 109 | return x 110 | 111 | 112 | class Encoder(nn.Module): 113 | def __init__( 114 | self, 115 | resolution: int, 116 | in_channels: int, 117 | ch: int, 118 | ch_mult: list[int], 119 | num_res_blocks: int, 120 | z_channels: int, 121 | ): 122 | super().__init__() 123 | self.ch = ch 124 | self.num_resolutions = len(ch_mult) 125 | self.num_res_blocks = num_res_blocks 126 | self.resolution = resolution 127 | self.in_channels = in_channels 128 | # downsampling 129 | self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) 130 | 131 | curr_res = resolution 132 | in_ch_mult = (1, *tuple(ch_mult)) 133 | self.in_ch_mult = in_ch_mult 134 | self.down = nn.ModuleList() 135 | block_in = self.ch 136 | for i_level in range(self.num_resolutions): 137 | block = nn.ModuleList() 138 | attn = nn.ModuleList() 139 | block_in = ch * in_ch_mult[i_level] 140 | block_out = ch * ch_mult[i_level] 141 | for _ in range(self.num_res_blocks): 142 | block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) 143 | block_in = block_out 144 | down = nn.Module() 145 | down.block = block 146 | down.attn = attn 147 | if i_level != self.num_resolutions - 1: 148 | down.downsample = Downsample(block_in) 149 | curr_res = curr_res // 2 150 | self.down.append(down) 151 | 152 | # middle 153 | self.mid = nn.Module() 154 | self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) 155 | self.mid.attn_1 = AttnBlock(block_in) 156 | self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) 157 | 158 | # end 159 | self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) 160 | self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1) 161 | 162 | def forward(self, x: Tensor) -> Tensor: 163 | # downsampling 164 | hs = [self.conv_in(x)] 165 | for i_level in range(self.num_resolutions): 166 | for i_block in range(self.num_res_blocks): 167 | h = self.down[i_level].block[i_block](hs[-1]) 168 | if len(self.down[i_level].attn) > 0: 169 | h = self.down[i_level].attn[i_block](h) 170 | hs.append(h) 171 | if i_level != self.num_resolutions - 1: 172 | hs.append(self.down[i_level].downsample(hs[-1])) 173 | 174 | # middle 175 | h = hs[-1] 176 | h = self.mid.block_1(h) 177 | h = self.mid.attn_1(h) 178 | h = self.mid.block_2(h) 179 | # end 180 | h = self.norm_out(h) 181 | h = swish(h) 182 | h = self.conv_out(h) 183 | return h 184 | 185 | 186 | class Decoder(nn.Module): 187 | def __init__( 188 | self, 189 | ch: int, 190 | out_ch: int, 191 | ch_mult: list[int], 192 | num_res_blocks: int, 193 | in_channels: int, 194 | resolution: int, 195 | z_channels: int, 196 | ): 197 | super().__init__() 198 | self.ch = ch 199 | self.num_resolutions = len(ch_mult) 200 | self.num_res_blocks = num_res_blocks 201 | self.resolution = resolution 202 | self.in_channels = in_channels 203 | self.ffactor = 2 ** (self.num_resolutions - 1) 204 | 205 | # compute in_ch_mult, block_in and curr_res at lowest res 206 | block_in = ch * ch_mult[self.num_resolutions - 1] 207 | curr_res = resolution // 2 ** (self.num_resolutions - 1) 208 | self.z_shape = (1, z_channels, curr_res, curr_res) 209 | 210 | # z to block_in 211 | self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) 212 | 213 | # middle 214 | self.mid = nn.Module() 215 | self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) 216 | self.mid.attn_1 = AttnBlock(block_in) 217 | self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) 218 | 219 | # upsampling 220 | self.up = nn.ModuleList() 221 | for i_level in reversed(range(self.num_resolutions)): 222 | block = nn.ModuleList() 223 | attn = nn.ModuleList() 224 | block_out = ch * ch_mult[i_level] 225 | for _ in range(self.num_res_blocks + 1): 226 | block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) 227 | block_in = block_out 228 | up = nn.Module() 229 | up.block = block 230 | up.attn = attn 231 | if i_level != 0: 232 | up.upsample = Upsample(block_in) 233 | curr_res = curr_res * 2 234 | self.up.insert(0, up) # prepend to get consistent order 235 | 236 | # end 237 | self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) 238 | self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) 239 | 240 | def forward(self, z: Tensor) -> Tensor: 241 | # z to block_in 242 | h = self.conv_in(z) 243 | 244 | # middle 245 | h = self.mid.block_1(h) 246 | h = self.mid.attn_1(h) 247 | h = self.mid.block_2(h) 248 | 249 | # upsampling 250 | for i_level in reversed(range(self.num_resolutions)): 251 | for i_block in range(self.num_res_blocks + 1): 252 | h = self.up[i_level].block[i_block](h) 253 | if len(self.up[i_level].attn) > 0: 254 | h = self.up[i_level].attn[i_block](h) 255 | if i_level != 0: 256 | h = self.up[i_level].upsample(h) 257 | 258 | # end 259 | h = self.norm_out(h) 260 | h = swish(h) 261 | h = self.conv_out(h) 262 | return h 263 | 264 | 265 | class DiagonalGaussian(nn.Module): 266 | def __init__(self, sample: bool = True, chunk_dim: int = 1): 267 | super().__init__() 268 | self.sample = sample 269 | self.chunk_dim = chunk_dim 270 | 271 | def forward(self, z: Tensor) -> Tensor: 272 | mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim) 273 | if self.sample: 274 | std = torch.exp(0.5 * logvar) 275 | return mean + std * torch.randn_like(mean) 276 | else: 277 | return mean 278 | 279 | 280 | class AutoEncoder(nn.Module): 281 | def __init__( 282 | self, 283 | resolution: int, 284 | in_channels: int, 285 | ch: int, 286 | out_ch: int, 287 | ch_mult: list[int], 288 | num_res_blocks: int, 289 | z_channels: int, 290 | scale_factor: float, 291 | shift_factor: float, 292 | ): 293 | super().__init__() 294 | self.encoder = Encoder( 295 | resolution=resolution, 296 | in_channels=in_channels, 297 | ch=ch, 298 | ch_mult=ch_mult, 299 | num_res_blocks=num_res_blocks, 300 | z_channels=z_channels, 301 | ) 302 | self.decoder = Decoder( 303 | resolution=resolution, 304 | in_channels=in_channels, 305 | ch=ch, 306 | out_ch=out_ch, 307 | ch_mult=ch_mult, 308 | num_res_blocks=num_res_blocks, 309 | z_channels=z_channels, 310 | ) 311 | self.reg = DiagonalGaussian() 312 | 313 | self.scale_factor = scale_factor 314 | self.shift_factor = shift_factor 315 | 316 | def encode(self, x: Tensor) -> Tensor: 317 | z = self.reg(self.encoder(x)) 318 | z = self.scale_factor * (z - self.shift_factor) 319 | return z 320 | 321 | def decode(self, z: Tensor) -> Tensor: 322 | z = z / self.scale_factor + self.shift_factor 323 | return self.decoder(z) 324 | 325 | def forward(self, x: Tensor) -> Tensor: 326 | return self.decode(self.encode(x)) 327 | -------------------------------------------------------------------------------- /modules/conditioner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from qwen_vl_utils import process_vision_info 3 | from transformers import ( 4 | AutoProcessor, 5 | Qwen2VLForConditionalGeneration, 6 | Qwen2_5_VLForConditionalGeneration, 7 | ) 8 | from torchvision.transforms import ToPILImage 9 | 10 | to_pil = ToPILImage() 11 | 12 | Qwen25VL_7b_PREFIX = '''Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt: 13 | - If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes. 14 | - If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.\n 15 | Here are examples of how to transform or refine prompts: 16 | - User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers. 17 | - User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.\n 18 | Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations: 19 | User Prompt:''' 20 | 21 | 22 | def split_string(s): 23 | # 将中文引号替换为英文引号 24 | s = s.replace("“", '"').replace("”", '"') # use english quotes 25 | result = [] 26 | # 标记是否在引号内 27 | in_quotes = False 28 | temp = "" 29 | 30 | # 遍历字符串中的每个字符及其索引 31 | for idx, char in enumerate(s): 32 | # 如果字符是引号且索引大于 155 33 | if char == '"' and idx > 155: 34 | # 将引号添加到临时字符串 35 | temp += char 36 | # 如果不在引号内 37 | if not in_quotes: 38 | # 将临时字符串添加到结果列表 39 | result.append(temp) 40 | # 清空临时字符串 41 | temp = "" 42 | 43 | # 切换引号状态 44 | in_quotes = not in_quotes 45 | continue 46 | # 如果在引号内 47 | if in_quotes: 48 | # 如果字符是空格 49 | if char.isspace(): 50 | pass # have space token 51 | 52 | # 将字符用中文引号包裹后添加到结果列表 53 | result.append("“" + char + "”") 54 | else: 55 | # 将字符添加到临时字符串 56 | temp += char 57 | 58 | # 如果临时字符串不为空 59 | if temp: 60 | # 将临时字符串添加到结果列表 61 | result.append(temp) 62 | 63 | return result 64 | 65 | 66 | class Qwen25VL_7b_Embedder(torch.nn.Module): 67 | def __init__(self, model_path, max_length=640, dtype=torch.bfloat16, device="cuda"): 68 | super(Qwen25VL_7b_Embedder, self).__init__() 69 | self.max_length = max_length 70 | self.dtype = dtype 71 | self.device = device 72 | 73 | self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained( 74 | model_path, 75 | torch_dtype=dtype, 76 | #attn_implementation="flash_attention_2", 77 | attn_implementation="eager", 78 | 79 | # ).to(torch.cuda.current_device()) 80 | ) 81 | 82 | self.model.requires_grad_(False) 83 | self.processor = AutoProcessor.from_pretrained( 84 | model_path, min_pixels=256 * 28 * 28, max_pixels=324 * 28 * 28 85 | ) 86 | 87 | self.prefix = Qwen25VL_7b_PREFIX 88 | 89 | def forward(self, caption, ref_images): 90 | text_list = caption 91 | embs = torch.zeros( 92 | len(text_list), 93 | self.max_length, 94 | self.model.config.hidden_size, 95 | dtype=torch.bfloat16, 96 | device=torch.cuda.current_device(), 97 | ) 98 | hidden_states = torch.zeros( 99 | len(text_list), 100 | self.max_length, 101 | self.model.config.hidden_size, 102 | dtype=torch.bfloat16, 103 | device=torch.cuda.current_device(), 104 | ) 105 | masks = torch.zeros( 106 | len(text_list), 107 | self.max_length, 108 | dtype=torch.long, 109 | device=torch.cuda.current_device(), 110 | ) 111 | input_ids_list = [] 112 | attention_mask_list = [] 113 | emb_list = [] 114 | 115 | def split_string(s): 116 | s = s.replace("“", '"').replace("”", '"').replace("'", '''"''') # use english quotes 117 | result = [] 118 | in_quotes = False 119 | temp = "" 120 | 121 | for idx,char in enumerate(s): 122 | if char == '"' and idx>155: 123 | temp += char 124 | if not in_quotes: 125 | result.append(temp) 126 | temp = "" 127 | 128 | in_quotes = not in_quotes 129 | continue 130 | if in_quotes: 131 | if char.isspace(): 132 | pass # have space token 133 | 134 | result.append("“" + char + "”") 135 | else: 136 | temp += char 137 | 138 | if temp: 139 | result.append(temp) 140 | 141 | return result 142 | 143 | for idx, (txt, imgs) in enumerate(zip(text_list, ref_images)): 144 | 145 | messages = [{"role": "user", "content": []}] 146 | 147 | messages[0]["content"].append({"type": "text", "text": f"{self.prefix}"}) 148 | 149 | messages[0]["content"].append({"type": "image", "image": to_pil(imgs)}) 150 | 151 | # 再添加 text 152 | messages[0]["content"].append({"type": "text", "text": f"{txt}"}) 153 | 154 | # Preparation for inference 155 | text = self.processor.apply_chat_template( 156 | messages, tokenize=False, add_generation_prompt=True, add_vision_id=True 157 | ) 158 | 159 | image_inputs, video_inputs = process_vision_info(messages) 160 | 161 | inputs = self.processor( 162 | text=[text], 163 | images=image_inputs, 164 | padding=True, 165 | return_tensors="pt", 166 | ) 167 | 168 | old_inputs_ids = inputs.input_ids 169 | text_split_list = split_string(text) 170 | 171 | token_list = [] 172 | for text_each in text_split_list: 173 | txt_inputs = self.processor( 174 | text=text_each, 175 | images=None, 176 | videos=None, 177 | padding=True, 178 | return_tensors="pt", 179 | ) 180 | token_each = txt_inputs.input_ids 181 | if token_each[0][0] == 2073 and token_each[0][-1] == 854: 182 | token_each = token_each[:, 1:-1] 183 | token_list.append(token_each) 184 | else: 185 | token_list.append(token_each) 186 | 187 | new_txt_ids = torch.cat(token_list, dim=1).to("cuda") 188 | 189 | new_txt_ids = new_txt_ids.to(old_inputs_ids.device) 190 | 191 | idx1 = (old_inputs_ids == 151653).nonzero(as_tuple=True)[1][0] 192 | idx2 = (new_txt_ids == 151653).nonzero(as_tuple=True)[1][0] 193 | inputs.input_ids = ( 194 | torch.cat([old_inputs_ids[0, :idx1], new_txt_ids[0, idx2:]], dim=0) 195 | .unsqueeze(0) 196 | .to("cuda") 197 | ) 198 | inputs.attention_mask = (inputs.input_ids > 0).long().to("cuda") 199 | outputs = self.model( 200 | input_ids=inputs.input_ids, 201 | attention_mask=inputs.attention_mask, 202 | pixel_values=inputs.pixel_values.to("cuda"), 203 | image_grid_thw=inputs.image_grid_thw.to("cuda"), 204 | output_hidden_states=True, 205 | ) 206 | 207 | emb = outputs["hidden_states"][-1] 208 | 209 | embs[idx, : min(self.max_length, emb.shape[1] - 217)] = emb[0, 217:][ 210 | : self.max_length 211 | ] 212 | 213 | masks[idx, : min(self.max_length, emb.shape[1] - 217)] = torch.ones( 214 | (min(self.max_length, emb.shape[1] - 217)), 215 | dtype=torch.long, 216 | device=torch.cuda.current_device(), 217 | ) 218 | 219 | return embs, masks -------------------------------------------------------------------------------- /modules/connector_edit.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | import torch.nn 5 | from einops import rearrange 6 | from torch import nn 7 | 8 | from .layers import MLP, TextProjection, TimestepEmbedder, apply_gate, attention 9 | 10 | 11 | class RMSNorm(nn.Module): 12 | def __init__( 13 | self, 14 | dim: int, 15 | elementwise_affine=True, 16 | eps: float = 1e-6, 17 | device=None, 18 | dtype=None, 19 | ): 20 | """ 21 | Initialize the RMSNorm normalization layer. 22 | 23 | Args: 24 | dim (int): The dimension of the input tensor. 25 | eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. 26 | 27 | Attributes: 28 | eps (float): A small value added to the denominator for numerical stability. 29 | weight (nn.Parameter): Learnable scaling parameter. 30 | 31 | """ 32 | factory_kwargs = {"device": device, "dtype": dtype} 33 | super().__init__() 34 | self.eps = eps 35 | if elementwise_affine: 36 | self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs)) 37 | 38 | def _norm(self, x): 39 | """ 40 | Apply the RMSNorm normalization to the input tensor. 41 | 42 | Args: 43 | x (torch.Tensor): The input tensor. 44 | 45 | Returns: 46 | torch.Tensor: The normalized tensor. 47 | 48 | """ 49 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 50 | 51 | def forward(self, x): 52 | """ 53 | Forward pass through the RMSNorm layer. 54 | 55 | Args: 56 | x (torch.Tensor): The input tensor. 57 | 58 | Returns: 59 | torch.Tensor: The output tensor after applying RMSNorm. 60 | 61 | """ 62 | output = self._norm(x.float()).type_as(x) 63 | if hasattr(self, "weight"): 64 | output = output * self.weight 65 | return output 66 | 67 | 68 | def get_norm_layer(norm_layer): 69 | """ 70 | Get the normalization layer. 71 | 72 | Args: 73 | norm_layer (str): The type of normalization layer. 74 | 75 | Returns: 76 | norm_layer (nn.Module): The normalization layer. 77 | """ 78 | if norm_layer == "layer": 79 | return nn.LayerNorm 80 | elif norm_layer == "rms": 81 | return RMSNorm 82 | else: 83 | raise NotImplementedError(f"Norm layer {norm_layer} is not implemented") 84 | 85 | 86 | def get_activation_layer(act_type): 87 | """get activation layer 88 | 89 | Args: 90 | act_type (str): the activation type 91 | 92 | Returns: 93 | torch.nn.functional: the activation layer 94 | """ 95 | if act_type == "gelu": 96 | return lambda: nn.GELU() 97 | elif act_type == "gelu_tanh": 98 | return lambda: nn.GELU(approximate="tanh") 99 | elif act_type == "relu": 100 | return nn.ReLU 101 | elif act_type == "silu": 102 | return nn.SiLU 103 | else: 104 | raise ValueError(f"Unknown activation type: {act_type}") 105 | 106 | class IndividualTokenRefinerBlock(torch.nn.Module): 107 | def __init__( 108 | self, 109 | hidden_size, 110 | heads_num, 111 | mlp_width_ratio: str = 4.0, 112 | mlp_drop_rate: float = 0.0, 113 | act_type: str = "silu", 114 | qk_norm: bool = False, 115 | qk_norm_type: str = "layer", 116 | qkv_bias: bool = True, 117 | need_CA: bool = False, 118 | dtype: Optional[torch.dtype] = None, 119 | device: Optional[torch.device] = None, 120 | ): 121 | factory_kwargs = {"device": device, "dtype": dtype} 122 | super().__init__() 123 | self.need_CA = need_CA 124 | self.heads_num = heads_num 125 | head_dim = hidden_size // heads_num 126 | mlp_hidden_dim = int(hidden_size * mlp_width_ratio) 127 | 128 | self.norm1 = nn.LayerNorm( 129 | hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs 130 | ) 131 | self.self_attn_qkv = nn.Linear( 132 | hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs 133 | ) 134 | qk_norm_layer = get_norm_layer(qk_norm_type) 135 | self.self_attn_q_norm = ( 136 | qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) 137 | if qk_norm 138 | else nn.Identity() 139 | ) 140 | self.self_attn_k_norm = ( 141 | qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) 142 | if qk_norm 143 | else nn.Identity() 144 | ) 145 | self.self_attn_proj = nn.Linear( 146 | hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs 147 | ) 148 | 149 | self.norm2 = nn.LayerNorm( 150 | hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs 151 | ) 152 | act_layer = get_activation_layer(act_type) 153 | self.mlp = MLP( 154 | in_channels=hidden_size, 155 | hidden_channels=mlp_hidden_dim, 156 | act_layer=act_layer, 157 | drop=mlp_drop_rate, 158 | **factory_kwargs, 159 | ) 160 | 161 | self.adaLN_modulation = nn.Sequential( 162 | act_layer(), 163 | nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs), 164 | ) 165 | 166 | if self.need_CA: 167 | self.cross_attnblock=CrossAttnBlock(hidden_size=hidden_size, 168 | heads_num=heads_num, 169 | mlp_width_ratio=mlp_width_ratio, 170 | mlp_drop_rate=mlp_drop_rate, 171 | act_type=act_type, 172 | qk_norm=qk_norm, 173 | qk_norm_type=qk_norm_type, 174 | qkv_bias=qkv_bias, 175 | **factory_kwargs,) 176 | # Zero-initialize the modulation 177 | nn.init.zeros_(self.adaLN_modulation[1].weight) 178 | nn.init.zeros_(self.adaLN_modulation[1].bias) 179 | 180 | def forward( 181 | self, 182 | x: torch.Tensor, 183 | c: torch.Tensor, # timestep_aware_representations + context_aware_representations 184 | attn_mask: torch.Tensor = None, 185 | y: torch.Tensor = None, 186 | ): 187 | gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1) 188 | 189 | norm_x = self.norm1(x) 190 | qkv = self.self_attn_qkv(norm_x) 191 | q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) 192 | # Apply QK-Norm if needed 193 | q = self.self_attn_q_norm(q).to(v) 194 | k = self.self_attn_k_norm(k).to(v) 195 | 196 | # Self-Attention 197 | attn = attention(q, k, v, mode="torch", attn_mask=attn_mask) 198 | 199 | x = x + apply_gate(self.self_attn_proj(attn), gate_msa) 200 | 201 | if self.need_CA: 202 | x = self.cross_attnblock(x, c, attn_mask, y) 203 | 204 | # FFN Layer 205 | x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp) 206 | 207 | return x 208 | 209 | 210 | 211 | 212 | class CrossAttnBlock(torch.nn.Module): 213 | def __init__( 214 | self, 215 | hidden_size, 216 | heads_num, 217 | mlp_width_ratio: str = 4.0, 218 | mlp_drop_rate: float = 0.0, 219 | act_type: str = "silu", 220 | qk_norm: bool = False, 221 | qk_norm_type: str = "layer", 222 | qkv_bias: bool = True, 223 | dtype: Optional[torch.dtype] = None, 224 | device: Optional[torch.device] = None, 225 | ): 226 | factory_kwargs = {"device": device, "dtype": dtype} 227 | super().__init__() 228 | self.heads_num = heads_num 229 | head_dim = hidden_size // heads_num 230 | 231 | self.norm1 = nn.LayerNorm( 232 | hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs 233 | ) 234 | self.norm1_2 = nn.LayerNorm( 235 | hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs 236 | ) 237 | self.self_attn_q = nn.Linear( 238 | hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs 239 | ) 240 | self.self_attn_kv = nn.Linear( 241 | hidden_size, hidden_size*2, bias=qkv_bias, **factory_kwargs 242 | ) 243 | qk_norm_layer = get_norm_layer(qk_norm_type) 244 | self.self_attn_q_norm = ( 245 | qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) 246 | if qk_norm 247 | else nn.Identity() 248 | ) 249 | self.self_attn_k_norm = ( 250 | qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) 251 | if qk_norm 252 | else nn.Identity() 253 | ) 254 | self.self_attn_proj = nn.Linear( 255 | hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs 256 | ) 257 | 258 | self.norm2 = nn.LayerNorm( 259 | hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs 260 | ) 261 | act_layer = get_activation_layer(act_type) 262 | 263 | self.adaLN_modulation = nn.Sequential( 264 | act_layer(), 265 | nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs), 266 | ) 267 | # Zero-initialize the modulation 268 | nn.init.zeros_(self.adaLN_modulation[1].weight) 269 | nn.init.zeros_(self.adaLN_modulation[1].bias) 270 | 271 | def forward( 272 | self, 273 | x: torch.Tensor, 274 | c: torch.Tensor, # timestep_aware_representations + context_aware_representations 275 | attn_mask: torch.Tensor = None, 276 | y: torch.Tensor=None, 277 | 278 | ): 279 | gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1) 280 | 281 | norm_x = self.norm1(x) 282 | norm_y = self.norm1_2(y) 283 | q = self.self_attn_q(norm_x) 284 | q = rearrange(q, "B L (H D) -> B L H D", H=self.heads_num) 285 | kv = self.self_attn_kv(norm_y) 286 | k, v = rearrange(kv, "B L (K H D) -> K B L H D", K=2, H=self.heads_num) 287 | # Apply QK-Norm if needed 288 | q = self.self_attn_q_norm(q).to(v) 289 | k = self.self_attn_k_norm(k).to(v) 290 | 291 | # Self-Attention 292 | attn = attention(q, k, v, mode="torch", attn_mask=attn_mask) 293 | 294 | x = x + apply_gate(self.self_attn_proj(attn), gate_msa) 295 | 296 | return x 297 | 298 | 299 | 300 | class IndividualTokenRefiner(torch.nn.Module): 301 | def __init__( 302 | self, 303 | hidden_size, 304 | heads_num, 305 | depth, 306 | mlp_width_ratio: float = 4.0, 307 | mlp_drop_rate: float = 0.0, 308 | act_type: str = "silu", 309 | qk_norm: bool = False, 310 | qk_norm_type: str = "layer", 311 | qkv_bias: bool = True, 312 | need_CA:bool=False, 313 | dtype: Optional[torch.dtype] = None, 314 | device: Optional[torch.device] = None, 315 | ): 316 | 317 | factory_kwargs = {"device": device, "dtype": dtype} 318 | super().__init__() 319 | self.need_CA = need_CA 320 | self.blocks = nn.ModuleList( 321 | [ 322 | IndividualTokenRefinerBlock( 323 | hidden_size=hidden_size, 324 | heads_num=heads_num, 325 | mlp_width_ratio=mlp_width_ratio, 326 | mlp_drop_rate=mlp_drop_rate, 327 | act_type=act_type, 328 | qk_norm=qk_norm, 329 | qk_norm_type=qk_norm_type, 330 | qkv_bias=qkv_bias, 331 | need_CA=self.need_CA, 332 | **factory_kwargs, 333 | ) 334 | for _ in range(depth) 335 | ] 336 | ) 337 | 338 | 339 | def forward( 340 | self, 341 | x: torch.Tensor, 342 | c: torch.LongTensor, 343 | mask: Optional[torch.Tensor] = None, 344 | y:torch.Tensor=None, 345 | ): 346 | self_attn_mask = None 347 | if mask is not None: 348 | batch_size = mask.shape[0] 349 | seq_len = mask.shape[1] 350 | mask = mask.to(x.device) 351 | # batch_size x 1 x seq_len x seq_len 352 | self_attn_mask_1 = mask.view(batch_size, 1, 1, seq_len).repeat( 353 | 1, 1, seq_len, 1 354 | ) 355 | # batch_size x 1 x seq_len x seq_len 356 | self_attn_mask_2 = self_attn_mask_1.transpose(2, 3) 357 | # batch_size x 1 x seq_len x seq_len, 1 for broadcasting of heads_num 358 | self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool() 359 | # avoids self-attention weight being NaN for padding tokens 360 | self_attn_mask[:, :, :, 0] = True 361 | 362 | 363 | for block in self.blocks: 364 | x = block(x, c, self_attn_mask,y) 365 | 366 | return x 367 | 368 | 369 | class SingleTokenRefiner(torch.nn.Module): 370 | """ 371 | A single token refiner block for llm text embedding refine. 372 | """ 373 | def __init__( 374 | self, 375 | in_channels, 376 | hidden_size, 377 | heads_num, 378 | depth, 379 | mlp_width_ratio: float = 4.0, 380 | mlp_drop_rate: float = 0.0, 381 | act_type: str = "silu", 382 | qk_norm: bool = False, 383 | qk_norm_type: str = "layer", 384 | qkv_bias: bool = True, 385 | need_CA:bool=False, 386 | attn_mode: str = "torch", 387 | dtype: Optional[torch.dtype] = None, 388 | device: Optional[torch.device] = None, 389 | ): 390 | factory_kwargs = {"device": device, "dtype": dtype} 391 | super().__init__() 392 | self.attn_mode = attn_mode 393 | self.need_CA = need_CA 394 | assert self.attn_mode == "torch", "Only support 'torch' mode for token refiner." 395 | 396 | self.input_embedder = nn.Linear( 397 | in_channels, hidden_size, bias=True, **factory_kwargs 398 | ) 399 | if self.need_CA: 400 | self.input_embedder_CA = nn.Linear( 401 | in_channels, hidden_size, bias=True, **factory_kwargs 402 | ) 403 | 404 | act_layer = get_activation_layer(act_type) 405 | # Build timestep embedding layer 406 | self.t_embedder = TimestepEmbedder(hidden_size, act_layer, **factory_kwargs) 407 | # Build context embedding layer 408 | self.c_embedder = TextProjection( 409 | in_channels, hidden_size, act_layer, **factory_kwargs 410 | ) 411 | 412 | self.individual_token_refiner = IndividualTokenRefiner( 413 | hidden_size=hidden_size, 414 | heads_num=heads_num, 415 | depth=depth, 416 | mlp_width_ratio=mlp_width_ratio, 417 | mlp_drop_rate=mlp_drop_rate, 418 | act_type=act_type, 419 | qk_norm=qk_norm, 420 | qk_norm_type=qk_norm_type, 421 | qkv_bias=qkv_bias, 422 | need_CA=need_CA, 423 | **factory_kwargs, 424 | ) 425 | 426 | def forward( 427 | self, 428 | x: torch.Tensor, 429 | t: torch.LongTensor, 430 | mask: Optional[torch.LongTensor] = None, 431 | y: torch.LongTensor=None, 432 | ): 433 | timestep_aware_representations = self.t_embedder(t) 434 | 435 | if mask is None: 436 | context_aware_representations = x.mean(dim=1) 437 | else: 438 | mask_float = mask.unsqueeze(-1) # [b, s1, 1] 439 | context_aware_representations = (x * mask_float).sum( 440 | dim=1 441 | ) / mask_float.sum(dim=1) 442 | context_aware_representations = self.c_embedder(context_aware_representations) 443 | c = timestep_aware_representations + context_aware_representations 444 | 445 | x = self.input_embedder(x) 446 | if self.need_CA: 447 | y = self.input_embedder_CA(y) 448 | x = self.individual_token_refiner(x, c, mask, y) 449 | else: 450 | x = self.individual_token_refiner(x, c, mask) 451 | 452 | return x 453 | 454 | 455 | 456 | class Qwen2Connector(torch.nn.Module): 457 | def __init__( 458 | self, 459 | # biclip_dim=1024, 460 | in_channels=3584, 461 | hidden_size=4096, 462 | heads_num=32, 463 | depth=2, 464 | need_CA=False, 465 | device=None, 466 | dtype=torch.bfloat16, 467 | ): 468 | super().__init__() 469 | factory_kwargs = {"device": device, "dtype":dtype} 470 | 471 | self.S =SingleTokenRefiner(in_channels=in_channels,hidden_size=hidden_size,heads_num=heads_num,depth=depth,need_CA=need_CA,**factory_kwargs) 472 | self.global_proj_out=nn.Linear(in_channels,768) 473 | 474 | self.scale_factor = nn.Parameter(torch.zeros(1)) 475 | with torch.no_grad(): 476 | self.scale_factor.data += -(1 - 0.09) 477 | 478 | def forward(self, x,t,mask): 479 | if self.scale_factor.dtype == torch.float8_e4m3fn: 480 | self.scale_factor.data = self.scale_factor.data.to(x.dtype) 481 | mask_float = mask.unsqueeze(-1) # [b, s1, 1] 482 | x_mean = (x * mask_float).sum( 483 | dim=1 484 | ) / mask_float.sum(dim=1) * (1 + self.scale_factor) 485 | 486 | global_out=self.global_proj_out(x_mean) 487 | encoder_hidden_states = self.S(x,t,mask) 488 | return encoder_hidden_states,global_out -------------------------------------------------------------------------------- /modules/layers.py: -------------------------------------------------------------------------------- 1 | # Modified from Flux 2 | # 3 | # Copyright 2024 Black Forest Labs 4 | 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | # This source code is licensed under the license found in the 18 | # LICENSE file in the root directory of this source tree. 19 | 20 | import math # noqa: I001 21 | from dataclasses import dataclass 22 | from functools import partial 23 | 24 | import torch 25 | import torch.nn.functional as F 26 | from einops import rearrange 27 | # from liger_kernel.ops.rms_norm import LigerRMSNormFunction # Removed original import 28 | from torch import Tensor, nn 29 | 30 | # --- Added --- 31 | LigerRMSNormFunction = None # Default to None 32 | try: 33 | # Try importing only if CUDA might be available, as liger seems GPU-focused 34 | # and specifically checks for HIP inside its own utils during import. 35 | # A more robust check might be needed depending on liger_kernel's full init process. 36 | if torch.cuda.is_available() or hasattr(torch, 'hip'): # Check if potential GPU environment 37 | from liger_kernel.ops.rms_norm import LigerRMSNormFunction 38 | except (ImportError, AttributeError): 39 | # If import fails (missing library or attribute like 'hip'), keep it None 40 | print("ComfyUI_RH_Step1XEdit: liger_kernel not found or failed to import. Using fallback RMSNorm.") 41 | LigerRMSNormFunction = None 42 | # --- End Added --- 43 | 44 | 45 | try: 46 | import flash_attn 47 | from flash_attn.flash_attn_interface import ( 48 | _flash_attn_forward, 49 | flash_attn_varlen_func, 50 | ) 51 | except ImportError: 52 | flash_attn = None 53 | flash_attn_varlen_func = None 54 | _flash_attn_forward = None 55 | 56 | 57 | MEMORY_LAYOUT = { 58 | "flash": ( 59 | lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]), 60 | lambda x: x, 61 | ), 62 | "torch": ( 63 | lambda x: x.transpose(1, 2), 64 | lambda x: x.transpose(1, 2), 65 | ), 66 | "vanilla": ( 67 | lambda x: x.transpose(1, 2), 68 | lambda x: x.transpose(1, 2), 69 | ), 70 | } 71 | 72 | 73 | def attention( 74 | q, 75 | k, 76 | v, 77 | mode="flash", 78 | drop_rate=0, 79 | attn_mask=None, 80 | causal=False, 81 | cu_seqlens_q=None, 82 | cu_seqlens_kv=None, 83 | max_seqlen_q=None, 84 | max_seqlen_kv=None, 85 | batch_size=1, 86 | ): 87 | """ 88 | Perform QKV self attention. 89 | 90 | Args: 91 | q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads. 92 | k (torch.Tensor): Key tensor with shape [b, s1, a, d] 93 | v (torch.Tensor): Value tensor with shape [b, s1, a, d] 94 | mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'. 95 | drop_rate (float): Dropout rate in attention map. (default: 0) 96 | attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla). 97 | (default: None) 98 | causal (bool): Whether to use causal attention. (default: False) 99 | cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, 100 | used to index into q. 101 | cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, 102 | used to index into kv. 103 | max_seqlen_q (int): The maximum sequence length in the batch of q. 104 | max_seqlen_kv (int): The maximum sequence length in the batch of k and v. 105 | 106 | Returns: 107 | torch.Tensor: Output tensor after self attention with shape [b, s, ad] 108 | """ 109 | pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode] 110 | q = pre_attn_layout(q) 111 | k = pre_attn_layout(k) 112 | v = pre_attn_layout(v) 113 | 114 | if mode == "torch": 115 | if attn_mask is not None and attn_mask.dtype != torch.bool: 116 | attn_mask = attn_mask.to(q.dtype) 117 | x = F.scaled_dot_product_attention( 118 | q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal 119 | ) 120 | elif mode == "flash": 121 | assert flash_attn_varlen_func is not None 122 | x: torch.Tensor = flash_attn_varlen_func( 123 | q, 124 | k, 125 | v, 126 | cu_seqlens_q, 127 | cu_seqlens_kv, 128 | max_seqlen_q, 129 | max_seqlen_kv, 130 | ) # type: ignore 131 | # x with shape [(bxs), a, d] 132 | x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # type: ignore # reshape x to [b, s, a, d] 133 | elif mode == "vanilla": 134 | scale_factor = 1 / math.sqrt(q.size(-1)) 135 | 136 | b, a, s, _ = q.shape 137 | s1 = k.size(2) 138 | attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device) 139 | if causal: 140 | # Only applied to self attention 141 | assert attn_mask is None, ( 142 | "Causal mask and attn_mask cannot be used together" 143 | ) 144 | temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril( 145 | diagonal=0 146 | ) 147 | attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) 148 | attn_bias.to(q.dtype) 149 | 150 | if attn_mask is not None: 151 | if attn_mask.dtype == torch.bool: 152 | attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) 153 | else: 154 | attn_bias += attn_mask 155 | 156 | # TODO: Maybe force q and k to be float32 to avoid numerical overflow 157 | attn = (q @ k.transpose(-2, -1)) * scale_factor 158 | attn += attn_bias 159 | attn = attn.softmax(dim=-1) 160 | attn = torch.dropout(attn, p=drop_rate, train=True) 161 | x = attn @ v 162 | else: 163 | raise NotImplementedError(f"Unsupported attention mode: {mode}") 164 | 165 | x = post_attn_layout(x) 166 | b, s, a, d = x.shape 167 | out = x.reshape(b, s, -1) 168 | return out 169 | 170 | 171 | def apply_gate(x, gate=None, tanh=False): 172 | """AI is creating summary for apply_gate 173 | 174 | Args: 175 | x (torch.Tensor): input tensor. 176 | gate (torch.Tensor, optional): gate tensor. Defaults to None. 177 | tanh (bool, optional): whether to use tanh function. Defaults to False. 178 | 179 | Returns: 180 | torch.Tensor: the output tensor after apply gate. 181 | """ 182 | if gate is None: 183 | return x 184 | if tanh: 185 | return x * gate.unsqueeze(1).tanh() 186 | else: 187 | return x * gate.unsqueeze(1) 188 | 189 | 190 | class MLP(nn.Module): 191 | """MLP as used in Vision Transformer, MLP-Mixer and related networks""" 192 | 193 | def __init__( 194 | self, 195 | in_channels, 196 | hidden_channels=None, 197 | out_features=None, 198 | act_layer=nn.GELU, 199 | norm_layer=None, 200 | bias=True, 201 | drop=0.0, 202 | use_conv=False, 203 | device=None, 204 | dtype=None, 205 | ): 206 | super().__init__() 207 | out_features = out_features or in_channels 208 | hidden_channels = hidden_channels or in_channels 209 | bias = (bias, bias) 210 | drop_probs = (drop, drop) 211 | linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear 212 | 213 | self.fc1 = linear_layer( 214 | in_channels, hidden_channels, bias=bias[0], device=device, dtype=dtype 215 | ) 216 | self.act = act_layer() 217 | self.drop1 = nn.Dropout(drop_probs[0]) 218 | self.norm = ( 219 | norm_layer(hidden_channels, device=device, dtype=dtype) 220 | if norm_layer is not None 221 | else nn.Identity() 222 | ) 223 | self.fc2 = linear_layer( 224 | hidden_channels, out_features, bias=bias[1], device=device, dtype=dtype 225 | ) 226 | self.drop2 = nn.Dropout(drop_probs[1]) 227 | 228 | def forward(self, x): 229 | x = self.fc1(x) 230 | x = self.act(x) 231 | x = self.drop1(x) 232 | x = self.norm(x) 233 | x = self.fc2(x) 234 | x = self.drop2(x) 235 | return x 236 | 237 | 238 | class TextProjection(nn.Module): 239 | """ 240 | Projects text embeddings. Also handles dropout for classifier-free guidance. 241 | 242 | Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py 243 | """ 244 | 245 | def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None): 246 | factory_kwargs = {"dtype": dtype, "device": device} 247 | super().__init__() 248 | self.linear_1 = nn.Linear( 249 | in_features=in_channels, 250 | out_features=hidden_size, 251 | bias=True, 252 | **factory_kwargs, 253 | ) 254 | self.act_1 = act_layer() 255 | self.linear_2 = nn.Linear( 256 | in_features=hidden_size, 257 | out_features=hidden_size, 258 | bias=True, 259 | **factory_kwargs, 260 | ) 261 | 262 | def forward(self, caption): 263 | hidden_states = self.linear_1(caption) 264 | hidden_states = self.act_1(hidden_states) 265 | hidden_states = self.linear_2(hidden_states) 266 | return hidden_states 267 | 268 | 269 | class TimestepEmbedder(nn.Module): 270 | """ 271 | Embeds scalar timesteps into vector representations. 272 | """ 273 | 274 | def __init__( 275 | self, 276 | hidden_size, 277 | act_layer, 278 | frequency_embedding_size=256, 279 | max_period=10000, 280 | out_size=None, 281 | dtype=None, 282 | device=None, 283 | ): 284 | factory_kwargs = {"dtype": dtype, "device": device} 285 | super().__init__() 286 | self.frequency_embedding_size = frequency_embedding_size 287 | self.max_period = max_period 288 | if out_size is None: 289 | out_size = hidden_size 290 | 291 | self.mlp = nn.Sequential( 292 | nn.Linear( 293 | frequency_embedding_size, hidden_size, bias=True, **factory_kwargs 294 | ), 295 | act_layer(), 296 | nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs), 297 | ) 298 | nn.init.normal_(self.mlp[0].weight, std=0.02) # type: ignore 299 | nn.init.normal_(self.mlp[2].weight, std=0.02) # type: ignore 300 | 301 | @staticmethod 302 | def timestep_embedding(t, dim, max_period=10000): 303 | """ 304 | Create sinusoidal timestep embeddings. 305 | 306 | Args: 307 | t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional. 308 | dim (int): the dimension of the output. 309 | max_period (int): controls the minimum frequency of the embeddings. 310 | 311 | Returns: 312 | embedding (torch.Tensor): An (N, D) Tensor of positional embeddings. 313 | 314 | .. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py 315 | """ 316 | half = dim // 2 317 | freqs = torch.exp( 318 | -math.log(max_period) 319 | * torch.arange(start=0, end=half, dtype=torch.float32) 320 | / half 321 | ).to(device=t.device) 322 | args = t[:, None].float() * freqs[None] 323 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 324 | if dim % 2: 325 | embedding = torch.cat( 326 | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 327 | ) 328 | return embedding 329 | 330 | def forward(self, t): 331 | t_freq = self.timestep_embedding( 332 | t, self.frequency_embedding_size, self.max_period 333 | ).type(self.mlp[0].weight.dtype) # type: ignore 334 | t_emb = self.mlp(t_freq) 335 | return t_emb 336 | 337 | 338 | class EmbedND(nn.Module): 339 | def __init__(self, dim: int, theta: int, axes_dim: list[int]): 340 | super().__init__() 341 | self.dim = dim 342 | self.theta = theta 343 | self.axes_dim = axes_dim 344 | 345 | def forward(self, ids: Tensor) -> Tensor: 346 | n_axes = ids.shape[-1] 347 | emb = torch.cat( 348 | [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], 349 | dim=-3, 350 | ) 351 | 352 | return emb.unsqueeze(1) 353 | 354 | 355 | class MLPEmbedder(nn.Module): 356 | def __init__(self, in_dim: int, hidden_dim: int): 357 | super().__init__() 358 | self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) 359 | self.silu = nn.SiLU() 360 | self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) 361 | 362 | def forward(self, x: Tensor) -> Tensor: 363 | return self.out_layer(self.silu(self.in_layer(x))) 364 | 365 | 366 | from .attention import attention, flash_attn_func # Import flash_attn_func as well 367 | 368 | def rope(pos, dim: int, theta: int): 369 | assert dim % 2 == 0 370 | scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim 371 | omega = 1.0 / (theta**scale) 372 | out = torch.einsum("...n,d->...nd", pos, omega) 373 | out = torch.stack( 374 | [torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1 375 | ) 376 | out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) 377 | return out.float() 378 | 379 | 380 | def attention_after_rope(q, k, v, pe): 381 | q, k = apply_rope(q, k, pe) 382 | 383 | # Determine attention mode based on flash_attn availability 384 | mode = "flash" if flash_attn_func is not None else "torch" 385 | # print(f"Using attention mode: {mode}") # Optional: for debugging 386 | 387 | # from .attention import attention # Original import location moved up 388 | 389 | x = attention(q, k, v, mode=mode) # Use determined mode 390 | return x 391 | 392 | 393 | #@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=True) 394 | def apply_rope(xq, xk, freqs_cis): 395 | # 将 num_heads 和 seq_len 的维度交换回原函数的处理顺序 396 | xq = xq.transpose(1, 2) # [batch, num_heads, seq_len, head_dim] 397 | xk = xk.transpose(1, 2) 398 | 399 | # 将 head_dim 拆分为复数部分(实部和虚部) 400 | xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) 401 | xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) 402 | 403 | # 应用旋转位置编码(复数乘法) 404 | xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] 405 | xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] 406 | 407 | # 恢复张量形状并转置回目标维度顺序 408 | xq_out = xq_out.reshape(*xq.shape).type_as(xq).transpose(1, 2) 409 | xk_out = xk_out.reshape(*xk.shape).type_as(xk).transpose(1, 2) 410 | 411 | return xq_out, xk_out 412 | 413 | 414 | #@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=True) 415 | def scale_add_residual( 416 | x: torch.Tensor, scale: torch.Tensor, residual: torch.Tensor 417 | ) -> torch.Tensor: 418 | return x * scale + residual 419 | 420 | 421 | #@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=True) 422 | def layernorm_and_scale_shift( 423 | x: torch.Tensor, scale: torch.Tensor, shift: torch.Tensor 424 | ) -> torch.Tensor: 425 | return torch.nn.functional.layer_norm(x, (x.size(-1),)) * (scale + 1) + shift 426 | 427 | 428 | class SelfAttention(nn.Module): 429 | def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False): 430 | super().__init__() 431 | self.num_heads = num_heads 432 | head_dim = dim // num_heads 433 | 434 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 435 | self.norm = QKNorm(head_dim) 436 | self.proj = nn.Linear(dim, dim) 437 | 438 | def forward(self, x: Tensor, pe: Tensor) -> Tensor: 439 | qkv = self.qkv(x) 440 | q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads) 441 | q, k = self.norm(q, k, v) 442 | x = attention_after_rope(q, k, v, pe=pe) 443 | x = self.proj(x) 444 | return x 445 | 446 | 447 | @dataclass 448 | class ModulationOut: 449 | shift: Tensor 450 | scale: Tensor 451 | gate: Tensor 452 | 453 | 454 | class RMSNorm(torch.nn.Module): 455 | def __init__(self, dim: int): 456 | super().__init__() 457 | self.scale = nn.Parameter(torch.ones(dim)) 458 | 459 | @staticmethod 460 | def rms_norm_fast(x, weight, eps): 461 | # Use the potentially imported Liger function if available 462 | if LigerRMSNormFunction is not None: 463 | return LigerRMSNormFunction.apply( 464 | x, weight, eps 465 | ) 466 | else: 467 | # Fallback to the slower pure PyTorch implementation 468 | # This check was missing here, causing the NoneType error 469 | # when LigerRMSNormFunction was None but rms_norm_fast was still called. 470 | return RMSNorm.rms_norm(x, weight, eps) 471 | 472 | @staticmethod 473 | def rms_norm(x, weight, eps): 474 | x_dtype = x.dtype 475 | # Convert both input and weight to float32 for calculation 476 | x_float = x.float() 477 | weight_float = weight.float() 478 | # Calculate RMS norm in float32 479 | rrms = torch.rsqrt(torch.mean(x_float**2, dim=-1, keepdim=True) + eps) 480 | result_float = (x_float * rrms) * weight_float 481 | # Convert the final result back to the original input dtype 482 | return result_float.to(dtype=x_dtype) 483 | 484 | def forward(self, x: Tensor): 485 | return self.rms_norm_fast(x, self.scale, 1e-6) 486 | 487 | 488 | class QKNorm(torch.nn.Module): 489 | def __init__(self, dim: int): 490 | super().__init__() 491 | self.query_norm = RMSNorm(dim) 492 | self.key_norm = RMSNorm(dim) 493 | 494 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: 495 | q = self.query_norm(q) 496 | k = self.key_norm(k) 497 | return q.to(v), k.to(v) 498 | 499 | 500 | class Modulation(nn.Module): 501 | def __init__(self, dim: int, double: bool): 502 | super().__init__() 503 | self.is_double = double 504 | self.multiplier = 6 if double else 3 505 | self.lin = nn.Linear(dim, self.multiplier * dim, bias=True) 506 | 507 | def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]: 508 | out = self.lin(nn.functional.silu(vec))[:, None, :].chunk( 509 | self.multiplier, dim=-1 510 | ) 511 | 512 | return ( 513 | ModulationOut(*out[:3]), 514 | ModulationOut(*out[3:]) if self.is_double else None, 515 | ) 516 | 517 | 518 | class DoubleStreamBlock(nn.Module): 519 | def __init__( 520 | self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False 521 | ): 522 | super().__init__() 523 | 524 | mlp_hidden_dim = int(hidden_size * mlp_ratio) 525 | self.num_heads = num_heads 526 | self.hidden_size = hidden_size 527 | self.img_mod = Modulation(hidden_size, double=True) 528 | self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 529 | self.img_attn = SelfAttention( 530 | dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias 531 | ) 532 | 533 | self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 534 | self.img_mlp = nn.Sequential( 535 | nn.Linear(hidden_size, mlp_hidden_dim, bias=True), 536 | nn.GELU(approximate="tanh"), 537 | nn.Linear(mlp_hidden_dim, hidden_size, bias=True), 538 | ) 539 | 540 | self.txt_mod = Modulation(hidden_size, double=True) 541 | self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 542 | self.txt_attn = SelfAttention( 543 | dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias 544 | ) 545 | 546 | self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 547 | self.txt_mlp = nn.Sequential( 548 | nn.Linear(hidden_size, mlp_hidden_dim, bias=True), 549 | nn.GELU(approximate="tanh"), 550 | nn.Linear(mlp_hidden_dim, hidden_size, bias=True), 551 | ) 552 | 553 | def forward( 554 | self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor 555 | ) -> tuple[Tensor, Tensor]: 556 | img_mod1, img_mod2 = self.img_mod(vec) 557 | txt_mod1, txt_mod2 = self.txt_mod(vec) 558 | 559 | # prepare image for attention 560 | img_modulated = self.img_norm1(img) 561 | img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift 562 | img_qkv = self.img_attn.qkv(img_modulated) 563 | img_q, img_k, img_v = rearrange( 564 | img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads 565 | ) 566 | img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) 567 | 568 | # prepare txt for attention 569 | txt_modulated = self.txt_norm1(txt) 570 | txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift 571 | txt_qkv = self.txt_attn.qkv(txt_modulated) 572 | txt_q, txt_k, txt_v = rearrange( 573 | txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads 574 | ) 575 | txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) 576 | 577 | # run actual attention 578 | q = torch.cat((txt_q, img_q), dim=1) 579 | k = torch.cat((txt_k, img_k), dim=1) 580 | v = torch.cat((txt_v, img_v), dim=1) 581 | 582 | attn = attention_after_rope(q, k, v, pe=pe) 583 | txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] 584 | 585 | # calculate the img bloks 586 | img = img + img_mod1.gate * self.img_attn.proj(img_attn) 587 | img_mlp = self.img_mlp( 588 | (1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift 589 | ) 590 | img = scale_add_residual(img_mlp, img_mod2.gate, img) 591 | 592 | # calculate the txt bloks 593 | txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) 594 | txt_mlp = self.txt_mlp( 595 | (1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift 596 | ) 597 | txt = scale_add_residual(txt_mlp, txt_mod2.gate, txt) 598 | return img, txt 599 | 600 | 601 | class SingleStreamBlock(nn.Module): 602 | """ 603 | A DiT block with parallel linear layers as described in 604 | https://arxiv.org/abs/2302.05442 and adapted modulation interface. 605 | """ 606 | 607 | def __init__( 608 | self, 609 | hidden_size: int, 610 | num_heads: int, 611 | mlp_ratio: float = 4.0, 612 | qk_scale: float | None = None, 613 | ): 614 | super().__init__() 615 | self.hidden_dim = hidden_size 616 | self.num_heads = num_heads 617 | head_dim = hidden_size // num_heads 618 | self.scale = qk_scale or head_dim**-0.5 619 | 620 | self.mlp_hidden_dim = int(hidden_size * mlp_ratio) 621 | # qkv and mlp_in 622 | self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim) 623 | # proj and mlp_out 624 | self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) 625 | 626 | self.norm = QKNorm(head_dim) 627 | 628 | self.hidden_size = hidden_size 629 | self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 630 | 631 | self.mlp_act = nn.GELU(approximate="tanh") 632 | self.modulation = Modulation(hidden_size, double=False) 633 | 634 | def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: 635 | mod, _ = self.modulation(vec) 636 | x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift 637 | qkv, mlp = torch.split( 638 | self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1 639 | ) 640 | 641 | q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads) 642 | q, k = self.norm(q, k, v) 643 | 644 | # compute attention 645 | attn = attention_after_rope(q, k, v, pe=pe) 646 | # compute activation in mlp stream, cat again and run second linear layer 647 | output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) 648 | return scale_add_residual(output, mod.gate, x) 649 | 650 | 651 | class LastLayer(nn.Module): 652 | def __init__(self, hidden_size: int, patch_size: int, out_channels: int): 653 | super().__init__() 654 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 655 | self.linear = nn.Linear( 656 | hidden_size, patch_size * patch_size * out_channels, bias=True 657 | ) 658 | self.adaLN_modulation = nn.Sequential( 659 | nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True) 660 | ) 661 | 662 | def forward(self, x: Tensor, vec: Tensor) -> Tensor: 663 | shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) 664 | x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] 665 | x = self.linear(x) 666 | return x 667 | -------------------------------------------------------------------------------- /modules/model_edit.py: -------------------------------------------------------------------------------- 1 | import math 2 | from dataclasses import dataclass 3 | 4 | import numpy as np 5 | import torch 6 | from torch import Tensor, nn 7 | 8 | from .connector_edit import Qwen2Connector 9 | from .layers import DoubleStreamBlock, EmbedND, LastLayer, MLPEmbedder, SingleStreamBlock 10 | 11 | 12 | @dataclass 13 | class Step1XParams: 14 | in_channels: int 15 | out_channels: int 16 | vec_in_dim: int 17 | context_in_dim: int 18 | hidden_size: int 19 | mlp_ratio: float 20 | num_heads: int 21 | depth: int 22 | depth_single_blocks: int 23 | axes_dim: list[int] 24 | theta: int 25 | qkv_bias: bool 26 | 27 | 28 | class Step1XEdit(nn.Module): 29 | """ 30 | Transformer model for flow matching on sequences. 31 | """ 32 | 33 | def __init__(self, params: Step1XParams): 34 | super().__init__() 35 | 36 | self.params = params 37 | self.in_channels = params.in_channels 38 | self.out_channels = params.out_channels 39 | if params.hidden_size % params.num_heads != 0: 40 | raise ValueError( 41 | f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" 42 | ) 43 | pe_dim = params.hidden_size // params.num_heads 44 | if sum(params.axes_dim) != pe_dim: 45 | raise ValueError( 46 | f"Got {params.axes_dim} but expected positional dim {pe_dim}" 47 | ) 48 | self.hidden_size = params.hidden_size 49 | self.num_heads = params.num_heads 50 | self.pe_embedder = EmbedND( 51 | dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim 52 | ) 53 | self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) 54 | self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) 55 | self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) 56 | self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) 57 | 58 | self.double_blocks = nn.ModuleList( 59 | [ 60 | DoubleStreamBlock( 61 | self.hidden_size, 62 | self.num_heads, 63 | mlp_ratio=params.mlp_ratio, 64 | qkv_bias=params.qkv_bias, 65 | ) 66 | for _ in range(params.depth) 67 | ] 68 | ) 69 | 70 | self.single_blocks = nn.ModuleList( 71 | [ 72 | SingleStreamBlock( 73 | self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio 74 | ) 75 | for _ in range(params.depth_single_blocks) 76 | ] 77 | ) 78 | 79 | self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) 80 | 81 | self.connector = Qwen2Connector() 82 | 83 | @staticmethod 84 | def timestep_embedding( 85 | t: Tensor, dim, max_period=10000, time_factor: float = 1000.0 86 | ): 87 | """ 88 | Create sinusoidal timestep embeddings. 89 | :param t: a 1-D Tensor of N indices, one per batch element. 90 | These may be fractional. 91 | :param dim: the dimension of the output. 92 | :param max_period: controls the minimum frequency of the embeddings. 93 | :return: an (N, D) Tensor of positional embeddings. 94 | """ 95 | t = time_factor * t 96 | half = dim // 2 97 | freqs = torch.exp( 98 | -math.log(max_period) 99 | * torch.arange(start=0, end=half, dtype=torch.float32) 100 | / half 101 | ).to(t.device) 102 | 103 | args = t[:, None].float() * freqs[None] 104 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 105 | if dim % 2: 106 | embedding = torch.cat( 107 | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 108 | ) 109 | if torch.is_floating_point(t): 110 | embedding = embedding.to(t) 111 | return embedding 112 | 113 | def forward( 114 | self, 115 | img: Tensor, 116 | img_ids: Tensor, 117 | txt: Tensor, 118 | txt_ids: Tensor, 119 | timesteps: Tensor, 120 | y: Tensor, 121 | **kwargs 122 | ) -> Tensor: 123 | use_fp8 = kwargs['use_fp8'] if 'use_fp8' in kwargs else False 124 | print(f'in Step1XEdit use_fp8:{use_fp8}') 125 | if img.ndim != 3 or txt.ndim != 3: 126 | raise ValueError("Input img and txt tensors must have 3 dimensions.") 127 | 128 | img = self.img_in(img) 129 | vec = self.time_in(self.timestep_embedding(timesteps, 256)) 130 | 131 | vec = vec + self.vector_in(y) 132 | txt = self.txt_in(txt) 133 | 134 | ids = torch.cat((txt_ids, img_ids), dim=1) 135 | pe = self.pe_embedder(ids) 136 | 137 | if not use_fp8: 138 | print('swap1') 139 | self.single_blocks.to('cpu') 140 | self.double_blocks.to('cuda') 141 | for block in self.double_blocks: 142 | img, txt = block(img=img, txt=txt, vec=vec, pe=pe) 143 | 144 | img = torch.cat((txt, img), 1) 145 | 146 | if not use_fp8: 147 | print('swap2') 148 | self.double_blocks.to('cpu') 149 | self.single_blocks.to('cuda') 150 | for block in self.single_blocks: 151 | img = block(img, vec=vec, pe=pe) 152 | img = img[:, txt.shape[1] :, ...] 153 | 154 | img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) 155 | return img -------------------------------------------------------------------------------- /nodes.py: -------------------------------------------------------------------------------- 1 | #import os 2 | #import sys 3 | #current_dir = os.path.dirname(os.path.abspath(__file__)) 4 | #sys.path.insert(0, current_dir) 5 | 6 | 7 | import sys 8 | import os 9 | import importlib 10 | import folder_paths 11 | from .inference import Step1XImageGenerator, kiki_tensor_to_pil 12 | import torch 13 | import numpy as np 14 | import comfy.utils 15 | 16 | # sys.path.insert(0, os.path.dirname(os.path.realpath(__file__))) 17 | 18 | class Kiki_Step1XEdit: 19 | @classmethod 20 | def INPUT_TYPES(s): 21 | return { 22 | "required": { 23 | "ref_image": ("IMAGE",), 24 | "prompt": ("STRING", {"multiline": True, 25 | "default": ''}), 26 | "num_steps": ("INT", {"default": 28, "min": 1, "max": 0xffffffffffffffff}), 27 | "cfg_guidance": ("FLOAT", {"default": 6.0}), 28 | "size_level": ("INT", {"default": 1024}), 29 | "seed": ("INT", {"default": 42, "min": 0, "max": 0xffffffffffffffff, 30 | "tooltip": "The random seed used for creating the noise."}), 31 | "use_fp8": ("BOOLEAN", {"default": True}) 32 | } 33 | } 34 | 35 | RETURN_TYPES = ("IMAGE",) 36 | RETURN_NAMES = ("image",) 37 | FUNCTION = "run" 38 | TITLE = 'RunningHub Step1X Edit' 39 | 40 | CATEGORY = "Runninghub/Step1XEdit" 41 | DESCRIPTION = "RunningHub Step1X Edit in 24G" 42 | 43 | def __init__(self): 44 | model_path = os.path.join(folder_paths.models_dir, 'step-1') 45 | qwen2vl_model_path = os.path.join(model_path, 'Qwen2.5-VL-7B-Instruct') 46 | 47 | self.image_edit = Step1XImageGenerator( 48 | ae_path=os.path.join(model_path, 'vae.safetensors'), 49 | dit_path=os.path.join(model_path, "step1x-edit-i1258.safetensors"), 50 | qwen2vl_model_path=qwen2vl_model_path, 51 | max_length=640, 52 | ) 53 | 54 | # def run(self, **kwargs): 55 | # lib_name = 'ComfyUI_RH_Step1XEdit.inference' 56 | # if lib_name in sys.modules: 57 | # importlib.reload(sys.modules[lib_name]) 58 | # import ComfyUI_RH_Step1XEdit.inference as rsi 59 | # img = rsi.run(**kwargs) 60 | # return (img, ) 61 | 62 | def run(self, **kwargs): 63 | prompt = kwargs['prompt'] 64 | img = kiki_tensor_to_pil(kwargs['ref_image'][0]) 65 | num_steps = kwargs['num_steps'] 66 | self.pbar = comfy.utils.ProgressBar(num_steps) 67 | use_fp8 = kwargs['use_fp8'] if 'use_fp8' in kwargs else False 68 | 69 | image = self.image_edit.generate_image( 70 | prompt, 71 | negative_prompt="", 72 | ref_images=img, 73 | num_samples=1, 74 | num_steps=num_steps, 75 | cfg_guidance=kwargs['cfg_guidance'], 76 | seed=kwargs['seed'], 77 | show_progress=True, 78 | size_level=kwargs['size_level'], 79 | rh_hook=self.update, 80 | use_fp8=use_fp8 81 | )[0] 82 | 83 | image = np.array(image).astype(np.float32) / 255.0 84 | image = torch.from_numpy(image)[None,] 85 | return (image, ) 86 | 87 | def update(self): 88 | self.pbar.update(1) 89 | 90 | NODE_CLASS_MAPPINGS = { 91 | "RunningHub_Step1XEdit": Kiki_Step1XEdit, 92 | } 93 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Core ML/DL Framework 2 | torch>=2.3.0 3 | torchvision>=0.18.0 4 | torchaudio>=2.3.0 5 | 6 | # LLM/Transformer Support 7 | transformers>=4.49.0 8 | accelerate>=0.30.1 9 | 10 | # Tensor Operations & Model Loading 11 | einops>=0.8.0 12 | safetensors>=0.4.3 13 | 14 | # Image Processing 15 | Pillow>=10.3.0 16 | 17 | # Utilities 18 | numpy>=1.26.4 19 | tqdm>=4.66.4 20 | 21 | # --- Optional Performance Libraries --- 22 | # Note: Installation for these often depends on your specific CUDA version and hardware. 23 | # Consult their respective documentation for installation instructions. 24 | # flash-attn>=2.5.8 25 | # liger-kernel (Installation method may vary) 26 | -------------------------------------------------------------------------------- /sampling.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections.abc import Callable 3 | 4 | import torch 5 | from torch import Tensor 6 | 7 | 8 | def get_noise(num_samples: int, height: int, width: int, device: torch.device, dtype: torch.dtype, seed: int): 9 | return torch.randn( 10 | num_samples, 11 | 16, 12 | # allow for packing 13 | 2 * math.ceil(height / 16), 14 | 2 * math.ceil(width / 16), 15 | device=device, 16 | dtype=dtype, 17 | generator=torch.Generator(device=device).manual_seed(seed), 18 | ) 19 | 20 | 21 | def time_shift(mu: float, sigma: float, t: Tensor): 22 | return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) 23 | 24 | 25 | def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]: 26 | m = (y2 - y1) / (x2 - x1) 27 | b = y1 - m * x1 28 | return lambda x: m * x + b 29 | 30 | 31 | def get_schedule( 32 | num_steps: int, 33 | image_seq_len: int, 34 | base_shift: float = 0.5, 35 | max_shift: float = 1.15, 36 | shift: bool = True, 37 | ) -> list[float]: 38 | # extra step for zero 39 | timesteps = torch.linspace(1, 0, num_steps + 1) 40 | 41 | # shifting the schedule to favor high timesteps for higher signal images 42 | if shift: 43 | # estimate mu based on linear estimation between two points 44 | mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) 45 | timesteps = time_shift(mu, 1.0, timesteps) 46 | 47 | return timesteps.tolist() 48 | --------------------------------------------------------------------------------