├── .gitignore ├── README.md ├── __init__.py ├── normal_crafter_nodes.py ├── normalcrafter ├── __init__.py ├── normal_crafter_ppl.py ├── unet.py └── utils.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.so 6 | .Python 7 | build/ 8 | develop-eggs/ 9 | dist/ 10 | downloads/ 11 | eggs/ 12 | .eggs/ 13 | lib/ 14 | lib64/ 15 | parts/ 16 | sdist/ 17 | var/ 18 | wheels/ 19 | share/python-wheels/ 20 | *.egg-info/ 21 | .installed.cfg 22 | *.egg 23 | MANIFEST 24 | 25 | # Virtualenv 26 | .env 27 | .venv 28 | env/ 29 | venv/ 30 | ENV/ 31 | env.bak/ 32 | venv.bak/ 33 | 34 | # IDE / Editor specific 35 | .vscode/ 36 | .idea/ 37 | *.swp 38 | *.swo 39 | 40 | # OS specific 41 | .DS_Store 42 | Thumbs.db -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ComfyUI NormalCrafter 2 | 3 | This is a ComfyUI custom node implementation for [NormalCrafter: Learning Temporally Consistent Normals from Video Diffusion Priors](https://github.com/Binyr/NormalCrafter) by Yanrui Bin, Wenbo Hu, Haoyuan Wang, Xinya Chen, and Bing Wang. 4 | 5 | It allows you to generate temporally consistent normal map sequences from input video frames. 6 | 7 | ## Installation 8 | 9 | 1. Clone this repository into your `ComfyUI/custom_nodes/` directory: 10 | ```bash 11 | cd ComfyUI/custom_nodes/ 12 | git clone https://github.com/AIWarper/ComfyUI-NormalCrafterWrapper.git 13 | ``` 14 | 2. Install the required dependencies. Navigate to the `ComfyUI-NormalCrafter` directory and install using the `requirements.txt`: 15 | ```bash 16 | ACTIVATE YOUR VENV FIRST 17 | ComfyUI-NormalCrafterWrapper 18 | pip install -r requirements.txt 19 | ``` 20 | (See the `requirements.txt` for notes on the `diffusers` dependency, which ComfyUI often manages.) 21 | 3. Restart ComfyUI. 22 | 23 | ## Node: NormalCrafter (Process Video) 24 | 25 | This node takes a sequence of images (video frames) and processes them to output a corresponding sequence of normal map images. 26 | 27 | ### Parameters 28 | 29 | * **`images` (Input Socket)**: The input image sequence (video frames). 30 | * **`pipe_override` (Input Socket, Optional)**: Allows providing a pre-loaded NormalCrafter pipeline instance. If unconnected, the node loads its own. 31 | * **`seed`**: (Integer, Default: 42) Controls the randomness for reproducible results. 32 | * **`control_after_generate`**: (Fixed, Increment, Decrement, Randomize) Standard ComfyUI widget for seed behavior on subsequent runs. Note: The underlying pipeline uses the seed for each full video processing. 33 | * **`max_res_dimension`**: (Integer, Default: 1024) The maximum dimension (height or width) to which input frames are resized while maintaining aspect ratio. 34 | * **`window_size`**: (Integer, Default: 14) The number of consecutive frames processed together in a sliding window. Affects temporal consistency. 35 | * **`time_step_size`**: (Integer, Default: 10) How many frames the sliding window moves forward after processing a chunk. If less than `window_size`, frames will overlap, potentially improving smoothness. 36 | * **`decode_chunk_size`**: (Integer, Default: 4) Number of latent frames decoded by the VAE at once. Primarily a VRAM management setting. 37 | * **`fps_for_time_ids`**: (Integer, Default: 7) Conditions the model on an intended Frames Per Second, influencing motion characteristics in the generated normals. *Note: In testing, this parameter showed minimal to no visible effect on the output for this specific model and task. As such I hard coded the value* 38 | * **`motion_bucket_id`**: (Integer, Default: 127) Conditions the model on an expected amount of motion. *Note: In testing, this parameter showed minimal to no visible effect on the output for this specific model and task. As such I hard coded the value* 39 | * **`noise_aug_strength`**: (Float, Default: 0.0) Strength of noise augmentation applied to conditioning information. *Note: In testing, this parameter showed minimal to no visible effect on the output for this specific model and task. As such I hard coded the value* 40 | 41 | ### Troubleshooting Flicker / Improving Temporal Consistency 42 | 43 | If you are experiencing flickering or temporal inconsistencies in your output: 44 | 45 | * **Increase `window_size`**: A larger window allows the model to see more temporal context, which can significantly improve consistency between frames. 46 | * **Adjust `time_step_size`**: Using a `time_step_size` smaller than `window_size` creates an overlap between processed windows. This overlap is merged, which can smooth transitions. For example, if `window_size` is 20, try a `time_step_size` of 10 or 15. 47 | 48 | You may be able to increase `window_size` and `time_step_size` substantially (e.g., to their maximum values) without encountering Out Of Memory (OOM) issues, depending on your hardware. Experiment to find the best balance for your needs. 49 | 50 | ### Dependencies 51 | 52 | * `mediapy` 53 | * `decord` 54 | * `diffusers` (and its dependencies like `transformers`, `huggingface_hub`) - ComfyUI usually manages its own `diffusers` version. Install manually if you encounter specific import errors related to it. 55 | * `torch`, `numpy`, `Pillow` (standard Python ML/Image libraries) 56 | 57 | Refer to `requirements.txt` for more details. 58 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | # ComfyUI/custom_nodes/ComfyUI-NormalCrafter/__init__.py 2 | # Or ComfyUI/custom_nodes/ComfyUI-NormalCrafterWrapper/__init__.py 3 | 4 | try: 5 | # Import BOTH classes from your nodes file 6 | from .normal_crafter_nodes import NormalCrafterNode, DetailTransfer # <--- IMPORT DetailTransfer HERE 7 | 8 | NODE_CLASS_MAPPINGS = { 9 | "NormalCrafterNode": NormalCrafterNode, 10 | "DetailTransfer": DetailTransfer # Now DetailTransfer is defined 11 | } 12 | 13 | NODE_DISPLAY_NAME_MAPPINGS = { 14 | "NormalCrafterNode": "NormalCrafter (Process Video)", # Or your preferred name 15 | "DetailTransfer": "Detail Transfer" 16 | } 17 | 18 | __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS'] 19 | print("✅ ComfyUI-NormalCrafter: Custom nodes loaded successfully.") 20 | 21 | except ImportError as e: 22 | print(f"❌ ComfyUI-NormalCrafter: Failed to import nodes: {e}") 23 | print(" Please ensure all dependencies are installed (e.g., diffusers, transformers, huggingface_hub) and the files are correctly placed.") 24 | NODE_CLASS_MAPPINGS = {} 25 | NODE_DISPLAY_NAME_MAPPINGS = {} 26 | except NameError as e: # Catch NameError specifically if DetailTransfer wasn't found during import 27 | print(f"❌ ComfyUI-NormalCrafter: Failed to define nodes, likely an issue importing a class: {e}") 28 | NODE_CLASS_MAPPINGS = {} 29 | NODE_DISPLAY_NAME_MAPPINGS = {} 30 | -------------------------------------------------------------------------------- /normal_crafter_nodes.py: -------------------------------------------------------------------------------- 1 | # ComfyUI/custom_nodes/ComfyUI-NormalCrafter/normal_crafter_nodes.py 2 | 3 | import torch 4 | import torchvision.transforms as transforms 5 | import numpy as np 6 | from PIL import Image 7 | import os 8 | 9 | import comfy.model_management 10 | import model_management 11 | import comfy.utils 12 | import folder_paths # ComfyUI's way to get model paths 13 | 14 | # Try to import required components and provide guidance if they are missing 15 | try: 16 | from diffusers import AutoencoderKLTemporalDecoder 17 | from huggingface_hub import snapshot_download 18 | except ImportError: 19 | print("ComfyUI-NormalCrafter: Missing essential libraries. Please ensure 'diffusers', 'transformers', 'accelerate', and 'huggingface_hub' are installed.") 20 | # Consider raising an error or making the node unusable if critical dependencies are missing. 21 | 22 | try: 23 | # This assumes the 'normalcrafter' directory is correctly placed within ComfyUI-NormalCrafter 24 | from .normalcrafter.normal_crafter_ppl import NormalCrafterPipeline 25 | from .normalcrafter.unet import DiffusersUNetSpatioTemporalConditionModelNormalCrafter 26 | except ImportError as e: 27 | print(f"ComfyUI-NormalCrafter: Error importing NormalCrafter components: {e}. " 28 | "Ensure the 'normalcrafter' directory is correctly placed inside 'ComfyUI-NormalCrafter'.") 29 | 30 | # Global variable to cache the pipeline 31 | NORMALCRAFTER_PIPE = None 32 | CURRENT_PIPE_CONFIG = {} # Stores the config {"device": "cuda"/"cpu", "dtype": "float16"/"float32"} 33 | 34 | # Define model paths and repo ID 35 | NORMALCRAFTER_REPO_ID = "Yanrui95/NormalCrafter" 36 | NORMALCRAFTER_MODELS_SUBDIR_NAME = "normalcrafter_models" # Subdirectory in ComfyUI/models/ 37 | SVD_XT_REPO_ID = "stabilityai/stable-video-diffusion-img2vid-xt" 38 | 39 | class DetailTransfer: 40 | @classmethod 41 | def INPUT_TYPES(s): 42 | return { 43 | "required": { 44 | "target": ("IMAGE", ), 45 | "source": ("IMAGE", ), 46 | "mode": ([ 47 | "add", 48 | "multiply", 49 | "screen", 50 | "overlay", 51 | "soft_light", 52 | "hard_light", 53 | "color_dodge", 54 | "color_burn", 55 | "difference", 56 | "exclusion", 57 | "divide", 58 | 59 | ], 60 | {"default": "add"} 61 | ), 62 | "blur_sigma": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 100.0, "step": 0.01}), 63 | "blend_factor": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.001, "round": 0.001}), 64 | }, 65 | "optional": { 66 | "mask": ("MASK", ), 67 | } 68 | } 69 | 70 | RETURN_TYPES = ("IMAGE",) 71 | FUNCTION = "process" 72 | CATEGORY = "NormalCrafter" 73 | 74 | def adjust_mask(self, mask, target_tensor): 75 | # Add a channel dimension and repeat to match the channel number of the target tensor 76 | if len(mask.shape) == 3: 77 | mask = mask.unsqueeze(1) # Add a channel dimension 78 | target_channels = target_tensor.shape[1] 79 | mask = mask.expand(-1, target_channels, -1, -1) # Expand the channel dimension to match the target tensor's channels 80 | 81 | return mask 82 | 83 | 84 | def process(self, target, source, mode, blur_sigma, blend_factor, mask=None): 85 | B, H, W, C = target.shape 86 | device = model_management.get_torch_device() 87 | target_tensor = target.permute(0, 3, 1, 2).clone().to(device) 88 | source_tensor = source.permute(0, 3, 1, 2).clone().to(device) 89 | 90 | if target.shape[1:] != source.shape[1:]: 91 | source_tensor = comfy.utils.common_upscale(source_tensor, W, H, "bilinear", "disabled") 92 | 93 | if source.shape[0] < B: 94 | source = source[0].unsqueeze(0).repeat(B, 1, 1, 1) 95 | 96 | kernel_size = int(6 * int(blur_sigma) + 1) 97 | 98 | gaussian_blur = transforms.GaussianBlur(kernel_size=(kernel_size, kernel_size), sigma=(blur_sigma, blur_sigma)) 99 | 100 | blurred_target = gaussian_blur(target_tensor) 101 | blurred_source = gaussian_blur(source_tensor) 102 | 103 | if mode == "add": 104 | tensor_out = (source_tensor - blurred_source) + blurred_target 105 | elif mode == "multiply": 106 | tensor_out = source_tensor * blurred_target 107 | elif mode == "screen": 108 | tensor_out = 1 - (1 - source_tensor) * (1 - blurred_target) 109 | elif mode == "overlay": 110 | tensor_out = torch.where(blurred_target < 0.5, 2 * source_tensor * blurred_target, 1 - 2 * (1 - source_tensor) * (1 - blurred_target)) 111 | elif mode == "soft_light": 112 | tensor_out = (1 - 2 * blurred_target) * source_tensor**2 + 2 * blurred_target * source_tensor 113 | elif mode == "hard_light": 114 | tensor_out = torch.where(source_tensor < 0.5, 2 * source_tensor * blurred_target, 1 - 2 * (1 - source_tensor) * (1 - blurred_target)) 115 | elif mode == "difference": 116 | tensor_out = torch.abs(blurred_target - source_tensor) 117 | elif mode == "exclusion": 118 | tensor_out = 0.5 - 2 * (blurred_target - 0.5) * (source_tensor - 0.5) 119 | elif mode == "color_dodge": 120 | tensor_out = blurred_target / (1 - source_tensor) 121 | elif mode == "color_burn": 122 | tensor_out = 1 - (1 - blurred_target) / source_tensor 123 | elif mode == "divide": 124 | tensor_out = (source_tensor / blurred_source) * blurred_target 125 | else: 126 | tensor_out = source_tensor 127 | 128 | tensor_out = torch.lerp(target_tensor, tensor_out, blend_factor) 129 | if mask is not None: 130 | # Call the function and pass in mask and target_tensor 131 | mask = self.adjust_mask(mask, target_tensor) 132 | mask = mask.to(device) 133 | tensor_out = torch.lerp(target_tensor, tensor_out, mask) 134 | tensor_out = torch.clamp(tensor_out, 0, 1) 135 | tensor_out = tensor_out.permute(0, 2, 3, 1).cpu().float() 136 | return (tensor_out,) 137 | 138 | class NormalCrafterNode: 139 | def __init__(self): 140 | self.pipe = None # Instance variable for the pipeline 141 | 142 | @classmethod 143 | def INPUT_TYPES(cls): 144 | return { 145 | "required": { 146 | "images": ("IMAGE",), 147 | "seed": ("INT", {"default": 42, "min": 0, "max": 0xffffffffffffffff}), 148 | "max_res_dimension": ("INT", {"default": 1024, "min": 256, "max": 2048, "step": 64}), 149 | "window_size": ("INT", {"default": 14, "min": 2, "max": 64}), 150 | "time_step_size": ("INT", {"default": 10, "min": 1, "max": 64}), 151 | "decode_chunk_size": ("INT", {"default": 4, "min": 1, "max": 64}), 152 | "offload_pipe_to_cpu_on_finish": ("BOOLEAN", {"default": True}), # <<< NEW INPUT 153 | }, 154 | "optional": { 155 | "pipe_override": ("NORMALCRAFTER_PIPE",), 156 | } 157 | } 158 | 159 | RETURN_TYPES = ("IMAGE",) 160 | FUNCTION = "process" 161 | CATEGORY = "NormalCrafter" 162 | 163 | def _get_local_nc_model_path(self): 164 | base_models_dir = os.path.join(folder_paths.models_dir, NORMALCRAFTER_MODELS_SUBDIR_NAME) 165 | os.makedirs(base_models_dir, exist_ok=True) 166 | model_name_folder = NORMALCRAFTER_REPO_ID.split('/')[-1] 167 | specific_model_path = os.path.join(base_models_dir, model_name_folder) 168 | return specific_model_path 169 | 170 | def _download_normalcrafter_model_if_needed(self): 171 | local_nc_path = self._get_local_nc_model_path() 172 | unet_config_path = os.path.join(local_nc_path, "unet", "config.json") 173 | 174 | if not os.path.exists(unet_config_path): 175 | print(f"ComfyUI-NormalCrafter: Downloading {NORMALCRAFTER_REPO_ID} model to {local_nc_path}...") 176 | try: 177 | snapshot_download( 178 | repo_id=NORMALCRAFTER_REPO_ID, 179 | local_dir=local_nc_path, 180 | local_dir_use_symlinks=False, 181 | ) 182 | print(f"ComfyUI-NormalCrafter: Model {NORMALCRAFTER_REPO_ID} download complete.") 183 | except Exception as e: 184 | print(f"ComfyUI-NormalCrafter: Failed to download model {NORMALCRAFTER_REPO_ID}: {e}") 185 | raise 186 | else: 187 | # print(f"ComfyUI-NormalCrafter: Model {NORMALCRAFTER_REPO_ID} found at {local_nc_path}.") 188 | pass 189 | return local_nc_path 190 | 191 | def _load_pipeline(self, device_str_requested="cuda", dtype_str_requested="float16"): 192 | global NORMALCRAFTER_PIPE, CURRENT_PIPE_CONFIG 193 | 194 | target_device_for_this_run = comfy.model_management.get_torch_device() if device_str_requested == "cuda" else torch.device("cpu") 195 | 196 | # Determine the torch_dtype we want for this run 197 | if dtype_str_requested == "float16": final_torch_dtype_for_load = torch.float16 198 | elif dtype_str_requested == "bf16": final_torch_dtype_for_load = torch.bfloat16 199 | else: final_torch_dtype_for_load = torch.float32 200 | 201 | if NORMALCRAFTER_PIPE is not None: 202 | # Pipe exists. Check if its current dtype and target device are suitable. 203 | pipe_actual_dtype = NORMALCRAFTER_PIPE.dtype # The true dtype of the existing pipe 204 | pipe_actual_dtype_str = "float16" if pipe_actual_dtype == torch.float16 else \ 205 | "bf16" if pipe_actual_dtype == torch.bfloat16 else "float32" 206 | 207 | if dtype_str_requested == pipe_actual_dtype_str: 208 | # Dtypes match. Just check device. 209 | if NORMALCRAFTER_PIPE.device != target_device_for_this_run: 210 | print(f"ComfyUI-NormalCrafter: Moving existing pipeline from {NORMALCRAFTER_PIPE.device} to {target_device_for_this_run}.") 211 | NORMALCRAFTER_PIPE.to(target_device_for_this_run) 212 | # The warning about fp16 on cpu will appear here if target_device_for_this_run is cpu, 213 | # but we're moving an fp16 pipe from cpu to gpu in the typical reload case, which is fine. 214 | 215 | CURRENT_PIPE_CONFIG = {"device": str(target_device_for_this_run), "dtype": dtype_str_requested} 216 | self.pipe = NORMALCRAFTER_PIPE 217 | # print(f"ComfyUI-NormalCrafter: Reusing existing pipeline. Now on {target_device_for_this_run} with {dtype_str_requested}.") 218 | return self.pipe 219 | else: 220 | # Dtype mismatch (e.g., user wants float32 now, but pipe is float16). Must reload fully. 221 | print(f"ComfyUI-NormalCrafter: Requested dtype {dtype_str_requested} differs from existing pipe's dtype {pipe_actual_dtype_str}. Reloading pipeline.") 222 | NORMALCRAFTER_PIPE = None # Force a full reload by clearing the global pipe 223 | 224 | # If NORMALCRAFTER_PIPE is None (either initially, or forced by dtype mismatch above) 225 | print("ComfyUI-NormalCrafter: Loading/Re-loading NormalCrafter pipeline from scratch...") 226 | local_nc_model_path = self._download_normalcrafter_model_if_needed() 227 | 228 | # Load components with final_torch_dtype_for_load 229 | unet = DiffusersUNetSpatioTemporalConditionModelNormalCrafter.from_pretrained( 230 | local_nc_model_path, subfolder="unet", low_cpu_mem_usage=True, torch_dtype=final_torch_dtype_for_load 231 | ) 232 | vae = AutoencoderKLTemporalDecoder.from_pretrained( 233 | local_nc_model_path, subfolder="vae", low_cpu_mem_usage=True, torch_dtype=final_torch_dtype_for_load 234 | ) 235 | svd_xt_variant = "fp16" if final_torch_dtype_for_load == torch.float16 else None 236 | 237 | pipe = NormalCrafterPipeline.from_pretrained( 238 | SVD_XT_REPO_ID, unet=unet, vae=vae, torch_dtype=final_torch_dtype_for_load, variant=svd_xt_variant, 239 | ) 240 | 241 | try: 242 | pipe.enable_xformers_memory_efficient_attention() 243 | # print("ComfyUI-NormalCrafter: Xformers memory efficient attention enabled.") # Already printed during first load often 244 | except Exception: # Simplified error handling 245 | pass 246 | 247 | pipe.to(target_device_for_this_run) 248 | NORMALCRAFTER_PIPE = pipe 249 | CURRENT_PIPE_CONFIG = { 250 | "device": str(target_device_for_this_run), 251 | "dtype": dtype_str_requested # The dtype it's configured with for this run 252 | } 253 | self.pipe = NORMALCRAFTER_PIPE 254 | print(f"ComfyUI-NormalCrafter: Pipeline loaded to {target_device_for_this_run} with dtype {final_torch_dtype_for_load}.") 255 | return self.pipe 256 | 257 | def tensor_to_pil_list(self, images_tensor: torch.Tensor) -> list: 258 | pil_images = [] 259 | for i in range(images_tensor.shape[0]): 260 | img_np = (images_tensor[i].cpu().numpy() * 255).astype(np.uint8) 261 | pil_images.append(Image.fromarray(img_np)) 262 | return pil_images 263 | 264 | def resize_pil_images(self, pil_images: list, max_res_dim: int) -> list: 265 | resized_images = [] 266 | if not pil_images: return [] 267 | for img in pil_images: 268 | original_width, original_height = img.size 269 | if max(original_height, original_width) > max_res_dim: 270 | scale = max_res_dim / max(original_height, original_width) 271 | target_height = round(original_height * scale) 272 | target_width = round(original_width * scale) 273 | else: 274 | target_height = original_height 275 | target_width = original_width 276 | resized_images.append(img.resize((target_width, target_height), Image.LANCZOS)) 277 | return resized_images 278 | 279 | def process(self, images: torch.Tensor, seed: int, max_res_dimension: int, 280 | window_size: int, time_step_size: int, decode_chunk_size: int, 281 | offload_pipe_to_cpu_on_finish: bool, # <<< NEW PARAMETER 282 | pipe_override=None): 283 | 284 | default_fps_for_time_ids = 7 285 | default_motion_bucket_id = 127 286 | default_noise_aug_strength = 0.0 287 | 288 | if pipe_override is not None: 289 | self.pipe = pipe_override 290 | print("ComfyUI-NormalCrafter: Using provided pipe_override.") 291 | else: 292 | current_comfy_device = comfy.model_management.get_torch_device() 293 | device_str = "cuda" if current_comfy_device.type == 'cuda' else "cpu" 294 | # Determine dtype based on device (float16 for CUDA, float32 for CPU) 295 | dtype_str = "float16" if device_str == "cuda" and comfy.model_management.should_use_fp16() else "float32" 296 | self._load_pipeline(device_str, dtype_str) 297 | 298 | if self.pipe is None: 299 | raise RuntimeError("ComfyUI-NormalCrafter: Pipeline could not be loaded.") 300 | 301 | # Ensure the pipe instance self.pipe is on the correct device for processing *before* using it 302 | # This is important if self.pipe came from the global cache and might have been on CPU 303 | processing_device = comfy.model_management.get_torch_device() 304 | if self.pipe.device != processing_device: 305 | print(f"ComfyUI-NormalCrafter: Moving self.pipe from {self.pipe.device} to {processing_device} for processing.") 306 | self.pipe.to(processing_device) 307 | 308 | 309 | pil_frames = self.tensor_to_pil_list(images) 310 | if not pil_frames: return (torch.zeros_like(images),) 311 | 312 | resized_pil_frames = self.resize_pil_images(pil_frames, max_res_dimension) 313 | 314 | num_actual_frames = len(resized_pil_frames) 315 | effective_frames_for_pipeline = list(resized_pil_frames) 316 | 317 | if num_actual_frames == 0: 318 | print("ComfyUI-NormalCrafter: Warning - No frames to process after resizing.") 319 | return (torch.zeros_like(images),) 320 | 321 | if num_actual_frames < window_size: 322 | print(f"ComfyUI-NormalCrafter: Number of frames ({num_actual_frames}) is less than window_size ({window_size}). Padding...") 323 | padding_needed = window_size - num_actual_frames 324 | last_frame_to_duplicate = resized_pil_frames[-1] 325 | for _ in range(padding_needed): 326 | effective_frames_for_pipeline.append(last_frame_to_duplicate) 327 | 328 | generator_device = self.pipe.device # Should be processing_device 329 | generator = torch.Generator(device=generator_device).manual_seed(seed) 330 | 331 | print(f"ComfyUI-NormalCrafter: Processing {len(effective_frames_for_pipeline)} frames (effective) with seed {seed}. Original: {num_actual_frames}.") 332 | print(f"ComfyUI-NormalCrafter: Using (internal defaults) fps={default_fps_for_time_ids}, motion_id={default_motion_bucket_id}, noise_aug={default_noise_aug_strength}") 333 | 334 | pbar = comfy.utils.ProgressBar(len(effective_frames_for_pipeline)) # This pbar seems not used by the pipe. 335 | 336 | with torch.inference_mode(): 337 | output_frames_np = self.pipe( # self.pipe should be on processing_device here 338 | images=effective_frames_for_pipeline, 339 | decode_chunk_size=decode_chunk_size, 340 | time_step_size=time_step_size, 341 | window_size=window_size, 342 | fps=default_fps_for_time_ids, 343 | motion_bucket_id=default_motion_bucket_id, 344 | noise_aug_strength=default_noise_aug_strength, 345 | generator=generator 346 | # SVD pipeline has its own progress bar, no need to pass pbar here 347 | ).frames[0] 348 | 349 | if len(effective_frames_for_pipeline) > num_actual_frames: 350 | output_frames_np = output_frames_np[:num_actual_frames, :, :, :] 351 | 352 | output_normals_0_1 = (output_frames_np.clip(-1., 1.) * 0.5) + 0.5 353 | output_tensor = torch.from_numpy(output_normals_0_1).float() 354 | 355 | # --- Explicit Offload After Processing --- 356 | # Only offload the globally managed pipe (NORMALCRAFTER_PIPE), not an overridden one. 357 | # And only if no pipe_override was used for this run. 358 | if pipe_override is None and \ 359 | offload_pipe_to_cpu_on_finish and \ 360 | NORMALCRAFTER_PIPE is not None and \ 361 | NORMALCRAFTER_PIPE.device.type == 'cuda': # Check if it's on CUDA before moving 362 | print("ComfyUI-NormalCrafter: Offloading globally cached pipeline to CPU after processing.") 363 | try: 364 | NORMALCRAFTER_PIPE.to("cpu") 365 | global CURRENT_PIPE_CONFIG # Make sure to get the global 366 | if "device" in CURRENT_PIPE_CONFIG: 367 | CURRENT_PIPE_CONFIG["device"] = "cpu" # Update its state to "cpu" 368 | else: # Should not happen if config is always set 369 | CURRENT_PIPE_CONFIG = {"device": "cpu", "dtype": CURRENT_PIPE_CONFIG.get("dtype", "float32")} 370 | 371 | comfy.model_management.soft_empty_cache() # Ask ComfyUI to try and free VRAM 372 | except Exception as e: 373 | print(f"ComfyUI-NormalCrafter: Error offloading pipeline to CPU: {e}") 374 | 375 | print(f"ComfyUI-NormalCrafter: Processing complete. Output tensor shape: {output_tensor.shape}") 376 | return (output_tensor,) 377 | -------------------------------------------------------------------------------- /normalcrafter/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIWarper/ComfyUI-NormalCrafterWrapper/95046ac2c781c6448796c14bf76196b47a28ae5f/normalcrafter/__init__.py -------------------------------------------------------------------------------- /normalcrafter/normal_crafter_ppl.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Callable, Dict, List, Optional, Union 3 | 4 | import numpy as np 5 | import PIL.Image 6 | import torch 7 | import torch.nn.functional as F 8 | import math 9 | 10 | from diffusers.utils import BaseOutput, logging 11 | from diffusers.utils.torch_utils import is_compiled_module, randn_tensor 12 | from diffusers import DiffusionPipeline 13 | from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import StableVideoDiffusionPipelineOutput, StableVideoDiffusionPipeline 14 | from PIL import Image 15 | import cv2 16 | 17 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 18 | 19 | class NormalCrafterPipeline(StableVideoDiffusionPipeline): 20 | 21 | def _encode_image(self, image, device, num_videos_per_prompt, do_classifier_free_guidance, scale=1, image_size=None): 22 | dtype = next(self.image_encoder.parameters()).dtype 23 | 24 | if not isinstance(image, torch.Tensor): 25 | image = self.video_processor.pil_to_numpy(image) # (0, 255) -> (0, 1) 26 | image = self.video_processor.numpy_to_pt(image) # (n, h, w, c) -> (n, c, h, w) 27 | 28 | # We normalize the image before resizing to match with the original implementation. 29 | # Then we unnormalize it after resizing. 30 | pixel_values = image 31 | B, C, H, W = pixel_values.shape 32 | patches = [pixel_values] 33 | # patches = [] 34 | for i in range(1, scale): 35 | num_patches_HW_this_level = i + 1 36 | patch_H = H // num_patches_HW_this_level + 1 37 | patch_W = W // num_patches_HW_this_level + 1 38 | for j in range(num_patches_HW_this_level): 39 | for k in range(num_patches_HW_this_level): 40 | patches.append(pixel_values[:, :, j*patch_H:(j+1)*patch_H, k*patch_W:(k+1)*patch_W]) 41 | 42 | def encode_image(image): 43 | image = image * 2.0 - 1.0 44 | if image_size is not None: 45 | image = _resize_with_antialiasing(image, image_size) 46 | else: 47 | image = _resize_with_antialiasing(image, (224, 224)) 48 | image = (image + 1.0) / 2.0 49 | 50 | # Normalize the image with for CLIP input 51 | image = self.feature_extractor( 52 | images=image, 53 | do_normalize=True, 54 | do_center_crop=False, 55 | do_resize=False, 56 | do_rescale=False, 57 | return_tensors="pt", 58 | ).pixel_values 59 | 60 | image = image.to(device=device, dtype=dtype) 61 | image_embeddings = self.image_encoder(image).image_embeds 62 | if len(image_embeddings.shape) < 3: 63 | image_embeddings = image_embeddings.unsqueeze(1) 64 | return image_embeddings 65 | 66 | image_embeddings = [] 67 | for patch in patches: 68 | image_embeddings.append(encode_image(patch)) 69 | image_embeddings = torch.cat(image_embeddings, dim=1) 70 | 71 | # duplicate image embeddings for each generation per prompt, using mps friendly method 72 | # import pdb 73 | # pdb.set_trace() 74 | bs_embed, seq_len, _ = image_embeddings.shape 75 | image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1) 76 | image_embeddings = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1) 77 | 78 | if do_classifier_free_guidance: 79 | negative_image_embeddings = torch.zeros_like(image_embeddings) 80 | 81 | # For classifier free guidance, we need to do two forward passes. 82 | # Here we concatenate the unconditional and text embeddings into a single batch 83 | # to avoid doing two forward passes 84 | image_embeddings = torch.cat([negative_image_embeddings, image_embeddings]) 85 | 86 | return image_embeddings 87 | 88 | def ecnode_video_vae(self, images, chunk_size: int = 14): 89 | if isinstance(images, list): 90 | width, height = images[0].size 91 | else: 92 | height, width = images[0].shape[:2] 93 | needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast 94 | if needs_upcasting: 95 | self.vae.to(dtype=torch.float32) 96 | 97 | device = self._execution_device 98 | images = self.video_processor.preprocess_video(images, height=height, width=width).to(device, self.vae.dtype) # torch type in range(-1, 1) with (1,3,h,w) 99 | images = images.squeeze(0) # from (1, c, t, h, w) -> (c, t, h, w) 100 | images = images.permute(1,0,2,3) # c, t, h, w -> (t, c, h, w) 101 | 102 | video_latents = [] 103 | # chunk_size = 14 104 | for i in range(0, images.shape[0], chunk_size): 105 | video_latents.append(self.vae.encode(images[i : i + chunk_size]).latent_dist.mode()) 106 | image_latents = torch.cat(video_latents) 107 | 108 | # cast back to fp16 if needed 109 | if needs_upcasting: 110 | self.vae.to(dtype=torch.float16) 111 | 112 | return image_latents 113 | 114 | def pad_image(self, images, scale=64): 115 | def get_pad(newW, W): 116 | pad_W = (newW - W) // 2 117 | if W % 2 == 1: 118 | pad_Ws = [pad_W, pad_W + 1] 119 | else: 120 | pad_Ws = [pad_W, pad_W] 121 | return pad_Ws 122 | 123 | if type(images[0]) is np.ndarray: 124 | H, W = images[0].shape[:2] 125 | else: 126 | W, H = images[0].size 127 | 128 | if W % scale == 0 and H % scale == 0: 129 | return images, None 130 | newW = int(np.ceil(W / scale) * scale) 131 | newH = int(np.ceil(H / scale) * scale) 132 | 133 | pad_Ws = get_pad(newW, W) 134 | pad_Hs = get_pad(newH, H) 135 | 136 | new_images = [] 137 | for image in images: 138 | if type(image) is np.ndarray: 139 | image = cv2.copyMakeBorder(image, *pad_Hs, *pad_Ws, cv2.BORDER_CONSTANT, value=(1.,1.,1.)) 140 | new_images.append(image) 141 | else: 142 | image = np.array(image) 143 | image = cv2.copyMakeBorder(image, *pad_Hs, *pad_Ws, cv2.BORDER_CONSTANT, value=(255,255,255)) 144 | new_images.append(Image.fromarray(image)) 145 | return new_images, pad_Hs+pad_Ws 146 | 147 | def unpad_image(self, v, pad_HWs): 148 | t, b, l, r = pad_HWs 149 | if t > 0 or b > 0: 150 | v = v[:, :, t:-b] 151 | if l > 0 or r > 0: 152 | v = v[:, :, :, l:-r] 153 | return v 154 | 155 | @torch.no_grad() 156 | def __call__( 157 | self, 158 | images: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor], 159 | decode_chunk_size: Optional[int] = None, 160 | time_step_size: Optional[int] = 1, 161 | window_size: Optional[int] = 1, 162 | fps: int = 7, # <<< CHANGED: Added default, will be overridden by node 163 | motion_bucket_id: int = 127, # <<< CHANGED: Added default, will be overridden by node 164 | noise_aug_strength: float = 0.0, # <<< CHANGED: Added default, will be overridden by node 165 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 166 | return_dict: bool = True 167 | ): 168 | images, pad_HWs = self.pad_image(images) 169 | 170 | # 0. Default height and width to unet 171 | width, height = images[0].size 172 | num_frames = len(images) 173 | 174 | # 1. Check inputs. Raise error if not correct 175 | self.check_inputs(images, height, width) 176 | 177 | # 2. Define call parameters 178 | batch_size = 1 179 | device = self._execution_device 180 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 181 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 182 | # corresponds to doing no classifier free guidance. 183 | self._guidance_scale = 1.0 # NormalCrafter seems to operate without CFG for normals directly 184 | num_videos_per_prompt = 1 185 | do_classifier_free_guidance = False # NormalCrafter specific 186 | num_inference_steps = 1 # For direct normal generation, this is typically 1 step if not iterative refinement 187 | # fps, motion_bucket_id, noise_aug_strength are now passed as arguments 188 | 189 | output_type = "np" # Default output type from SVD pipeline is numpy array 190 | # data_keys = ["normal"] # Not used directly in this simplified flow 191 | use_linear_merge = True 192 | determineTrain = True # This seems to indicate a direct generation rather than denoising from pure noise 193 | encode_image_scale = 1 194 | encode_image_WH = None 195 | 196 | decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else 7 # Default if not provided 197 | 198 | # 3. Encode input image using using clip. (num_image * num_videos_per_prompt, 1, 1024) 199 | image_embeddings = self._encode_image(images, device, num_videos_per_prompt, do_classifier_free_guidance=do_classifier_free_guidance, scale=encode_image_scale, image_size=encode_image_WH) 200 | # 4. Encode input image using VAE 201 | image_latents = self.ecnode_video_vae(images, chunk_size=decode_chunk_size).to(image_embeddings.dtype) 202 | 203 | # image_latents [num_frames, channels, height, width] ->[1, num_frames, channels, height, width] 204 | image_latents = image_latents.unsqueeze(0) 205 | 206 | # 5. Get Added Time IDs 207 | added_time_ids = self._get_add_time_ids( 208 | fps, # <<< USES PASSED-IN VALUE 209 | motion_bucket_id, # <<< USES PASSED-IN VALUE 210 | noise_aug_strength, # <<< USES PASSED-IN VALUE 211 | image_embeddings.dtype, 212 | batch_size, 213 | num_videos_per_prompt, 214 | do_classifier_free_guidance, 215 | ) 216 | added_time_ids = added_time_ids.to(device) 217 | 218 | # get Start and End frame idx for each window 219 | def get_ses(num_frames_local): # Renamed to avoid conflict with outer num_frames 220 | ses = [] 221 | for i in range(0, num_frames_local, time_step_size): 222 | ses.append([i, i + window_size]) 223 | num_to_remain = 0 224 | for se_idx, se_val in enumerate(ses): # Use enumerate for clarity 225 | if se_val[1] > num_frames_local: 226 | # Adjust the last window to not exceed num_frames_local if it's shorter than window_size 227 | if se_idx > 0 and ses[se_idx-1][1] < num_frames_local : # Ensure there's a previous valid window 228 | ses[se_idx] = [num_frames_local - window_size, num_frames_local] 229 | if ses[se_idx][0] < 0: # Handle very short videos 230 | ses[se_idx][0] = 0 231 | num_to_remain = se_idx +1 232 | else: # If this is the only window and it's too short or calculation leads to invalid range 233 | ses[se_idx] = [0, num_frames_local] # Process what's available 234 | num_to_remain = se_idx + 1 235 | 236 | break # Stop adding more windows 237 | num_to_remain += 1 238 | ses = ses[:num_to_remain] 239 | 240 | # Ensure the last window covers the end if not already 241 | if ses and ses[-1][1] < num_frames_local and num_frames_local >= window_size : 242 | new_start = num_frames_local - window_size 243 | if new_start > ses[-1][0]: # Add only if it's a new, valid window 244 | ses.append([new_start, num_frames_local]) 245 | elif new_start < ses[-1][0] and new_start >= 0: # If it overlaps but starts earlier, replace last 246 | ses[-1] = [new_start, num_frames_local] 247 | 248 | 249 | # Handle case where num_frames_local < window_size 250 | if not ses and num_frames_local > 0: 251 | ses.append([0, num_frames_local]) # Process all available frames 252 | 253 | return ses 254 | 255 | ses = get_ses(num_frames) # Use the original num_frames from the input 'images' 256 | 257 | pred = None 258 | for i, se in enumerate(ses): 259 | # Adjust window_num_frames based on the actual slice, esp. for the last potentially shorter window 260 | current_window_num_frames = se[1] - se[0] 261 | window_image_embeddings = image_embeddings[se[0]:se[1]] 262 | window_image_latents = image_latents[:, se[0]:se[1]] 263 | # added_time_ids might need adjustment if its batch size depends on num_frames, 264 | # but SVD usually computes it once for the whole sequence properties (fps, motion_bucket_id) 265 | window_added_time_ids = added_time_ids # Assuming added_time_ids are constant per generation call 266 | 267 | if i == 0 or time_step_size >= window_size: # Corrected condition: no overlap if step >= window 268 | to_replace_latents = None 269 | else: 270 | last_se = ses[i-1] 271 | overlap_start_in_current_window = 0 272 | overlap_start_in_previous_pred = last_se[1] - (window_size - time_step_size) # More robust overlap calculation 273 | num_to_replace_latents = window_size - time_step_size # This is the overlap size 274 | 275 | if num_to_replace_latents > 0 and pred is not None and pred.shape[1] >= (se[0] + num_to_replace_latents) : 276 | # The slice from `pred` needs to correspond to the overlapping part of the *previous* window's output 277 | # that aligns with the *start* of the current window's processing. 278 | # `se[0]` is the global start index of the current window. 279 | # `last_se[1]` is the global end index of the previous window. 280 | # The overlap is from `se[0]` to `last_se[1]`. 281 | # So, from `pred`, we need from `se[0]` up to `num_to_replace_latents` frames into the current window. 282 | to_replace_latents = pred[:, se[0] : se[0] + num_to_replace_latents] 283 | else: 284 | to_replace_latents = None 285 | 286 | 287 | latents_current_window = self.generate( 288 | num_inference_steps, 289 | device, 290 | batch_size, 291 | num_videos_per_prompt, 292 | current_window_num_frames, # Use actual frames in this window 293 | height, # Original height/width after padding 294 | width, # Original height/width after padding 295 | window_image_embeddings, 296 | generator, 297 | determineTrain, 298 | to_replace_latents, # This is for conditioning the start of the current window 299 | do_classifier_free_guidance, 300 | window_image_latents, 301 | window_added_time_ids 302 | ) 303 | 304 | # Merge latents 305 | if pred is None: 306 | pred = latents_current_window 307 | else: 308 | if to_replace_latents is not None and use_linear_merge and time_step_size < window_size: 309 | num_overlap_frames = window_size - time_step_size 310 | weight = torch.linspace(1.0, 0.0, num_overlap_frames + 2, device=device, dtype=latents_current_window.dtype)[1:-1] 311 | weight = weight.view(1, -1, 1, 1, 1) # Reshape for broadcasting 312 | 313 | # The part of `pred` to merge with is from global index `se[0]` for `num_overlap_frames` 314 | pred_overlap_part = pred[:, se[0] : se[0] + num_overlap_frames] 315 | # The part of `latents_current_window` to merge is its beginning 316 | current_window_overlap_part = latents_current_window[:, :num_overlap_frames] 317 | 318 | merged_overlap = pred_overlap_part * weight + current_window_overlap_part * (1.0 - weight) 319 | 320 | # Update `pred` and append the new, non-overlapping part of `latents_current_window` 321 | new_part_of_current_window = latents_current_window[:, num_overlap_frames:] 322 | pred = torch.cat([pred[:, :se[0]], merged_overlap, new_part_of_current_window], dim=1) 323 | else: 324 | # No overlap or no merge, just append the new part 325 | # This assumes pred ends exactly where the new window (minus overlap) begins 326 | pred = torch.cat([pred[:, :se[0] + time_step_size], latents_current_window[:, window_size - time_step_size:] ], dim=1) 327 | # A simpler, but potentially less robust way if windows aren't perfectly aligned: 328 | # pred = torch.cat([pred[:, :se[0]], latents_current_window], dim=1) 329 | # This needs careful handling if just appending parts to avoid duplicated or missing frames. 330 | # For now, let's assume `pred` is built up to `se[0]` and we append `latents_current_window` 331 | # This logic needs to be very precise for stitching. 332 | # A safer approach if `time_step_size` is not less than `window_size`: 333 | if time_step_size >= window_size : 334 | pred = torch.cat([pred, latents_current_window], dim=1) # Non-overlapping windows 335 | else: # Overlapping windows but not merging (or linear merge disabled) - take new window's data 336 | pred = torch.cat([pred[:, :se[0]], latents_current_window], dim=1) 337 | 338 | 339 | 340 | # Ensure pred is not longer than num_frames due to windowing logic 341 | if pred is not None and pred.shape[1] > num_frames: 342 | pred = pred[:, :num_frames] 343 | 344 | 345 | if not output_type == "latent": 346 | needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast 347 | if needs_upcasting: 348 | self.vae.to(dtype=torch.float16) 349 | 350 | def decode_latents(latents_to_decode, num_frames_to_decode, decode_chunk_size_local): 351 | # Ensure latents_to_decode is not None and has frames 352 | if latents_to_decode is None or latents_to_decode.shape[1] == 0: 353 | # Return a dummy tensor or raise an error, based on expected behavior 354 | # For now, let's assume if latents are None, frames should be too. 355 | # This part of SVD pipeline expects latents to be (1, T, C, H, W) 356 | # If num_frames_to_decode is 0, it might also cause issues. 357 | print("Warning: No latents to decode.") 358 | # Create a dummy black frame sequence matching expected output structure 359 | # This depends on what `self.video_processor.postprocess_video` expects. 360 | # Assuming it handles (T, H, W, C) numpy arrays. 361 | # Placeholder: return empty array or array of zeros. 362 | # The original SVD decode_latents doesn't explicitly handle None input. 363 | # Let's try to return a zero array matching expected shape if num_frames_to_decode > 0 364 | if num_frames_to_decode > 0: 365 | dummy_h = latents_to_decode.shape[3] * self.vae_scale_factor if latents_to_decode is not None else 256 # example H 366 | dummy_w = latents_to_decode.shape[4] * self.vae_scale_factor if latents_to_decode is not None else 256 # example W 367 | return np.zeros((num_frames_to_decode, dummy_h, dummy_w, 3), dtype=np.float32) 368 | return np.array([]) 369 | 370 | 371 | frames_decoded = self.decode_latents(latents_to_decode, num_frames_to_decode, decode_chunk_size_local) 372 | frames_decoded = self.video_processor.postprocess_video(video=frames_decoded, output_type="np") 373 | frames_decoded = frames_decoded * 2 - 1 374 | return frames_decoded 375 | 376 | frames = decode_latents(pred, num_frames, decode_chunk_size) 377 | if pad_HWs is not None: 378 | frames = self.unpad_image(frames, pad_HWs) 379 | else: 380 | frames = pred 381 | 382 | self.maybe_free_model_hooks() 383 | 384 | if not return_dict: 385 | return frames 386 | 387 | return StableVideoDiffusionPipelineOutput(frames=frames) 388 | 389 | 390 | def generate( 391 | self, 392 | num_inference_steps, 393 | device, 394 | batch_size, 395 | num_videos_per_prompt, 396 | num_frames, 397 | height, 398 | width, 399 | image_embeddings, 400 | generator, 401 | determineTrain, 402 | to_replace_latents, 403 | do_classifier_free_guidance, 404 | image_latents, 405 | added_time_ids, 406 | latents=None, 407 | ): 408 | # 6. Prepare timesteps 409 | self.scheduler.set_timesteps(num_inference_steps, device=device) 410 | timesteps = self.scheduler.timesteps 411 | 412 | # 7. Prepare latent variables 413 | num_channels_latents = self.unet.config.in_channels 414 | latents = self.prepare_latents( 415 | batch_size * num_videos_per_prompt, 416 | num_frames, 417 | num_channels_latents, 418 | height, 419 | width, 420 | image_embeddings.dtype, 421 | device, 422 | generator, 423 | latents, 424 | ) 425 | if determineTrain: 426 | latents[...] = 0. 427 | 428 | # 8. Denoising loop 429 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 430 | self._num_timesteps = len(timesteps) 431 | with self.progress_bar(total=num_inference_steps) as progress_bar: 432 | for i, t in enumerate(timesteps): 433 | # replace part of latents with conditons. ToDo: t embedding should also replace 434 | if to_replace_latents is not None: 435 | num_img_condition = to_replace_latents.shape[1] 436 | if not determineTrain: 437 | _noise = randn_tensor(to_replace_latents.shape, generator=generator, device=device, dtype=image_embeddings.dtype) 438 | noisy_to_replace_latents = self.scheduler.add_noise(to_replace_latents, _noise, t.unsqueeze(0)) 439 | latents[:, :num_img_condition] = noisy_to_replace_latents 440 | else: 441 | latents[:, :num_img_condition] = to_replace_latents 442 | 443 | 444 | # expand the latents if we are doing classifier free guidance 445 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 446 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 447 | timestep = t 448 | # Concatenate image_latents over channels dimention 449 | latent_model_input = torch.cat([latent_model_input, image_latents], dim=2) 450 | # predict the noise residual 451 | noise_pred = self.unet( 452 | latent_model_input, 453 | timestep, 454 | encoder_hidden_states=image_embeddings, 455 | added_time_ids=added_time_ids, 456 | return_dict=False, 457 | )[0] 458 | 459 | # perform guidance 460 | if do_classifier_free_guidance: 461 | noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) 462 | noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) 463 | 464 | # compute the previous noisy sample x_t -> x_t-1 465 | scheduler_output = self.scheduler.step(noise_pred, t, latents) 466 | latents = scheduler_output.prev_sample 467 | 468 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 469 | progress_bar.update() 470 | 471 | return latents 472 | # resizing utils 473 | # TODO: clean up later 474 | def _resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True): 475 | h, w = input.shape[-2:] 476 | factors = (h / size[0], w / size[1]) 477 | 478 | # First, we have to determine sigma 479 | # Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171 480 | sigmas = ( 481 | max((factors[0] - 1.0) / 2.0, 0.001), 482 | max((factors[1] - 1.0) / 2.0, 0.001), 483 | ) 484 | 485 | # Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma 486 | # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206 487 | # But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now 488 | ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3)) 489 | 490 | # Make sure it is odd 491 | if (ks[0] % 2) == 0: 492 | ks = ks[0] + 1, ks[1] 493 | 494 | if (ks[1] % 2) == 0: 495 | ks = ks[0], ks[1] + 1 496 | 497 | input = _gaussian_blur2d(input, ks, sigmas) 498 | 499 | output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners) 500 | return output 501 | 502 | 503 | def _compute_padding(kernel_size): 504 | """Compute padding tuple.""" 505 | # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom) 506 | # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad 507 | if len(kernel_size) < 2: 508 | raise AssertionError(kernel_size) 509 | computed = [k - 1 for k in kernel_size] 510 | 511 | # for even kernels we need to do asymmetric padding :( 512 | out_padding = 2 * len(kernel_size) * [0] 513 | 514 | for i in range(len(kernel_size)): 515 | computed_tmp = computed[-(i + 1)] 516 | 517 | pad_front = computed_tmp // 2 518 | pad_rear = computed_tmp - pad_front 519 | 520 | out_padding[2 * i + 0] = pad_front 521 | out_padding[2 * i + 1] = pad_rear 522 | 523 | return out_padding 524 | 525 | 526 | def _filter2d(input, kernel): 527 | # prepare kernel 528 | b, c, h, w = input.shape 529 | tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype) 530 | 531 | tmp_kernel = tmp_kernel.expand(-1, c, -1, -1) 532 | 533 | height, width = tmp_kernel.shape[-2:] 534 | 535 | padding_shape: list[int] = _compute_padding([height, width]) 536 | input = torch.nn.functional.pad(input, padding_shape, mode="reflect") 537 | 538 | # kernel and input tensor reshape to align element-wise or batch-wise params 539 | tmp_kernel = tmp_kernel.reshape(-1, 1, height, width) 540 | input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1)) 541 | 542 | # convolve the tensor with the kernel. 543 | output = torch.nn.functional.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1) 544 | 545 | out = output.view(b, c, h, w) 546 | return out 547 | 548 | 549 | def _gaussian(window_size: int, sigma): 550 | if isinstance(sigma, float): 551 | sigma = torch.tensor([[sigma]]) 552 | 553 | batch_size = sigma.shape[0] 554 | 555 | x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1) 556 | 557 | if window_size % 2 == 0: 558 | x = x + 0.5 559 | 560 | gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0))) 561 | 562 | return gauss / gauss.sum(-1, keepdim=True) 563 | 564 | 565 | def _gaussian_blur2d(input, kernel_size, sigma): 566 | if isinstance(sigma, tuple): 567 | sigma = torch.tensor([sigma], dtype=input.dtype) 568 | else: 569 | sigma = sigma.to(dtype=input.dtype) 570 | 571 | ky, kx = int(kernel_size[0]), int(kernel_size[1]) 572 | bs = sigma.shape[0] 573 | kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1)) 574 | kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1)) 575 | out_x = _filter2d(input, kernel_x[..., None, :]) 576 | out = _filter2d(out_x, kernel_y[..., None]) 577 | 578 | return out 579 | -------------------------------------------------------------------------------- /normalcrafter/unet.py: -------------------------------------------------------------------------------- 1 | from diffusers import UNetSpatioTemporalConditionModel 2 | from diffusers.models.unets.unet_spatio_temporal_condition import UNetSpatioTemporalConditionOutput 3 | from diffusers.utils import is_torch_version 4 | import torch 5 | from typing import Any, Dict, Optional, Tuple, Union 6 | 7 | def create_custom_forward(module, return_dict=None): 8 | def custom_forward(*inputs): 9 | if return_dict is not None: 10 | return module(*inputs, return_dict=return_dict) 11 | else: 12 | return module(*inputs) 13 | 14 | return custom_forward 15 | CKPT_KWARGS = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 16 | 17 | 18 | class DiffusersUNetSpatioTemporalConditionModelNormalCrafter(UNetSpatioTemporalConditionModel): 19 | 20 | @staticmethod 21 | def forward_crossattn_down_block_dino( 22 | module, 23 | hidden_states: torch.Tensor, 24 | temb: Optional[torch.Tensor] = None, 25 | encoder_hidden_states: Optional[torch.Tensor] = None, 26 | image_only_indicator: Optional[torch.Tensor] = None, 27 | dino_down_block_res_samples = None, 28 | ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]: 29 | output_states = () 30 | self = module 31 | blocks = list(zip(self.resnets, self.attentions)) 32 | for resnet, attn in blocks: 33 | if self.training and self.gradient_checkpointing: # TODO 34 | hidden_states = torch.utils.checkpoint.checkpoint( 35 | create_custom_forward(resnet), 36 | hidden_states, 37 | temb, 38 | image_only_indicator, 39 | **CKPT_KWARGS, 40 | ) 41 | 42 | hidden_states = torch.utils.checkpoint.checkpoint( 43 | create_custom_forward(attn), 44 | hidden_states, 45 | encoder_hidden_states, 46 | image_only_indicator, 47 | False, 48 | **CKPT_KWARGS, 49 | )[0] 50 | else: 51 | hidden_states = resnet( 52 | hidden_states, 53 | temb, 54 | image_only_indicator=image_only_indicator, 55 | ) 56 | hidden_states = attn( 57 | hidden_states, 58 | encoder_hidden_states=encoder_hidden_states, 59 | image_only_indicator=image_only_indicator, 60 | return_dict=False, 61 | )[0] 62 | 63 | if dino_down_block_res_samples is not None: 64 | hidden_states += dino_down_block_res_samples.pop(0) 65 | 66 | output_states = output_states + (hidden_states,) 67 | 68 | if self.downsamplers is not None: 69 | for downsampler in self.downsamplers: 70 | hidden_states = downsampler(hidden_states) 71 | if dino_down_block_res_samples is not None: 72 | hidden_states += dino_down_block_res_samples.pop(0) 73 | 74 | output_states = output_states + (hidden_states,) 75 | 76 | return hidden_states, output_states 77 | @staticmethod 78 | def forward_down_block_dino( 79 | module, 80 | hidden_states: torch.Tensor, 81 | temb: Optional[torch.Tensor] = None, 82 | image_only_indicator: Optional[torch.Tensor] = None, 83 | dino_down_block_res_samples = None, 84 | ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]: 85 | self = module 86 | output_states = () 87 | for resnet in self.resnets: 88 | if self.training and self.gradient_checkpointing: 89 | if is_torch_version(">=", "1.11.0"): 90 | hidden_states = torch.utils.checkpoint.checkpoint( 91 | create_custom_forward(resnet), 92 | hidden_states, 93 | temb, 94 | image_only_indicator, 95 | use_reentrant=False, 96 | ) 97 | else: 98 | hidden_states = torch.utils.checkpoint.checkpoint( 99 | create_custom_forward(resnet), 100 | hidden_states, 101 | temb, 102 | image_only_indicator, 103 | ) 104 | else: 105 | hidden_states = resnet( 106 | hidden_states, 107 | temb, 108 | image_only_indicator=image_only_indicator, 109 | ) 110 | if dino_down_block_res_samples is not None: 111 | hidden_states += dino_down_block_res_samples.pop(0) 112 | output_states = output_states + (hidden_states,) 113 | 114 | if self.downsamplers is not None: 115 | for downsampler in self.downsamplers: 116 | hidden_states = downsampler(hidden_states) 117 | if dino_down_block_res_samples is not None: 118 | hidden_states += dino_down_block_res_samples.pop(0) 119 | output_states = output_states + (hidden_states,) 120 | 121 | return hidden_states, output_states 122 | 123 | 124 | def forward( 125 | self, 126 | sample: torch.FloatTensor, 127 | timestep: Union[torch.Tensor, float, int], 128 | encoder_hidden_states: torch.Tensor, 129 | added_time_ids: torch.Tensor, 130 | return_dict: bool = True, 131 | image_controlnet_down_block_res_samples = None, 132 | image_controlnet_mid_block_res_sample = None, 133 | dino_down_block_res_samples = None, 134 | 135 | ) -> Union[UNetSpatioTemporalConditionOutput, Tuple]: 136 | r""" 137 | The [`UNetSpatioTemporalConditionModel`] forward method. 138 | 139 | Args: 140 | sample (`torch.FloatTensor`): 141 | The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`. 142 | timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. 143 | encoder_hidden_states (`torch.FloatTensor`): 144 | The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`. 145 | added_time_ids: (`torch.FloatTensor`): 146 | The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal 147 | embeddings and added to the time embeddings. 148 | return_dict (`bool`, *optional*, defaults to `True`): 149 | Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead of a plain 150 | tuple. 151 | Returns: 152 | [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`: 153 | If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is returned, otherwise 154 | a `tuple` is returned where the first element is the sample tensor. 155 | """ 156 | if not hasattr(self, "custom_gradient_checkpointing"): 157 | self.custom_gradient_checkpointing = False 158 | 159 | # 1. time 160 | timesteps = timestep 161 | if not torch.is_tensor(timesteps): 162 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can 163 | # This would be a good case for the `match` statement (Python 3.10+) 164 | is_mps = sample.device.type == "mps" 165 | if isinstance(timestep, float): 166 | dtype = torch.float32 if is_mps else torch.float64 167 | else: 168 | dtype = torch.int32 if is_mps else torch.int64 169 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) 170 | elif len(timesteps.shape) == 0: 171 | timesteps = timesteps[None].to(sample.device) 172 | 173 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 174 | batch_size, num_frames = sample.shape[:2] 175 | if len(timesteps.shape) == 1: 176 | timesteps = timesteps.expand(batch_size) 177 | else: 178 | timesteps = timesteps.reshape(batch_size * num_frames) 179 | t_emb = self.time_proj(timesteps) # (B, C) 180 | 181 | # `Timesteps` does not contain any weights and will always return f32 tensors 182 | # but time_embedding might actually be running in fp16. so we need to cast here. 183 | # there might be better ways to encapsulate this. 184 | t_emb = t_emb.to(dtype=sample.dtype) 185 | 186 | emb = self.time_embedding(t_emb) # (B, C) 187 | 188 | time_embeds = self.add_time_proj(added_time_ids.flatten()) 189 | time_embeds = time_embeds.reshape((batch_size, -1)) 190 | time_embeds = time_embeds.to(emb.dtype) 191 | aug_emb = self.add_embedding(time_embeds) 192 | if emb.shape[0] == 1: 193 | emb = emb + aug_emb 194 | # Repeat the embeddings num_video_frames times 195 | # emb: [batch, channels] -> [batch * frames, channels] 196 | emb = emb.repeat_interleave(num_frames, dim=0) 197 | else: 198 | aug_emb = aug_emb.repeat_interleave(num_frames, dim=0) 199 | emb = emb + aug_emb 200 | 201 | # Flatten the batch and frames dimensions 202 | # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width] 203 | sample = sample.flatten(0, 1) 204 | 205 | # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels] 206 | # here, our encoder_hidden_states is [batch * frames, 1, channels] 207 | 208 | if not sample.shape[0] == encoder_hidden_states.shape[0]: 209 | encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0) 210 | # 2. pre-process 211 | sample = self.conv_in(sample) 212 | 213 | image_only_indicator = torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device) 214 | 215 | if dino_down_block_res_samples is not None: 216 | dino_down_block_res_samples = [x for x in dino_down_block_res_samples] 217 | sample += dino_down_block_res_samples.pop(0) 218 | 219 | down_block_res_samples = (sample,) 220 | for downsample_block in self.down_blocks: 221 | if dino_down_block_res_samples is None: 222 | if self.custom_gradient_checkpointing: 223 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: 224 | sample, res_samples = torch.utils.checkpoint.checkpoint( 225 | create_custom_forward(downsample_block), 226 | sample, 227 | emb, 228 | encoder_hidden_states, 229 | image_only_indicator, 230 | **CKPT_KWARGS, 231 | ) 232 | else: 233 | sample, res_samples = torch.utils.checkpoint.checkpoint( 234 | create_custom_forward(downsample_block), 235 | sample, 236 | emb, 237 | image_only_indicator, 238 | **CKPT_KWARGS, 239 | ) 240 | else: 241 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: 242 | sample, res_samples = downsample_block( 243 | hidden_states=sample, 244 | temb=emb, 245 | encoder_hidden_states=encoder_hidden_states, 246 | image_only_indicator=image_only_indicator, 247 | ) 248 | else: 249 | sample, res_samples = downsample_block( 250 | hidden_states=sample, 251 | temb=emb, 252 | image_only_indicator=image_only_indicator, 253 | ) 254 | else: 255 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: 256 | sample, res_samples = self.forward_crossattn_down_block_dino( 257 | downsample_block, 258 | sample, 259 | emb, 260 | encoder_hidden_states, 261 | image_only_indicator, 262 | dino_down_block_res_samples, 263 | ) 264 | else: 265 | sample, res_samples = self.forward_down_block_dino( 266 | downsample_block, 267 | sample, 268 | emb, 269 | image_only_indicator, 270 | dino_down_block_res_samples, 271 | ) 272 | down_block_res_samples += res_samples 273 | 274 | if image_controlnet_down_block_res_samples is not None: 275 | new_down_block_res_samples = () 276 | 277 | for down_block_res_sample, image_controlnet_down_block_res_sample in zip( 278 | down_block_res_samples, image_controlnet_down_block_res_samples 279 | ): 280 | down_block_res_sample = (down_block_res_sample + image_controlnet_down_block_res_sample) / 2 281 | new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) 282 | 283 | down_block_res_samples = new_down_block_res_samples 284 | 285 | # 4. mid 286 | if self.custom_gradient_checkpointing: 287 | sample = torch.utils.checkpoint.checkpoint( 288 | create_custom_forward(self.mid_block), 289 | sample, 290 | emb, 291 | encoder_hidden_states, 292 | image_only_indicator, 293 | **CKPT_KWARGS, 294 | ) 295 | else: 296 | sample = self.mid_block( 297 | hidden_states=sample, 298 | temb=emb, 299 | encoder_hidden_states=encoder_hidden_states, 300 | image_only_indicator=image_only_indicator, 301 | ) 302 | 303 | if image_controlnet_mid_block_res_sample is not None: 304 | sample = (sample + image_controlnet_mid_block_res_sample) / 2 305 | 306 | # 5. up 307 | mid_up_block_out_samples = [sample, ] 308 | down_block_out_sampels = [] 309 | for i, upsample_block in enumerate(self.up_blocks): 310 | res_samples = down_block_res_samples[-len(upsample_block.resnets) :] 311 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] 312 | down_block_out_sampels.append(res_samples[-1]) 313 | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: 314 | if self.custom_gradient_checkpointing: 315 | sample = torch.utils.checkpoint.checkpoint( 316 | create_custom_forward(upsample_block), 317 | sample, 318 | res_samples, 319 | emb, 320 | encoder_hidden_states, 321 | image_only_indicator, 322 | **CKPT_KWARGS 323 | ) 324 | else: 325 | sample = upsample_block( 326 | hidden_states=sample, 327 | temb=emb, 328 | res_hidden_states_tuple=res_samples, 329 | encoder_hidden_states=encoder_hidden_states, 330 | image_only_indicator=image_only_indicator, 331 | ) 332 | else: 333 | if self.custom_gradient_checkpointing: 334 | sample = torch.utils.checkpoint.checkpoint( 335 | create_custom_forward(upsample_block), 336 | sample, 337 | res_samples, 338 | emb, 339 | image_only_indicator, 340 | **CKPT_KWARGS 341 | ) 342 | else: 343 | sample = upsample_block( 344 | hidden_states=sample, 345 | temb=emb, 346 | res_hidden_states_tuple=res_samples, 347 | image_only_indicator=image_only_indicator, 348 | ) 349 | mid_up_block_out_samples.append(sample) 350 | # 6. post-process 351 | sample = self.conv_norm_out(sample) 352 | sample = self.conv_act(sample) 353 | if self.custom_gradient_checkpointing: 354 | sample = torch.utils.checkpoint.checkpoint( 355 | create_custom_forward(self.conv_out), 356 | sample, 357 | **CKPT_KWARGS 358 | ) 359 | else: 360 | sample = self.conv_out(sample) 361 | 362 | # 7. Reshape back to original shape 363 | sample = sample.reshape(batch_size, num_frames, *sample.shape[1:]) 364 | 365 | if not return_dict: 366 | return (sample, down_block_out_sampels[::-1], mid_up_block_out_samples) 367 | 368 | return UNetSpatioTemporalConditionOutput(sample=sample) -------------------------------------------------------------------------------- /normalcrafter/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Union, List 2 | import tempfile 3 | import numpy as np 4 | import PIL.Image 5 | import matplotlib.cm as cm 6 | import mediapy 7 | import torch 8 | from decord import VideoReader, cpu 9 | 10 | 11 | def read_video_frames(video_path, process_length, target_fps, max_res): 12 | print("==> processing video: ", video_path) 13 | vid = VideoReader(video_path, ctx=cpu(0)) 14 | print("==> original video shape: ", (len(vid), *vid.get_batch([0]).shape[1:])) 15 | original_height, original_width = vid.get_batch([0]).shape[1:3] 16 | 17 | if max(original_height, original_width) > max_res: 18 | scale = max_res / max(original_height, original_width) 19 | height = round(original_height * scale) 20 | width = round(original_width * scale) 21 | else: 22 | height = original_height 23 | width = original_width 24 | 25 | vid = VideoReader(video_path, ctx=cpu(0), width=width, height=height) 26 | 27 | fps = vid.get_avg_fps() if target_fps == -1 else target_fps 28 | stride = round(vid.get_avg_fps() / fps) 29 | stride = max(stride, 1) 30 | frames_idx = list(range(0, len(vid), stride)) 31 | print( 32 | f"==> downsampled shape: {len(frames_idx), *vid.get_batch([0]).shape[1:]}, with stride: {stride}" 33 | ) 34 | if process_length != -1 and process_length < len(frames_idx): 35 | frames_idx = frames_idx[:process_length] 36 | print( 37 | f"==> final processing shape: {len(frames_idx), *vid.get_batch([0]).shape[1:]}" 38 | ) 39 | frames = vid.get_batch(frames_idx).asnumpy().astype(np.uint8) 40 | frames = [PIL.Image.fromarray(x) for x in frames] 41 | 42 | return frames, fps 43 | 44 | def save_video( 45 | video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], 46 | output_video_path: str = None, 47 | fps: int = 10, 48 | crf: int = 18, 49 | ) -> str: 50 | if output_video_path is None: 51 | output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name 52 | 53 | if isinstance(video_frames[0], np.ndarray): 54 | video_frames = [(frame * 255).astype(np.uint8) for frame in video_frames] 55 | 56 | elif isinstance(video_frames[0], PIL.Image.Image): 57 | video_frames = [np.array(frame) for frame in video_frames] 58 | mediapy.write_video(output_video_path, video_frames, fps=fps, crf=crf) 59 | return output_video_path 60 | 61 | def vis_sequence_normal(normals: np.ndarray): 62 | normals = normals.clip(-1., 1.) 63 | normals = normals * 0.5 + 0.5 64 | return normals 65 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | mediapy 2 | decord 3 | 4 | 5 | # diffusers 6 | # If you do not have diffusers installed, you need to install it. 7 | # You can typically install a compatible version by running the following command inside your activated VENV: 8 | # pip install diffusers 9 | # Repeat this process for any other packages you may be missing that show as errors during bootup or runtime --------------------------------------------------------------------------------