├── src ├── common │ ├── __init__.py │ ├── distributed │ │ ├── __init__.py │ │ ├── meta_init_utils.py │ │ ├── basic.py │ │ └── advanced.py │ ├── seed.py │ ├── logger.py │ ├── cache.py │ ├── diffusion │ │ ├── timesteps │ │ │ ├── sampling │ │ │ │ └── trailing.py │ │ │ └── base.py │ │ ├── schedules │ │ │ ├── lerp.py │ │ │ └── base.py │ │ ├── __init__.py │ │ ├── types.py │ │ ├── config.py │ │ ├── utils.py │ │ └── samplers │ │ │ ├── base.py │ │ │ └── euler.py │ ├── partition.py │ ├── half_precision_fixes.py │ ├── decorators.py │ └── config.py ├── utils │ ├── __init__.py │ ├── constants.py │ ├── downloads.py │ └── model_registry.py ├── interfaces │ └── __init__.py ├── models │ ├── video_vae_v3 │ │ ├── s8_c16_t4_inflation_sd3.yaml │ │ └── modules │ │ │ ├── global_config.py │ │ │ ├── context_parallel_lib.py │ │ │ ├── types.py │ │ │ ├── inflated_layers.py │ │ │ └── inflated_lib.py │ ├── dit_v2 │ │ ├── patch │ │ │ ├── __init__.py │ │ │ └── patch_v1.py │ │ ├── nablocks │ │ │ ├── attention │ │ │ │ └── __init__.py │ │ │ ├── __init__.py │ │ │ └── mmsr_block.py │ │ ├── mlp.py │ │ ├── embedding.py │ │ ├── mm.py │ │ ├── window.py │ │ ├── attention.py │ │ ├── modulation.py │ │ ├── normalization.py │ │ └── rope.py │ └── dit │ │ ├── nablocks │ │ └── __init__.py │ │ ├── blocks │ │ └── __init__.py │ │ ├── mlp.py │ │ ├── embedding.py │ │ ├── mm.py │ │ ├── window.py │ │ ├── patch.py │ │ ├── rope.py │ │ ├── attention.py │ │ ├── modulation.py │ │ └── normalization.py ├── optimization │ ├── __init__.py │ └── performance.py ├── core │ └── __init__.py ├── data │ └── image │ │ └── transforms │ │ ├── divisible_crop.py │ │ ├── na_resize.py │ │ ├── side_resize.py │ │ └── area_resize.py └── __init__.py ├── neg_emb.pt ├── pos_emb.pt ├── docs ├── node.png ├── usage.png ├── demo_01.jpg ├── demo_02.jpg └── BlockSwap.png ├── requirements.txt ├── .gitignore ├── __init__.py ├── LICENSE ├── CONTRIBUTING.md ├── configs_7b └── main.yaml ├── configs_3b └── main.yaml └── example_workflows ├── SeedVR2_Image_Upscaling.json └── SeedVR2_Video_Upscaling.json /src/common/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /neg_emb.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lihaoyun6/ComfyUI-SeedVR2_VideoUpscaler/HEAD/neg_emb.pt -------------------------------------------------------------------------------- /pos_emb.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lihaoyun6/ComfyUI-SeedVR2_VideoUpscaler/HEAD/pos_emb.pt -------------------------------------------------------------------------------- /docs/node.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lihaoyun6/ComfyUI-SeedVR2_VideoUpscaler/HEAD/docs/node.png -------------------------------------------------------------------------------- /docs/usage.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lihaoyun6/ComfyUI-SeedVR2_VideoUpscaler/HEAD/docs/usage.png -------------------------------------------------------------------------------- /docs/demo_01.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lihaoyun6/ComfyUI-SeedVR2_VideoUpscaler/HEAD/docs/demo_01.jpg -------------------------------------------------------------------------------- /docs/demo_02.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lihaoyun6/ComfyUI-SeedVR2_VideoUpscaler/HEAD/docs/demo_02.jpg -------------------------------------------------------------------------------- /docs/BlockSwap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lihaoyun6/ComfyUI-SeedVR2_VideoUpscaler/HEAD/docs/BlockSwap.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops 2 | omegaconf>=2.3.0 3 | diffusers>=v0.33.1 4 | pytorch-extension 5 | rotary_embedding_torch>=0.5.3 6 | opencv-python 7 | peft>=0.15.0 8 | gguf 9 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities package for SeedVR2 3 | Contains general utility functions like downloads, config loading, etc. 4 | """ 5 | ''' 6 | from .downloads import download_weight 7 | 8 | __all__ = [ 9 | "download_weight", 10 | ] 11 | ''' -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .git/* 2 | **/__pycache__/ 3 | tests/ 4 | .vscode/ 5 | .cursor/ 6 | benchmark/ 7 | advanced_optimizations.py 8 | BENCHMARK_* 9 | environment.yml 10 | install_safetensors.py 11 | manage_quantized_models.py 12 | quantization_config.py 13 | quantize_model.py 14 | quantize_vae.py 15 | README_*.md 16 | run_*.py 17 | vram_diagnostic.py 18 | test_*.py 19 | VRAM_OPTIMIZATIONS_SUMMARY.md 20 | seedvr2.py 21 | src/core/isolated_generation.py 22 | src/core/subprocess_runner.py 23 | models/video_vae_v3_mine_bad/ 24 | src/processing/ 25 | TILE_VAE* 26 | .DS_Store -------------------------------------------------------------------------------- /src/interfaces/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Interfaces package for SeedVR2 3 | Contains user interface integrations (ComfyUI, etc.) 4 | """ 5 | 6 | # ComfyUI Interfaces Module 7 | # Handles ComfyUI node integration and user interface 8 | 9 | from .comfyui_node import ( 10 | # Main ComfyUI node class 11 | SeedVR2, 12 | 13 | # ComfyUI mappings 14 | NODE_CLASS_MAPPINGS, 15 | NODE_DISPLAY_NAME_MAPPINGS, 16 | 17 | ) 18 | 19 | __all__ = [ 20 | # Core node interface 21 | 'SeedVR2', 22 | 23 | # ComfyUI mappings 24 | 'NODE_CLASS_MAPPINGS', 25 | 'NODE_DISPLAY_NAME_MAPPINGS', 26 | 27 | ] -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | SeedVR2 Video Upscaler - Transition progressive vers architecture modulaire 3 | 4 | Ce fichier gère la transition entre: 5 | - Ancien code monolithique (seedvr2.py) 6 | - Nouvelle architecture modulaire (src/) 7 | 8 | Migration en cours... 9 | """ 10 | 11 | # 🆕 TENTATIVE: Nouvelle architecture modulaire 12 | from .src.interfaces.comfyui_node import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS 13 | USING_MODULAR = True 14 | 15 | 16 | # Export pour ComfyUI 17 | __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] 18 | 19 | # Métadonnées 20 | __version__ = "1.5.0-transition" if not USING_MODULAR else "2.0.0-modular" 21 | -------------------------------------------------------------------------------- /src/models/video_vae_v3/s8_c16_t4_inflation_sd3.yaml: -------------------------------------------------------------------------------- 1 | act_fn: silu 2 | block_out_channels: 3 | - 128 4 | - 256 5 | - 512 6 | - 512 7 | down_block_types: 8 | - DownEncoderBlock3D 9 | - DownEncoderBlock3D 10 | - DownEncoderBlock3D 11 | - DownEncoderBlock3D 12 | in_channels: 3 13 | latent_channels: 16 14 | layers_per_block: 2 15 | norm_num_groups: 32 16 | out_channels: 3 17 | slicing_sample_min_size: 4 18 | temporal_scale_num: 2 19 | inflation_mode: pad 20 | up_block_types: 21 | - UpDecoderBlock3D 22 | - UpDecoderBlock3D 23 | - UpDecoderBlock3D 24 | - UpDecoderBlock3D 25 | spatial_downsample_factor: 8 26 | temporal_downsample_factor: 4 27 | use_quant_conv: False 28 | use_post_quant_conv: False 29 | -------------------------------------------------------------------------------- /src/models/dit_v2/patch/__init__.py: -------------------------------------------------------------------------------- 1 | # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 2 | # // 3 | # // Licensed under the Apache License, Version 2.0 (the "License"); 4 | # // you may not use this file except in compliance with the License. 5 | # // You may obtain a copy of the License at 6 | # // 7 | # // http://www.apache.org/licenses/LICENSE-2.0 8 | # // 9 | # // Unless required by applicable law or agreed to in writing, software 10 | # // distributed under the License is distributed on an "AS IS" BASIS, 11 | # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # // See the License for the specific language governing permissions and 13 | # // limitations under the License. 14 | 15 | def get_na_patch_layers(patch_type="v1"): 16 | assert patch_type in ["v1"] 17 | if patch_type == "v1": 18 | from .patch_v1 import NaPatchIn, NaPatchOut 19 | return NaPatchIn, NaPatchOut 20 | -------------------------------------------------------------------------------- /src/models/dit_v2/nablocks/attention/__init__.py: -------------------------------------------------------------------------------- 1 | # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 2 | # // 3 | # // Licensed under the Apache License, Version 2.0 (the "License"); 4 | # // you may not use this file except in compliance with the License. 5 | # // You may obtain a copy of the License at 6 | # // 7 | # // http://www.apache.org/licenses/LICENSE-2.0 8 | # // 9 | # // Unless required by applicable law or agreed to in writing, software 10 | # // distributed under the License is distributed on an "AS IS" BASIS, 11 | # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # // See the License for the specific language governing permissions and 13 | # // limitations under the License. 14 | 15 | from .mmattn import NaMMAttention 16 | 17 | attns = { 18 | "mm_full": NaMMAttention, 19 | } 20 | 21 | 22 | def get_attn(attn_type: str): 23 | if attn_type in attns: 24 | return attns[attn_type] 25 | raise NotImplementedError(f"{attn_type} is not supported") 26 | -------------------------------------------------------------------------------- /src/models/dit/nablocks/__init__.py: -------------------------------------------------------------------------------- 1 | # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 2 | # // 3 | # // Licensed under the Apache License, Version 2.0 (the "License"); 4 | # // you may not use this file except in compliance with the License. 5 | # // You may obtain a copy of the License at 6 | # // 7 | # // http://www.apache.org/licenses/LICENSE-2.0 8 | # // 9 | # // Unless required by applicable law or agreed to in writing, software 10 | # // distributed under the License is distributed on an "AS IS" BASIS, 11 | # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # // See the License for the specific language governing permissions and 13 | # // limitations under the License. 14 | 15 | from .mmsr_block import NaMMSRTransformerBlock 16 | 17 | nadit_blocks = { 18 | "mmdit_sr": NaMMSRTransformerBlock, 19 | } 20 | 21 | 22 | def get_nablock(block_type: str): 23 | if block_type in nadit_blocks: 24 | return nadit_blocks[block_type] 25 | raise NotImplementedError(f"{block_type} is not supported") 26 | -------------------------------------------------------------------------------- /src/models/video_vae_v3/modules/global_config.py: -------------------------------------------------------------------------------- 1 | # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 2 | # // 3 | # // Licensed under the Apache License, Version 2.0 (the "License"); 4 | # // you may not use this file except in compliance with the License. 5 | # // You may obtain a copy of the License at 6 | # // 7 | # // http://www.apache.org/licenses/LICENSE-2.0 8 | # // 9 | # // Unless required by applicable law or agreed to in writing, software 10 | # // distributed under the License is distributed on an "AS IS" BASIS, 11 | # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # // See the License for the specific language governing permissions and 13 | # // limitations under the License. 14 | 15 | from typing import Optional 16 | 17 | _NORM_LIMIT = float("inf") 18 | 19 | 20 | def get_norm_limit(): 21 | return _NORM_LIMIT 22 | 23 | 24 | def set_norm_limit(value: Optional[float] = None): 25 | global _NORM_LIMIT 26 | if value is None: 27 | value = float("inf") 28 | _NORM_LIMIT = value 29 | -------------------------------------------------------------------------------- /src/models/dit/blocks/__init__.py: -------------------------------------------------------------------------------- 1 | # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 2 | # // 3 | # // Licensed under the Apache License, Version 2.0 (the "License"); 4 | # // you may not use this file except in compliance with the License. 5 | # // You may obtain a copy of the License at 6 | # // 7 | # // http://www.apache.org/licenses/LICENSE-2.0 8 | # // 9 | # // Unless required by applicable law or agreed to in writing, software 10 | # // distributed under the License is distributed on an "AS IS" BASIS, 11 | # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # // See the License for the specific language governing permissions and 13 | # // limitations under the License. 14 | 15 | from .mmdit_window_block import MMWindowTransformerBlock 16 | 17 | dit_blocks = { 18 | "mmdit_window": MMWindowTransformerBlock, 19 | } 20 | 21 | 22 | def get_block(block_type: str): 23 | if block_type in dit_blocks: 24 | return dit_blocks[block_type] 25 | raise NotImplementedError(f"{block_type} is not supported") 26 | -------------------------------------------------------------------------------- /src/models/dit_v2/nablocks/__init__.py: -------------------------------------------------------------------------------- 1 | # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 2 | # // 3 | # // Licensed under the Apache License, Version 2.0 (the "License"); 4 | # // you may not use this file except in compliance with the License. 5 | # // You may obtain a copy of the License at 6 | # // 7 | # // http://www.apache.org/licenses/LICENSE-2.0 8 | # // 9 | # // Unless required by applicable law or agreed to in writing, software 10 | # // distributed under the License is distributed on an "AS IS" BASIS, 11 | # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # // See the License for the specific language governing permissions and 13 | # // limitations under the License. 14 | 15 | from .mmsr_block import NaMMSRTransformerBlock 16 | 17 | 18 | nadit_blocks = { 19 | "mmdit_sr": NaMMSRTransformerBlock, 20 | } 21 | 22 | 23 | def get_nablock(block_type: str): 24 | if block_type in nadit_blocks: 25 | return nadit_blocks[block_type] 26 | raise NotImplementedError(f"{block_type} is not supported") 27 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 the comfyui-FlowChain 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/optimization/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Optimization package for SeedVR2 3 | Contains memory management, performance optimizations, and compatibility layers 4 | """ 5 | ''' 6 | # Memory management functions 7 | from .memory_manager import ( 8 | get_vram_usage, 9 | clear_vram_cache, 10 | reset_vram_peak, 11 | preinitialize_rope_cache, 12 | ) 13 | 14 | # Performance optimization functions 15 | from .performance import ( 16 | optimized_video_rearrange, 17 | optimized_single_video_rearrange, 18 | optimized_sample_to_image_format, 19 | temporal_latent_blending, 20 | ) 21 | 22 | # Compatibility functions and classes 23 | from .compatibility import ( 24 | FP8CompatibleDiT, 25 | ) 26 | 27 | __all__ = [ 28 | # Memory management 29 | "get_vram_usage", 30 | "clear_vram_cache", 31 | "reset_vram_peak", 32 | "preinitialize_rope_cache", 33 | 34 | # Performance optimization 35 | "optimized_video_rearrange", 36 | "optimized_single_video_rearrange", 37 | "optimized_sample_to_image_format", 38 | "temporal_latent_blending", 39 | 40 | # Compatibility 41 | "FP8CompatibleDiT", 42 | ] 43 | ''' -------------------------------------------------------------------------------- /src/common/distributed/__init__.py: -------------------------------------------------------------------------------- 1 | # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 2 | # // 3 | # // Licensed under the Apache License, Version 2.0 (the "License"); 4 | # // you may not use this file except in compliance with the License. 5 | # // You may obtain a copy of the License at 6 | # // 7 | # // http://www.apache.org/licenses/LICENSE-2.0 8 | # // 9 | # // Unless required by applicable law or agreed to in writing, software 10 | # // distributed under the License is distributed on an "AS IS" BASIS, 11 | # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # // See the License for the specific language governing permissions and 13 | # // limitations under the License. 14 | 15 | """ 16 | Distributed package. 17 | """ 18 | 19 | from .basic import ( 20 | barrier_if_distributed, 21 | convert_to_ddp, 22 | get_device, 23 | get_global_rank, 24 | get_local_rank, 25 | get_world_size, 26 | init_torch, 27 | ) 28 | 29 | __all__ = [ 30 | "barrier_if_distributed", 31 | "convert_to_ddp", 32 | "get_device", 33 | "get_global_rank", 34 | "get_local_rank", 35 | "get_world_size", 36 | "init_torch", 37 | ] 38 | -------------------------------------------------------------------------------- /src/common/seed.py: -------------------------------------------------------------------------------- 1 | # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 2 | # // 3 | # // Licensed under the Apache License, Version 2.0 (the "License"); 4 | # // you may not use this file except in compliance with the License. 5 | # // You may obtain a copy of the License at 6 | # // 7 | # // http://www.apache.org/licenses/LICENSE-2.0 8 | # // 9 | # // Unless required by applicable law or agreed to in writing, software 10 | # // distributed under the License is distributed on an "AS IS" BASIS, 11 | # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # // See the License for the specific language governing permissions and 13 | # // limitations under the License. 14 | 15 | import random 16 | from typing import Optional 17 | import numpy as np 18 | import torch 19 | 20 | from .distributed import get_global_rank 21 | 22 | 23 | def set_seed(seed: Optional[int], same_across_ranks: bool = False): 24 | """Function that sets the seed for pseudo-random number generators.""" 25 | if seed is not None: 26 | seed += get_global_rank() if not same_across_ranks else 0 27 | random.seed(seed) 28 | np.random.seed(seed) 29 | torch.manual_seed(seed) 30 | 31 | -------------------------------------------------------------------------------- /src/core/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Core Module for SeedVR2 3 | 4 | Contains the main business logic and model management functionality: 5 | - Model configuration and loading 6 | - Architecture detection and memory estimation 7 | - Runner creation and management 8 | - Generation pipeline and logic 9 | """ 10 | ''' 11 | from .model_manager import ( 12 | configure_runner, 13 | load_quantized_state_dict, 14 | configure_dit_model_inference, 15 | configure_vae_model_inference, 16 | ) 17 | 18 | from .generation import ( 19 | generation_step, 20 | generation_loop, 21 | cut_videos, 22 | prepare_video_transforms, 23 | load_text_embeddings, 24 | calculate_optimal_batch_params 25 | ) 26 | 27 | from .infer import VideoDiffusionInfer 28 | 29 | __all__ = [ 30 | # Model management 31 | 'configure_runner', 32 | 'load_quantized_state_dict', 33 | 'configure_dit_model_inference', 34 | 'configure_vae_model_inference', 35 | 36 | # Generation logic 37 | 'generation_step', 38 | 'generation_loop', 39 | 'cut_videos', 40 | 'prepare_video_transforms', 41 | 'load_text_embeddings', 42 | 'calculate_optimal_batch_params', 43 | 44 | # Infer 45 | 'VideoDiffusionInfer' 46 | ] 47 | ''' -------------------------------------------------------------------------------- /src/utils/constants.py: -------------------------------------------------------------------------------- 1 | """ 2 | Shared constants and utilities for SeedVR2 3 | Only includes constants actually used in the codebase 4 | """ 5 | 6 | import os 7 | 8 | # Model folder names 9 | SEEDVR2_FOLDER_NAME = "SEEDVR2" # Physical folder name on disk 10 | SEEDVR2_MODEL_TYPE = "seedvr2" # Model type identifier for ComfyUI 11 | 12 | # Supported model file formats 13 | #SUPPORTED_MODEL_EXTENSIONS = {'.safetensors', '.gguf'} 14 | SUPPORTED_MODEL_EXTENSIONS = {'.safetensors'} 15 | 16 | def get_script_directory(): 17 | """Get the root script directory path (3 levels up from this file)""" 18 | return os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 19 | 20 | def get_base_cache_dir(): 21 | """Get or create the model cache directory""" 22 | try: 23 | import folder_paths # only works if comfyui is available 24 | cache_dir = os.path.join(folder_paths.models_dir, SEEDVR2_FOLDER_NAME) 25 | folder_paths.add_model_folder_path(SEEDVR2_MODEL_TYPE, cache_dir) 26 | except: 27 | cache_dir = f"./{SEEDVR2_MODEL_TYPE}_models" 28 | 29 | os.makedirs(cache_dir, exist_ok=True) 30 | return cache_dir 31 | 32 | def is_supported_model_file(filename: str) -> bool: 33 | """Check if a file has a supported model extension""" 34 | return any(filename.endswith(ext) for ext in SUPPORTED_MODEL_EXTENSIONS) -------------------------------------------------------------------------------- /src/common/logger.py: -------------------------------------------------------------------------------- 1 | # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 2 | # // 3 | # // Licensed under the Apache License, Version 2.0 (the "License"); 4 | # // you may not use this file except in compliance with the License. 5 | # // You may obtain a copy of the License at 6 | # // 7 | # // http://www.apache.org/licenses/LICENSE-2.0 8 | # // 9 | # // Unless required by applicable law or agreed to in writing, software 10 | # // distributed under the License is distributed on an "AS IS" BASIS, 11 | # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # // See the License for the specific language governing permissions and 13 | # // limitations under the License. 14 | 15 | """ 16 | Logging utility functions. 17 | """ 18 | 19 | import logging 20 | import sys 21 | from typing import Optional 22 | 23 | from .distributed import get_global_rank, get_local_rank, get_world_size 24 | 25 | _default_handler = logging.StreamHandler(sys.stdout) 26 | _default_handler.setFormatter( 27 | logging.Formatter( 28 | "%(asctime)s " 29 | + (f"[Rank:{get_global_rank()}]" if get_world_size() > 1 else "") 30 | + (f"[LocalRank:{get_local_rank()}]" if get_world_size() > 1 else "") 31 | + "[%(threadName).12s][%(name)s][%(levelname).5s] " 32 | + "%(message)s" 33 | ) 34 | ) 35 | 36 | 37 | def get_logger(name: Optional[str] = None) -> logging.Logger: 38 | """ 39 | Get a logger. 40 | """ 41 | logger = logging.getLogger(name) 42 | logger.addHandler(_default_handler) 43 | logger.setLevel(logging.INFO) 44 | return logger 45 | -------------------------------------------------------------------------------- /src/data/image/transforms/divisible_crop.py: -------------------------------------------------------------------------------- 1 | # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 2 | # // 3 | # // Licensed under the Apache License, Version 2.0 (the "License"); 4 | # // you may not use this file except in compliance with the License. 5 | # // You may obtain a copy of the License at 6 | # // 7 | # // http://www.apache.org/licenses/LICENSE-2.0 8 | # // 9 | # // Unless required by applicable law or agreed to in writing, software 10 | # // distributed under the License is distributed on an "AS IS" BASIS, 11 | # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # // See the License for the specific language governing permissions and 13 | # // limitations under the License. 14 | 15 | from typing import Union 16 | import torch 17 | from PIL import Image 18 | from torchvision.transforms import functional as TVF 19 | 20 | 21 | class DivisibleCrop: 22 | def __init__(self, factor): 23 | if not isinstance(factor, tuple): 24 | factor = (factor, factor) 25 | 26 | self.height_factor, self.width_factor = factor[0], factor[1] 27 | 28 | def __call__(self, image: Union[torch.Tensor, Image.Image]): 29 | if isinstance(image, torch.Tensor): 30 | height, width = image.shape[-2:] 31 | elif isinstance(image, Image.Image): 32 | width, height = image.size 33 | else: 34 | raise NotImplementedError 35 | 36 | cropped_height = height - (height % self.height_factor) 37 | cropped_width = width - (width % self.width_factor) 38 | 39 | image = TVF.center_crop(img=image, output_size=(cropped_height, cropped_width)) 40 | return image 41 | -------------------------------------------------------------------------------- /src/common/cache.py: -------------------------------------------------------------------------------- 1 | # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 2 | # // 3 | # // Licensed under the Apache License, Version 2.0 (the "License"); 4 | # // you may not use this file except in compliance with the License. 5 | # // You may obtain a copy of the License at 6 | # // 7 | # // http://www.apache.org/licenses/LICENSE-2.0 8 | # // 9 | # // Unless required by applicable law or agreed to in writing, software 10 | # // distributed under the License is distributed on an "AS IS" BASIS, 11 | # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # // See the License for the specific language governing permissions and 13 | # // limitations under the License. 14 | 15 | from typing import Callable 16 | 17 | 18 | class Cache: 19 | """Caching reusable args for faster inference""" 20 | 21 | def __init__(self, disable=False, prefix="", cache=None): 22 | self.cache = cache if cache is not None else {} 23 | self.disable = disable 24 | self.prefix = prefix 25 | 26 | def __call__(self, key: str, fn: Callable): 27 | if self.disable: 28 | return fn() 29 | 30 | key = self.prefix + key 31 | try: 32 | result = self.cache[key] 33 | except KeyError: 34 | result = fn() 35 | self.cache[key] = result 36 | return result 37 | 38 | def namespace(self, namespace: str): 39 | return Cache( 40 | disable=self.disable, 41 | prefix=self.prefix + namespace + ".", 42 | cache=self.cache, 43 | ) 44 | 45 | def get(self, key: str): 46 | key = self.prefix + key 47 | return self.cache[key] 48 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Comfyui-FlowChain 2 | 3 | Thank you for your interest in contributing to sd-wav2lip-uhq! We appreciate your effort and to help us incorporate your contribution in the best way possible, please follow the following contribution guidelines. 4 | 5 | ## Reporting Bugs 6 | 7 | If you find a bug in the project, we encourage you to report it. Here's how: 8 | 9 | 1. First, check the [existing Issues](url_of_issues) to see if the issue has already been reported. If it has, please add a comment to the existing issue rather than creating a new one. 10 | 2. If you can't find an existing issue that matches your bug, create a new issue. Make sure to include as many details as possible so we can understand and reproduce the problem. 11 | 12 | ## Proposing Changes 13 | 14 | We welcome code contributions from the community. Here's how to propose changes: 15 | 16 | 1. Fork this repository to your own GitHub account. 17 | 2. Create a new branch on your fork for your changes. 18 | 3. Make your changes in this branch. 19 | 4. When you are ready, submit a pull request to the `main` branch of this repository. 20 | 21 | Please note that we use the GitHub Flow workflow, so all pull requests should be made to the `main` branch. 22 | 23 | Before submitting a pull request, please make sure your code adheres to the project's coding conventions and it has passed all tests. If you are adding features, please also add appropriate tests. 24 | 25 | ## Contact 26 | 27 | If you have any questions or need help, please ping the developer via discord NumZ#7184 to make sure your addition will fit well into such a large project and to get help if needed. 28 | 29 | Thank you again for your contribution ! 30 | -------------------------------------------------------------------------------- /src/common/diffusion/timesteps/sampling/trailing.py: -------------------------------------------------------------------------------- 1 | # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 2 | # // 3 | # // Licensed under the Apache License, Version 2.0 (the "License"); 4 | # // you may not use this file except in compliance with the License. 5 | # // You may obtain a copy of the License at 6 | # // 7 | # // http://www.apache.org/licenses/LICENSE-2.0 8 | # // 9 | # // Unless required by applicable law or agreed to in writing, software 10 | # // distributed under the License is distributed on an "AS IS" BASIS, 11 | # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # // See the License for the specific language governing permissions and 13 | # // limitations under the License. 14 | 15 | import torch 16 | 17 | from ...types import SamplingDirection 18 | from ..base import SamplingTimesteps 19 | 20 | 21 | class UniformTrailingSamplingTimesteps(SamplingTimesteps): 22 | """ 23 | Uniform trailing sampling timesteps. 24 | Defined in (https://arxiv.org/abs/2305.08891) 25 | 26 | Shift is proposed in SD3 for RF schedule. 27 | Defined in (https://arxiv.org/pdf/2403.03206) eq.23 28 | """ 29 | 30 | def __init__( 31 | self, 32 | T: int, 33 | steps: int, 34 | shift: float = 1.0, 35 | device: torch.device = "cpu", 36 | ): 37 | # Create trailing timesteps. 38 | timesteps = torch.arange(1.0, 0.0, -1.0 / steps, device=device) 39 | 40 | # Shift timesteps. 41 | timesteps = shift * timesteps / (1 + (shift - 1) * timesteps) 42 | 43 | # Scale to T range. 44 | if isinstance(T, float): 45 | timesteps = timesteps * T 46 | else: 47 | timesteps = timesteps.mul(T + 1).sub(1).round().int() 48 | 49 | super().__init__(T=T, timesteps=timesteps, direction=SamplingDirection.backward) 50 | -------------------------------------------------------------------------------- /src/common/distributed/meta_init_utils.py: -------------------------------------------------------------------------------- 1 | # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 2 | # // 3 | # // Licensed under the Apache License, Version 2.0 (the "License"); 4 | # // you may not use this file except in compliance with the License. 5 | # // You may obtain a copy of the License at 6 | # // 7 | # // http://www.apache.org/licenses/LICENSE-2.0 8 | # // 9 | # // Unless required by applicable law or agreed to in writing, software 10 | # // distributed under the License is distributed on an "AS IS" BASIS, 11 | # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # // See the License for the specific language governing permissions and 13 | # // limitations under the License. 14 | 15 | import torch 16 | from rotary_embedding_torch import RotaryEmbedding 17 | from torch import nn 18 | from torch.distributed.fsdp._common_utils import _is_fsdp_flattened 19 | 20 | __all__ = ["meta_non_persistent_buffer_init_fn"] 21 | 22 | 23 | def meta_non_persistent_buffer_init_fn(module: nn.Module) -> nn.Module: 24 | """ 25 | Used for materializing `non-persistent tensor buffers` while model resuming. 26 | 27 | Since non-persistent tensor buffers are not saved in state_dict, 28 | when initializing model with meta device, user should materialize those buffers manually. 29 | 30 | Currently, only `rope.dummy` is this special case. 31 | """ 32 | with torch.no_grad(): 33 | for submodule in module.modules(): 34 | if not isinstance(submodule, RotaryEmbedding): 35 | continue 36 | for buffer_name, buffer in submodule.named_buffers(recurse=False): 37 | if buffer.is_meta and "dummy" in buffer_name: 38 | materialized_buffer = torch.zeros_like(buffer, device="cpu") 39 | setattr(submodule, buffer_name, materialized_buffer) 40 | assert not any(b.is_meta for n, b in module.named_buffers()) 41 | return module 42 | -------------------------------------------------------------------------------- /src/common/diffusion/schedules/lerp.py: -------------------------------------------------------------------------------- 1 | # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 2 | # // 3 | # // Licensed under the Apache License, Version 2.0 (the "License"); 4 | # // you may not use this file except in compliance with the License. 5 | # // You may obtain a copy of the License at 6 | # // 7 | # // http://www.apache.org/licenses/LICENSE-2.0 8 | # // 9 | # // Unless required by applicable law or agreed to in writing, software 10 | # // distributed under the License is distributed on an "AS IS" BASIS, 11 | # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # // See the License for the specific language governing permissions and 13 | # // limitations under the License. 14 | 15 | """ 16 | Linear interpolation schedule (lerp). 17 | """ 18 | 19 | from typing import Union 20 | import torch 21 | 22 | from .base import Schedule 23 | 24 | 25 | class LinearInterpolationSchedule(Schedule): 26 | """ 27 | Linear interpolation schedule (lerp) is proposed by flow matching and rectified flow. 28 | It leads to straighter probability flow theoretically. It is also used by Stable Diffusion 3. 29 | 30 | 31 | 32 | x_t = (1 - t) * x_0 + t * x_T 33 | 34 | Can be either continuous or discrete. 35 | """ 36 | 37 | def __init__(self, T: Union[int, float] = 1.0): 38 | self._T = T 39 | 40 | @property 41 | def T(self) -> Union[int, float]: 42 | return self._T 43 | 44 | def A(self, t: torch.Tensor) -> torch.Tensor: 45 | return 1 - (t / self.T) 46 | 47 | def B(self, t: torch.Tensor) -> torch.Tensor: 48 | return t / self.T 49 | 50 | # ---------------------------------------------------- 51 | 52 | def isnr(self, snr: torch.Tensor) -> torch.Tensor: 53 | t = self.T / (1 + snr**0.5) 54 | t = t if self.is_continuous() else t.round().int() 55 | return t 56 | -------------------------------------------------------------------------------- /src/common/partition.py: -------------------------------------------------------------------------------- 1 | # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 2 | # // 3 | # // Licensed under the Apache License, Version 2.0 (the "License"); 4 | # // you may not use this file except in compliance with the License. 5 | # // You may obtain a copy of the License at 6 | # // 7 | # // http://www.apache.org/licenses/LICENSE-2.0 8 | # // 9 | # // Unless required by applicable law or agreed to in writing, software 10 | # // distributed under the License is distributed on an "AS IS" BASIS, 11 | # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # // See the License for the specific language governing permissions and 13 | # // limitations under the License. 14 | 15 | """ 16 | Partition utility functions. 17 | """ 18 | 19 | from typing import Any, List 20 | 21 | 22 | def partition_by_size(data: List[Any], size: int) -> List[List[Any]]: 23 | """ 24 | Partition a list by size. 25 | When indivisible, the last group contains fewer items than the target size. 26 | 27 | Examples: 28 | - data: [1,2,3,4,5] 29 | - size: 2 30 | - return: [[1,2], [3,4], [5]] 31 | """ 32 | assert size > 0 33 | return [data[i : (i + size)] for i in range(0, len(data), size)] 34 | 35 | 36 | def partition_by_groups(data: List[Any], groups: int) -> List[List[Any]]: 37 | """ 38 | Partition a list by groups. 39 | When indivisible, some groups may have more items than others. 40 | 41 | Examples: 42 | - data: [1,2,3,4,5] 43 | - groups: 2 44 | - return: [[1,3,5], [2,4]] 45 | """ 46 | assert groups > 0 47 | return [data[i::groups] for i in range(groups)] 48 | 49 | 50 | def shift_list(data: List[Any], n: int) -> List[Any]: 51 | """ 52 | Rotate a list by n elements. 53 | 54 | Examples: 55 | - data: [1,2,3,4,5] 56 | - n: 3 57 | - return: [4,5,1,2,3] 58 | """ 59 | return data[(n % len(data)) :] + data[: (n % len(data))] 60 | -------------------------------------------------------------------------------- /src/data/image/transforms/na_resize.py: -------------------------------------------------------------------------------- 1 | # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 2 | # // 3 | # // Licensed under the Apache License, Version 2.0 (the "License"); 4 | # // you may not use this file except in compliance with the License. 5 | # // You may obtain a copy of the License at 6 | # // 7 | # // http://www.apache.org/licenses/LICENSE-2.0 8 | # // 9 | # // Unless required by applicable law or agreed to in writing, software 10 | # // distributed under the License is distributed on an "AS IS" BASIS, 11 | # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # // See the License for the specific language governing permissions and 13 | # // limitations under the License. 14 | 15 | import torch 16 | from typing import Literal 17 | from torchvision.transforms import CenterCrop, Compose, InterpolationMode, Resize 18 | 19 | from .area_resize import AreaResize 20 | from .side_resize import SideResize 21 | 22 | def NaResize( 23 | resolution: int, 24 | mode: Literal["area", "side"], 25 | downsample_only: bool, 26 | interpolation: InterpolationMode = InterpolationMode.BICUBIC, 27 | ): 28 | Interpolation = InterpolationMode.BILINEAR if torch.mps.is_available() else interpolation 29 | if mode == "area": 30 | return AreaResize( 31 | max_area=resolution**2, 32 | downsample_only=downsample_only, 33 | interpolation=Interpolation, 34 | ) 35 | if mode == "side": 36 | return SideResize( 37 | size=resolution, 38 | downsample_only=downsample_only, 39 | interpolation=Interpolation, 40 | ) 41 | if mode == "square": 42 | return Compose( 43 | [ 44 | Resize( 45 | size=resolution, 46 | interpolation=Interpolation, 47 | ), 48 | CenterCrop(resolution), 49 | ] 50 | ) 51 | raise ValueError(f"Unknown resize mode: {mode}") 52 | -------------------------------------------------------------------------------- /src/common/diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 2 | # // 3 | # // Licensed under the Apache License, Version 2.0 (the "License"); 4 | # // you may not use this file except in compliance with the License. 5 | # // You may obtain a copy of the License at 6 | # // 7 | # // http://www.apache.org/licenses/LICENSE-2.0 8 | # // 9 | # // Unless required by applicable law or agreed to in writing, software 10 | # // distributed under the License is distributed on an "AS IS" BASIS, 11 | # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # // See the License for the specific language governing permissions and 13 | # // limitations under the License. 14 | 15 | """ 16 | Diffusion package. 17 | """ 18 | 19 | from .config import ( 20 | create_sampler_from_config, 21 | create_sampling_timesteps_from_config, 22 | create_schedule_from_config, 23 | ) 24 | from .samplers.base import Sampler 25 | from .samplers.euler import EulerSampler 26 | from .schedules.base import Schedule 27 | from .schedules.lerp import LinearInterpolationSchedule 28 | from .timesteps.base import SamplingTimesteps, Timesteps 29 | from .timesteps.sampling.trailing import UniformTrailingSamplingTimesteps 30 | from .types import PredictionType, SamplingDirection 31 | from .utils import classifier_free_guidance, classifier_free_guidance_dispatcher, expand_dims 32 | 33 | __all__ = [ 34 | # Configs 35 | "create_sampler_from_config", 36 | "create_sampling_timesteps_from_config", 37 | "create_schedule_from_config", 38 | # Schedules 39 | "Schedule", 40 | "DiscreteVariancePreservingSchedule", 41 | "LinearInterpolationSchedule", 42 | # Samplers 43 | "Sampler", 44 | "EulerSampler", 45 | # Timesteps 46 | "Timesteps", 47 | "SamplingTimesteps", 48 | # Types 49 | "PredictionType", 50 | "SamplingDirection", 51 | "UniformTrailingSamplingTimesteps", 52 | # Utils 53 | "classifier_free_guidance", 54 | "classifier_free_guidance_dispatcher", 55 | "expand_dims", 56 | ] 57 | -------------------------------------------------------------------------------- /src/common/diffusion/types.py: -------------------------------------------------------------------------------- 1 | # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 2 | # // 3 | # // Licensed under the Apache License, Version 2.0 (the "License"); 4 | # // you may not use this file except in compliance with the License. 5 | # // You may obtain a copy of the License at 6 | # // 7 | # // http://www.apache.org/licenses/LICENSE-2.0 8 | # // 9 | # // Unless required by applicable law or agreed to in writing, software 10 | # // distributed under the License is distributed on an "AS IS" BASIS, 11 | # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # // See the License for the specific language governing permissions and 13 | # // limitations under the License. 14 | 15 | """ 16 | Type definitions. 17 | """ 18 | 19 | from enum import Enum 20 | 21 | 22 | class PredictionType(str, Enum): 23 | """ 24 | x_0: 25 | Predict data sample. 26 | x_T: 27 | Predict noise sample. 28 | Proposed by DDPM (https://arxiv.org/abs/2006.11239) 29 | Proved problematic by zsnr paper (https://arxiv.org/abs/2305.08891) 30 | v_cos: 31 | Predict velocity dx/dt based on the cosine schedule (A_t * x_T - B_t * x_0). 32 | Proposed by progressive distillation (https://arxiv.org/abs/2202.00512) 33 | v_lerp: 34 | Predict velocity dx/dt based on the lerp schedule (x_T - x_0). 35 | Proposed by rectified flow (https://arxiv.org/abs/2209.03003) 36 | """ 37 | 38 | x_0 = "x_0" 39 | x_T = "x_T" 40 | v_cos = "v_cos" 41 | v_lerp = "v_lerp" 42 | 43 | 44 | class SamplingDirection(str, Enum): 45 | """ 46 | backward: Sample from x_T to x_0 for data generation. 47 | forward: Sample from x_0 to x_T for noise inversion. 48 | """ 49 | 50 | backward = "backward" 51 | forward = "forward" 52 | 53 | @staticmethod 54 | def reverse(direction): 55 | if direction == SamplingDirection.backward: 56 | return SamplingDirection.forward 57 | if direction == SamplingDirection.forward: 58 | return SamplingDirection.backward 59 | raise NotImplementedError 60 | -------------------------------------------------------------------------------- /src/common/diffusion/timesteps/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Sequence, Union 3 | import torch 4 | 5 | from ..types import SamplingDirection 6 | 7 | 8 | class Timesteps(ABC): 9 | """ 10 | Timesteps base class. 11 | """ 12 | 13 | def __init__(self, T: Union[int, float]): 14 | assert T > 0 15 | self._T = T 16 | 17 | @property 18 | def T(self) -> Union[int, float]: 19 | """ 20 | Maximum timestep inclusive. 21 | int if discrete, float if continuous. 22 | """ 23 | return self._T 24 | 25 | def is_continuous(self) -> bool: 26 | """ 27 | Whether the schedule is continuous. 28 | """ 29 | return isinstance(self.T, float) 30 | 31 | 32 | class SamplingTimesteps(Timesteps): 33 | """ 34 | Sampling timesteps. 35 | It defines the discretization of sampling steps. 36 | """ 37 | 38 | def __init__( 39 | self, 40 | T: Union[int, float], 41 | timesteps: torch.Tensor, 42 | direction: SamplingDirection, 43 | ): 44 | assert timesteps.ndim == 1 45 | super().__init__(T) 46 | self.timesteps = timesteps 47 | self.direction = direction 48 | 49 | def __len__(self) -> int: 50 | """ 51 | Number of sampling steps. 52 | """ 53 | return len(self.timesteps) 54 | 55 | def __getitem__(self, idx: Union[int, torch.IntTensor]) -> torch.Tensor: 56 | """ 57 | The timestep at the sampling step. 58 | Returns a scalar tensor if idx is int, 59 | or tensor of the same size if idx is a tensor. 60 | """ 61 | return self.timesteps[idx] 62 | 63 | def index(self, t: torch.Tensor) -> torch.Tensor: 64 | """ 65 | Find index by t. 66 | Return index of the same shape as t. 67 | Index is -1 if t not found in timesteps. 68 | """ 69 | i, j = t.reshape(-1, 1).eq(self.timesteps).nonzero(as_tuple=True) 70 | idx = torch.full_like(t, fill_value=-1, dtype=torch.int) 71 | idx.view(-1)[i] = j.int() 72 | return idx 73 | -------------------------------------------------------------------------------- /src/data/image/transforms/side_resize.py: -------------------------------------------------------------------------------- 1 | # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 2 | # // 3 | # // Licensed under the Apache License, Version 2.0 (the "License"); 4 | # // you may not use this file except in compliance with the License. 5 | # // You may obtain a copy of the License at 6 | # // 7 | # // http://www.apache.org/licenses/LICENSE-2.0 8 | # // 9 | # // Unless required by applicable law or agreed to in writing, software 10 | # // distributed under the License is distributed on an "AS IS" BASIS, 11 | # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # // See the License for the specific language governing permissions and 13 | # // limitations under the License. 14 | 15 | from typing import Union 16 | import torch 17 | from PIL import Image 18 | from torchvision.transforms import InterpolationMode 19 | from torchvision.transforms import functional as TVF 20 | 21 | class SideResize: 22 | def __init__( 23 | self, 24 | size: int, 25 | downsample_only: bool = False, 26 | interpolation: InterpolationMode = InterpolationMode.BICUBIC, 27 | ): 28 | self.size = size 29 | self.downsample_only = downsample_only 30 | self.interpolation = interpolation 31 | if torch.mps.is_available(): 32 | self.interpolation = InterpolationMode.BILINEAR 33 | 34 | def __call__(self, image: Union[torch.Tensor, Image.Image]): 35 | """ 36 | Args: 37 | image (PIL Image or Tensor): Image to be scaled. 38 | 39 | Returns: 40 | PIL Image or Tensor: Rescaled image. 41 | """ 42 | if isinstance(image, torch.Tensor): 43 | height, width = image.shape[-2:] 44 | elif isinstance(image, Image.Image): 45 | width, height = image.size 46 | else: 47 | raise NotImplementedError 48 | 49 | if self.downsample_only and min(width, height) < self.size: 50 | # keep original height and width for small pictures. 51 | size = min(width, height) 52 | else: 53 | size = self.size 54 | 55 | return TVF.resize(image, size, self.interpolation) 56 | -------------------------------------------------------------------------------- /src/models/dit/mlp.py: -------------------------------------------------------------------------------- 1 | # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 2 | # // 3 | # // Licensed under the Apache License, Version 2.0 (the "License"); 4 | # // you may not use this file except in compliance with the License. 5 | # // You may obtain a copy of the License at 6 | # // 7 | # // http://www.apache.org/licenses/LICENSE-2.0 8 | # // 9 | # // Unless required by applicable law or agreed to in writing, software 10 | # // distributed under the License is distributed on an "AS IS" BASIS, 11 | # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # // See the License for the specific language governing permissions and 13 | # // limitations under the License. 14 | 15 | from typing import Optional 16 | import torch 17 | import torch.nn.functional as F 18 | from torch import nn 19 | 20 | 21 | def get_mlp(mlp_type: Optional[str] = "normal"): 22 | if mlp_type == "normal": 23 | return MLP 24 | elif mlp_type == "swiglu": 25 | return SwiGLUMLP 26 | 27 | 28 | class MLP(nn.Module): 29 | def __init__( 30 | self, 31 | dim: int, 32 | expand_ratio: int, 33 | ): 34 | super().__init__() 35 | self.proj_in = nn.Linear(dim, dim * expand_ratio) 36 | self.act = nn.GELU("tanh") 37 | self.proj_out = nn.Linear(dim * expand_ratio, dim) 38 | 39 | def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: 40 | x = self.proj_in(x) 41 | x = self.act(x) 42 | x = self.proj_out(x) 43 | return x 44 | 45 | 46 | class SwiGLUMLP(nn.Module): 47 | def __init__( 48 | self, 49 | dim: int, 50 | expand_ratio: int, 51 | multiple_of: int = 256, 52 | ): 53 | super().__init__() 54 | hidden_dim = int(2 * dim * expand_ratio / 3) 55 | hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) 56 | self.proj_in_gate = nn.Linear(dim, hidden_dim, bias=False) 57 | self.proj_out = nn.Linear(hidden_dim, dim, bias=False) 58 | self.proj_in = nn.Linear(dim, hidden_dim, bias=False) 59 | 60 | def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: 61 | x = self.proj_out(F.silu(self.proj_in_gate(x)) * self.proj_in(x)) 62 | return x 63 | -------------------------------------------------------------------------------- /src/models/dit_v2/mlp.py: -------------------------------------------------------------------------------- 1 | # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 2 | # // 3 | # // Licensed under the Apache License, Version 2.0 (the "License"); 4 | # // you may not use this file except in compliance with the License. 5 | # // You may obtain a copy of the License at 6 | # // 7 | # // http://www.apache.org/licenses/LICENSE-2.0 8 | # // 9 | # // Unless required by applicable law or agreed to in writing, software 10 | # // distributed under the License is distributed on an "AS IS" BASIS, 11 | # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # // See the License for the specific language governing permissions and 13 | # // limitations under the License. 14 | 15 | from typing import Optional 16 | import torch 17 | import torch.nn.functional as F 18 | from torch import nn 19 | 20 | 21 | def get_mlp(mlp_type: Optional[str] = "normal"): 22 | if mlp_type == "normal": 23 | return MLP 24 | elif mlp_type == "swiglu": 25 | return SwiGLUMLP 26 | 27 | 28 | class MLP(nn.Module): 29 | def __init__( 30 | self, 31 | dim: int, 32 | expand_ratio: int, 33 | ): 34 | super().__init__() 35 | self.proj_in = nn.Linear(dim, dim * expand_ratio) 36 | self.act = nn.GELU("tanh") 37 | self.proj_out = nn.Linear(dim * expand_ratio, dim) 38 | 39 | def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: 40 | x = self.proj_in(x) 41 | x = self.act(x) 42 | x = self.proj_out(x) 43 | return x 44 | 45 | 46 | class SwiGLUMLP(nn.Module): 47 | def __init__( 48 | self, 49 | dim: int, 50 | expand_ratio: int, 51 | multiple_of: int = 256, 52 | ): 53 | super().__init__() 54 | hidden_dim = int(2 * dim * expand_ratio / 3) 55 | hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) 56 | self.proj_in_gate = nn.Linear(dim, hidden_dim, bias=False) 57 | self.proj_out = nn.Linear(hidden_dim, dim, bias=False) 58 | self.proj_in = nn.Linear(dim, hidden_dim, bias=False) 59 | 60 | def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: 61 | x = self.proj_out(F.silu(self.proj_in_gate(x)) * self.proj_in(x)) 62 | return x 63 | -------------------------------------------------------------------------------- /src/models/dit/embedding.py: -------------------------------------------------------------------------------- 1 | # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 2 | # // 3 | # // Licensed under the Apache License, Version 2.0 (the "License"); 4 | # // you may not use this file except in compliance with the License. 5 | # // You may obtain a copy of the License at 6 | # // 7 | # // http://www.apache.org/licenses/LICENSE-2.0 8 | # // 9 | # // Unless required by applicable law or agreed to in writing, software 10 | # // distributed under the License is distributed on an "AS IS" BASIS, 11 | # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # // See the License for the specific language governing permissions and 13 | # // limitations under the License. 14 | 15 | from typing import Optional, Union 16 | import torch 17 | from diffusers.models.embeddings import get_timestep_embedding 18 | from torch import nn 19 | 20 | 21 | def emb_add(emb1: torch.Tensor, emb2: Optional[torch.Tensor]): 22 | return emb1 if emb2 is None else emb1 + emb2 23 | 24 | 25 | class TimeEmbedding(nn.Module): 26 | def __init__( 27 | self, 28 | sinusoidal_dim: int, 29 | hidden_dim: int, 30 | output_dim: int, 31 | ): 32 | super().__init__() 33 | self.sinusoidal_dim = sinusoidal_dim 34 | self.proj_in = nn.Linear(sinusoidal_dim, hidden_dim) 35 | self.proj_hid = nn.Linear(hidden_dim, hidden_dim) 36 | self.proj_out = nn.Linear(hidden_dim, output_dim) 37 | self.act = nn.SiLU() 38 | 39 | def forward( 40 | self, 41 | timestep: Union[int, float, torch.IntTensor, torch.FloatTensor], 42 | device: torch.device, 43 | dtype: torch.dtype, 44 | ) -> torch.FloatTensor: 45 | if not torch.is_tensor(timestep): 46 | timestep = torch.tensor([timestep], device=device, dtype=dtype) 47 | if timestep.ndim == 0: 48 | timestep = timestep[None] 49 | 50 | emb = get_timestep_embedding( 51 | timesteps=timestep, 52 | embedding_dim=self.sinusoidal_dim, 53 | flip_sin_to_cos=False, 54 | downscale_freq_shift=0, 55 | ) 56 | emb = emb.to(dtype) 57 | emb = self.proj_in(emb) 58 | emb = self.act(emb) 59 | emb = self.proj_hid(emb) 60 | emb = self.act(emb) 61 | emb = self.proj_out(emb) 62 | return emb 63 | -------------------------------------------------------------------------------- /src/models/dit_v2/embedding.py: -------------------------------------------------------------------------------- 1 | # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 2 | # // 3 | # // Licensed under the Apache License, Version 2.0 (the "License"); 4 | # // you may not use this file except in compliance with the License. 5 | # // You may obtain a copy of the License at 6 | # // 7 | # // http://www.apache.org/licenses/LICENSE-2.0 8 | # // 9 | # // Unless required by applicable law or agreed to in writing, software 10 | # // distributed under the License is distributed on an "AS IS" BASIS, 11 | # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # // See the License for the specific language governing permissions and 13 | # // limitations under the License. 14 | 15 | from typing import Optional, Union 16 | import torch 17 | from diffusers.models.embeddings import get_timestep_embedding 18 | from torch import nn 19 | 20 | 21 | def emb_add(emb1: torch.Tensor, emb2: Optional[torch.Tensor]): 22 | return emb1 if emb2 is None else emb1 + emb2 23 | 24 | 25 | class TimeEmbedding(nn.Module): 26 | def __init__( 27 | self, 28 | sinusoidal_dim: int, 29 | hidden_dim: int, 30 | output_dim: int, 31 | ): 32 | super().__init__() 33 | self.sinusoidal_dim = sinusoidal_dim 34 | self.proj_in = nn.Linear(sinusoidal_dim, hidden_dim) 35 | self.proj_hid = nn.Linear(hidden_dim, hidden_dim) 36 | self.proj_out = nn.Linear(hidden_dim, output_dim) 37 | self.act = nn.SiLU() 38 | 39 | def forward( 40 | self, 41 | timestep: Union[int, float, torch.IntTensor, torch.FloatTensor], 42 | device: torch.device, 43 | dtype: torch.dtype, 44 | ) -> torch.FloatTensor: 45 | if not torch.is_tensor(timestep): 46 | timestep = torch.tensor([timestep], device=device, dtype=dtype) 47 | if timestep.ndim == 0: 48 | timestep = timestep[None] 49 | 50 | emb = get_timestep_embedding( 51 | timesteps=timestep, 52 | embedding_dim=self.sinusoidal_dim, 53 | flip_sin_to_cos=False, 54 | downscale_freq_shift=0, 55 | ) 56 | emb = emb.to(dtype) 57 | emb = self.proj_in(emb) 58 | emb = self.act(emb) 59 | emb = self.proj_hid(emb) 60 | emb = self.act(emb) 61 | emb = self.proj_out(emb) 62 | return emb 63 | -------------------------------------------------------------------------------- /src/common/diffusion/config.py: -------------------------------------------------------------------------------- 1 | # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 2 | # // 3 | # // Licensed under the Apache License, Version 2.0 (the "License"); 4 | # // you may not use this file except in compliance with the License. 5 | # // You may obtain a copy of the License at 6 | # // 7 | # // http://www.apache.org/licenses/LICENSE-2.0 8 | # // 9 | # // Unless required by applicable law or agreed to in writing, software 10 | # // distributed under the License is distributed on an "AS IS" BASIS, 11 | # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # // See the License for the specific language governing permissions and 13 | # // limitations under the License. 14 | 15 | """ 16 | Utility functions for creating schedules and samplers from config. 17 | """ 18 | 19 | import torch 20 | from omegaconf import DictConfig 21 | 22 | from .samplers.base import Sampler 23 | from .samplers.euler import EulerSampler 24 | from .schedules.base import Schedule 25 | from .schedules.lerp import LinearInterpolationSchedule 26 | from .timesteps.base import SamplingTimesteps 27 | from .timesteps.sampling.trailing import UniformTrailingSamplingTimesteps 28 | 29 | 30 | def create_schedule_from_config( 31 | config: DictConfig, 32 | device: torch.device, 33 | dtype: torch.dtype = torch.float32, 34 | ) -> Schedule: 35 | """ 36 | Create a schedule from configuration. 37 | """ 38 | if config.type == "lerp": 39 | return LinearInterpolationSchedule(T=config.get("T", 1.0)) 40 | 41 | raise NotImplementedError 42 | 43 | 44 | def create_sampler_from_config( 45 | config: DictConfig, 46 | schedule: Schedule, 47 | timesteps: SamplingTimesteps, 48 | ) -> Sampler: 49 | """ 50 | Create a sampler from configuration. 51 | """ 52 | if config.type == "euler": 53 | return EulerSampler( 54 | schedule=schedule, 55 | timesteps=timesteps, 56 | prediction_type=config.prediction_type, 57 | ) 58 | raise NotImplementedError 59 | 60 | 61 | def create_sampling_timesteps_from_config( 62 | config: DictConfig, 63 | schedule: Schedule, 64 | device: torch.device, 65 | dtype: torch.dtype = torch.float32, 66 | ) -> SamplingTimesteps: 67 | if config.type == "uniform_trailing": 68 | return UniformTrailingSamplingTimesteps( 69 | T=schedule.T, 70 | steps=config.steps, 71 | shift=config.get("shift", 1.0), 72 | device=device, 73 | ) 74 | raise NotImplementedError -------------------------------------------------------------------------------- /src/models/dit/mm.py: -------------------------------------------------------------------------------- 1 | # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 2 | # // 3 | # // Licensed under the Apache License, Version 2.0 (the "License"); 4 | # // you may not use this file except in compliance with the License. 5 | # // You may obtain a copy of the License at 6 | # // 7 | # // http://www.apache.org/licenses/LICENSE-2.0 8 | # // 9 | # // Unless required by applicable law or agreed to in writing, software 10 | # // distributed under the License is distributed on an "AS IS" BASIS, 11 | # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # // See the License for the specific language governing permissions and 13 | # // limitations under the License. 14 | 15 | from dataclasses import dataclass 16 | from typing import Any, Callable, Dict, List, Tuple 17 | import torch 18 | from torch import nn 19 | 20 | 21 | @dataclass 22 | class MMArg: 23 | vid: Any 24 | txt: Any 25 | 26 | 27 | def get_args(key: str, args: List[Any]) -> List[Any]: 28 | return [getattr(v, key) if isinstance(v, MMArg) else v for v in args] 29 | 30 | 31 | def get_kwargs(key: str, kwargs: Dict[str, Any]) -> Dict[str, Any]: 32 | return {k: getattr(v, key) if isinstance(v, MMArg) else v for k, v in kwargs.items()} 33 | 34 | 35 | class MMModule(nn.Module): 36 | def __init__( 37 | self, 38 | module: Callable[..., nn.Module], 39 | *args, 40 | shared_weights: bool = False, 41 | **kwargs, 42 | ): 43 | super().__init__() 44 | self.shared_weights = shared_weights 45 | if self.shared_weights: 46 | assert get_args("vid", args) == get_args("txt", args) 47 | assert get_kwargs("vid", kwargs) == get_kwargs("txt", kwargs) 48 | self.all = module(*get_args("vid", args), **get_kwargs("vid", kwargs)) 49 | else: 50 | self.vid = module(*get_args("vid", args), **get_kwargs("vid", kwargs)) 51 | self.txt = module(*get_args("txt", args), **get_kwargs("txt", kwargs)) 52 | 53 | def forward( 54 | self, 55 | vid: torch.FloatTensor, 56 | txt: torch.FloatTensor, 57 | *args, 58 | **kwargs, 59 | ) -> Tuple[ 60 | torch.FloatTensor, 61 | torch.FloatTensor, 62 | ]: 63 | vid_module = self.vid if not self.shared_weights else self.all 64 | txt_module = self.txt if not self.shared_weights else self.all 65 | vid = vid_module(vid, *get_args("vid", args), **get_kwargs("vid", kwargs)) 66 | txt = txt_module(txt, *get_args("txt", args), **get_kwargs("txt", kwargs)) 67 | return vid, txt 68 | -------------------------------------------------------------------------------- /configs_7b/main.yaml: -------------------------------------------------------------------------------- 1 | __object__: 2 | path: projects.video_diffusion_sr.train 3 | name: VideoDiffusionTrainer 4 | 5 | dit: 6 | model: 7 | __object__: 8 | path: 9 | - "custom_nodes.ComfyUI-SeedVR2_VideoUpscaler.src.models.dit.nadit" 10 | - "ComfyUI.custom_nodes.ComfyUI-SeedVR2_VideoUpscaler.src.models.dit.nadit" 11 | - "src.models.dit.nadit" 12 | name: "NaDiT" 13 | args: "as_params" 14 | vid_in_channels: 33 15 | vid_out_channels: 16 16 | vid_dim: 3072 17 | txt_in_dim: 5120 18 | txt_dim: ${.vid_dim} 19 | emb_dim: ${eval:'6 * ${.vid_dim}'} 20 | heads: 24 21 | head_dim: 128 # llm-like 22 | expand_ratio: 4 23 | norm: fusedrms 24 | norm_eps: 1e-5 25 | ada: single 26 | qk_bias: False 27 | qk_rope: True 28 | qk_norm: fusedrms 29 | patch_size: [1, 2, 2] 30 | num_layers: 36 # llm-like 31 | shared_mlp: False 32 | shared_qkv: False 33 | mlp_type: normal 34 | block_type: ${eval:'${.num_layers} * ["mmdit_sr"]'} # space-full 35 | window: ${eval:'${.num_layers} * [(4,3,3)]'} # space-full 36 | window_method: ${eval:'${.num_layers} // 2 * ["720pwin_by_size_bysize","720pswin_by_size_bysize"]'} # space-full 37 | compile: False 38 | gradient_checkpoint: True 39 | fsdp: 40 | sharding_strategy: _HYBRID_SHARD_ZERO2 41 | 42 | ema: 43 | decay: 0.9998 44 | 45 | vae: 46 | model: 47 | __object__: 48 | path: 49 | - "custom_nodes.ComfyUI-SeedVR2_VideoUpscaler.src.models.video_vae_v3.modules.attn_video_vae" 50 | - "ComfyUI.custom_nodes.ComfyUI-SeedVR2_VideoUpscaler.src.models.video_vae_v3.modules.attn_video_vae" 51 | - "src.models.video_vae_v3.modules.attn_video_vae" 52 | name: "VideoAutoencoderKLWrapper" 53 | args: "as_params" 54 | freeze_encoder: False 55 | # gradient_checkpoint: True 56 | slicing: 57 | split_size: 4 58 | memory_device: same 59 | memory_limit: 60 | conv_max_mem: 0.5 61 | norm_max_mem: 0.5 62 | checkpoint: ema_vae_fp16.safetensors 63 | scaling_factor: 0.9152 64 | compile: False 65 | grouping: False 66 | dtype: float16 67 | 68 | diffusion: 69 | schedule: 70 | type: lerp 71 | T: 1000.0 72 | sampler: 73 | type: euler 74 | prediction_type: v_lerp 75 | timesteps: 76 | training: 77 | type: logitnormal 78 | loc: 0.0 79 | scale: 1.0 80 | sampling: 81 | type: uniform_trailing 82 | steps: 50 83 | transform: True 84 | loss: 85 | type: v_lerp 86 | cfg: 87 | scale: 7.5 88 | rescale: 0 89 | 90 | condition: 91 | i2v: 0.0 92 | v2v: 0.0 93 | sr: 1.0 94 | noise_scale: 0.25 95 | -------------------------------------------------------------------------------- /src/common/half_precision_fixes.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | def safe_pad_operation(x, padding, mode='constant', value=0.0): 5 | """Safe padding operation that handles Half precision only for problematic modes""" 6 | # Modes qui nécessitent le fix Half precision 7 | problematic_modes = ['replicate', 'reflect', 'circular'] 8 | 9 | if mode in problematic_modes: 10 | try: 11 | return F.pad(x, padding, mode=mode, value=value) 12 | except RuntimeError as e: 13 | if "not implemented for 'Half'" in str(e): 14 | original_dtype = x.dtype 15 | return F.pad(x.float(), padding, mode=mode, value=value).to(original_dtype) 16 | else: 17 | raise e 18 | else: 19 | # Pour 'constant' et autres modes compatibles, pas de fix nécessaire 20 | return F.pad(x, padding, mode=mode, value=value) 21 | 22 | 23 | def safe_interpolate_operation(x, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None): 24 | """Safe interpolate operation that handles Half precision for problematic modes""" 25 | # Modes qui peuvent causer des problèmes avec Half precision 26 | problematic_modes = ['bilinear', 'bicubic', 'trilinear'] 27 | 28 | if mode in problematic_modes: 29 | try: 30 | return F.interpolate( 31 | x, 32 | size=size, 33 | scale_factor=scale_factor, 34 | mode=mode, 35 | align_corners=align_corners, 36 | recompute_scale_factor=recompute_scale_factor 37 | ) 38 | except RuntimeError as e: 39 | if ("not implemented for 'Half'" in str(e) or 40 | "compute_indices_weights" in str(e)): 41 | original_dtype = x.dtype 42 | return F.interpolate( 43 | x.float(), 44 | size=size, 45 | scale_factor=scale_factor, 46 | mode=mode, 47 | align_corners=align_corners, 48 | recompute_scale_factor=recompute_scale_factor 49 | ).to(original_dtype) 50 | else: 51 | raise e 52 | else: 53 | # Pour 'nearest' et autres modes compatibles, pas de fix nécessaire 54 | return F.interpolate( 55 | x, 56 | size=size, 57 | scale_factor=scale_factor, 58 | mode=mode, 59 | align_corners=align_corners, 60 | recompute_scale_factor=recompute_scale_factor 61 | ) 62 | -------------------------------------------------------------------------------- /src/common/distributed/basic.py: -------------------------------------------------------------------------------- 1 | # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 2 | # // 3 | # // Licensed under the Apache License, Version 2.0 (the "License"); 4 | # // you may not use this file except in compliance with the License. 5 | # // You may obtain a copy of the License at 6 | # // 7 | # // http://www.apache.org/licenses/LICENSE-2.0 8 | # // 9 | # // Unless required by applicable law or agreed to in writing, software 10 | # // distributed under the License is distributed on an "AS IS" BASIS, 11 | # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # // See the License for the specific language governing permissions and 13 | # // limitations under the License. 14 | 15 | """ 16 | Distributed basic functions. 17 | """ 18 | 19 | import os 20 | from datetime import timedelta 21 | import torch 22 | import torch.distributed as dist 23 | from torch.nn.parallel import DistributedDataParallel 24 | 25 | def get_global_rank() -> int: 26 | """ 27 | Get the global rank, the global index of the GPU. 28 | """ 29 | return int(os.environ.get("RANK", "0")) 30 | 31 | 32 | def get_local_rank() -> int: 33 | """ 34 | Get the local rank, the local index of the GPU. 35 | """ 36 | return int(os.environ.get("LOCAL_RANK", "0")) 37 | 38 | 39 | def get_world_size() -> int: 40 | """ 41 | Get the world size, the total amount of GPUs. 42 | """ 43 | return int(os.environ.get("WORLD_SIZE", "1")) 44 | 45 | 46 | def get_device() -> torch.device: 47 | """ 48 | Get current rank device. 49 | """ 50 | if torch.mps.is_available(): 51 | return "mps" 52 | return torch.device("cuda", get_local_rank()) 53 | 54 | 55 | def barrier_if_distributed(*args, **kwargs): 56 | """ 57 | Synchronizes all processes if under distributed context. 58 | """ 59 | if dist.is_initialized(): 60 | return dist.barrier(*args, **kwargs) 61 | 62 | 63 | def init_torch(cudnn_benchmark=True, timeout=timedelta(seconds=600)): 64 | """ 65 | Common PyTorch initialization configuration. 66 | """ 67 | torch.backends.cuda.matmul.allow_tf32 = True 68 | torch.backends.cudnn.allow_tf32 = True 69 | torch.backends.cudnn.benchmark = cudnn_benchmark 70 | torch.cuda.set_device(get_local_rank()) 71 | dist.init_process_group( 72 | backend="nccl", 73 | rank=get_global_rank(), 74 | world_size=get_world_size(), 75 | timeout=timeout, 76 | ) 77 | 78 | 79 | def convert_to_ddp(module: torch.nn.Module, **kwargs) -> DistributedDataParallel: 80 | return DistributedDataParallel( 81 | module=module, 82 | device_ids=[get_local_rank()], 83 | output_device=get_local_rank(), 84 | **kwargs, 85 | ) 86 | -------------------------------------------------------------------------------- /src/models/video_vae_v3/modules/context_parallel_lib.py: -------------------------------------------------------------------------------- 1 | # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 2 | # // 3 | # // Licensed under the Apache License, Version 2.0 (the "License"); 4 | # // you may not use this file except in compliance with the License. 5 | # // You may obtain a copy of the License at 6 | # // 7 | # // http://www.apache.org/licenses/LICENSE-2.0 8 | # // 9 | # // Unless required by applicable law or agreed to in writing, software 10 | # // distributed under the License is distributed on an "AS IS" BASIS, 11 | # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # // See the License for the specific language governing permissions and 13 | # // limitations under the License. 14 | 15 | from typing import List 16 | import torch 17 | import torch.nn.functional as F 18 | from torch import Tensor 19 | 20 | from .types import MemoryState 21 | 22 | # Single GPU inference - no distributed processing needed 23 | # print("Warning: Using single GPU inference mode - distributed features disabled") 24 | 25 | 26 | def causal_conv_slice_inputs(x, split_size, memory_state): 27 | # Single GPU inference - no slicing needed, return full tensor 28 | return x 29 | 30 | 31 | def causal_conv_gather_outputs(x): 32 | # Single GPU inference - no gathering needed, return tensor as is 33 | return x 34 | 35 | 36 | def get_output_len(conv_module, input_len, pad_len, dim=0): 37 | dilated_kernerl_size = conv_module.dilation[dim] * (conv_module.kernel_size[dim] - 1) + 1 38 | output_len = (input_len + pad_len - dilated_kernerl_size) // conv_module.stride[dim] + 1 39 | return output_len 40 | 41 | 42 | def get_cache_size(conv_module, input_len, pad_len, dim=0): 43 | dilated_kernerl_size = conv_module.dilation[dim] * (conv_module.kernel_size[dim] - 1) + 1 44 | output_len = (input_len + pad_len - dilated_kernerl_size) // conv_module.stride[dim] + 1 45 | remain_len = ( 46 | input_len + pad_len - ((output_len - 1) * conv_module.stride[dim] + dilated_kernerl_size) 47 | ) 48 | overlap_len = dilated_kernerl_size - conv_module.stride[dim] 49 | cache_len = overlap_len + remain_len # >= 0 50 | 51 | assert output_len > 0 52 | return cache_len 53 | 54 | 55 | def cache_send_recv(tensor: List[Tensor], cache_size, times, memory=None): 56 | # Single GPU inference - simplified cache handling 57 | recv_buffer = None 58 | 59 | # Handle memory buffer for single GPU case 60 | if memory is not None: 61 | recv_buffer = memory.to(tensor[0]) 62 | elif times > 0: 63 | tile_repeat = [1] * tensor[0].ndim 64 | tile_repeat[2] = times 65 | recv_buffer = torch.tile(tensor[0][:, :, :1], tile_repeat) 66 | 67 | return recv_buffer 68 | -------------------------------------------------------------------------------- /configs_3b/main.yaml: -------------------------------------------------------------------------------- 1 | __object__: 2 | path: projects.video_diffusion_sr.train 3 | name: VideoDiffusionTrainer 4 | 5 | dit: 6 | model: 7 | __object__: 8 | path: 9 | - "custom_nodes.ComfyUI-SeedVR2_VideoUpscaler.src.models.dit_v2.nadit" 10 | - "ComfyUI.custom_nodes.ComfyUI-SeedVR2_VideoUpscaler.src.models.dit_v2.nadit" 11 | - "src.models.dit_v2.nadit" 12 | name: "NaDiT" 13 | args: "as_params" 14 | vid_in_channels: 33 15 | vid_out_channels: 16 16 | vid_dim: 2560 17 | vid_out_norm: fusedrms 18 | txt_in_dim: 5120 19 | txt_in_norm: fusedln 20 | txt_dim: ${.vid_dim} 21 | emb_dim: ${eval:'6 * ${.vid_dim}'} 22 | heads: 20 23 | head_dim: 128 # llm-like 24 | expand_ratio: 4 25 | norm: fusedrms 26 | norm_eps: 1.0e-05 27 | ada: single 28 | qk_bias: False 29 | qk_norm: fusedrms 30 | patch_size: [1, 2, 2] 31 | num_layers: 32 # llm-like 32 | mm_layers: 10 33 | mlp_type: swiglu 34 | msa_type: None 35 | block_type: ${eval:'${.num_layers} * ["mmdit_sr"]'} # space-full 36 | window: ${eval:'${.num_layers} * [(4,3,3)]'} # space-full 37 | window_method: ${eval:'${.num_layers} // 2 * ["720pwin_by_size_bysize","720pswin_by_size_bysize"]'} # space-full 38 | rope_type: mmrope3d 39 | rope_dim: 128 40 | compile: False 41 | gradient_checkpoint: True 42 | fsdp: 43 | sharding_strategy: _HYBRID_SHARD_ZERO2 44 | 45 | ema: 46 | decay: 0.9998 47 | 48 | vae: 49 | model: 50 | __object__: 51 | path: 52 | - "custom_nodes.ComfyUI-SeedVR2_VideoUpscaler.src.models.video_vae_v3.modules.attn_video_vae" 53 | - "ComfyUI.custom_nodes.ComfyUI-SeedVR2_VideoUpscaler.src.models.video_vae_v3.modules.attn_video_vae" 54 | - "src.models.video_vae_v3.modules.attn_video_vae" 55 | name: "VideoAutoencoderKLWrapper" 56 | args: "as_params" 57 | freeze_encoder: False 58 | gradient_checkpoint: True # Disabled to prevent VRAM leaks in inference 59 | slicing: 60 | split_size: 4 61 | memory_device: same 62 | memory_limit: 63 | conv_max_mem: 0.5 64 | norm_max_mem: 0.5 65 | checkpoint: ema_vae_fp16.safetensors 66 | scaling_factor: 0.9152 67 | compile: False 68 | grouping: False 69 | dtype: float16 70 | 71 | diffusion: 72 | schedule: 73 | type: lerp 74 | T: 1000.0 75 | sampler: 76 | type: euler 77 | prediction_type: v_lerp 78 | timesteps: 79 | training: 80 | type: logitnormal 81 | loc: 0.0 82 | scale: 1.0 83 | sampling: 84 | type: uniform_trailing 85 | steps: 50 86 | transform: True 87 | loss: 88 | type: v_lerp 89 | cfg: 90 | scale: 7.5 91 | rescale: 0 92 | 93 | condition: 94 | i2v: 0.0 95 | v2v: 0.0 96 | sr: 1.0 97 | noise_scale: 0.25 98 | -------------------------------------------------------------------------------- /src/models/dit_v2/mm.py: -------------------------------------------------------------------------------- 1 | # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 2 | # // 3 | # // Licensed under the Apache License, Version 2.0 (the "License"); 4 | # // you may not use this file except in compliance with the License. 5 | # // You may obtain a copy of the License at 6 | # // 7 | # // http://www.apache.org/licenses/LICENSE-2.0 8 | # // 9 | # // Unless required by applicable law or agreed to in writing, software 10 | # // distributed under the License is distributed on an "AS IS" BASIS, 11 | # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # // See the License for the specific language governing permissions and 13 | # // limitations under the License. 14 | 15 | from dataclasses import dataclass 16 | from typing import Any, Callable, Dict, List, Tuple 17 | import torch 18 | from torch import nn 19 | 20 | 21 | @dataclass 22 | class MMArg: 23 | vid: Any 24 | txt: Any 25 | 26 | 27 | def get_args(key: str, args: List[Any]) -> List[Any]: 28 | return [getattr(v, key) if isinstance(v, MMArg) else v for v in args] 29 | 30 | 31 | def get_kwargs(key: str, kwargs: Dict[str, Any]) -> Dict[str, Any]: 32 | return {k: getattr(v, key) if isinstance(v, MMArg) else v for k, v in kwargs.items()} 33 | 34 | 35 | class MMModule(nn.Module): 36 | def __init__( 37 | self, 38 | module: Callable[..., nn.Module], 39 | *args, 40 | shared_weights: bool = False, 41 | vid_only: bool = False, 42 | **kwargs, 43 | ): 44 | super().__init__() 45 | self.shared_weights = shared_weights 46 | self.vid_only = vid_only 47 | if self.shared_weights: 48 | assert get_args("vid", args) == get_args("txt", args) 49 | assert get_kwargs("vid", kwargs) == get_kwargs("txt", kwargs) 50 | self.all = module(*get_args("vid", args), **get_kwargs("vid", kwargs)) 51 | else: 52 | self.vid = module(*get_args("vid", args), **get_kwargs("vid", kwargs)) 53 | self.txt = ( 54 | module(*get_args("txt", args), **get_kwargs("txt", kwargs)) 55 | if not vid_only 56 | else None 57 | ) 58 | 59 | def forward( 60 | self, 61 | vid: torch.FloatTensor, 62 | txt: torch.FloatTensor, 63 | *args, 64 | **kwargs, 65 | ) -> Tuple[ 66 | torch.FloatTensor, 67 | torch.FloatTensor, 68 | ]: 69 | vid_module = self.vid if not self.shared_weights else self.all 70 | vid = vid_module(vid, *get_args("vid", args), **get_kwargs("vid", kwargs)) 71 | if not self.vid_only: 72 | txt_module = self.txt if not self.shared_weights else self.all 73 | txt = txt_module(txt, *get_args("txt", args), **get_kwargs("txt", kwargs)) 74 | return vid, txt 75 | -------------------------------------------------------------------------------- /src/models/video_vae_v3/modules/types.py: -------------------------------------------------------------------------------- 1 | # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 2 | # // 3 | # // Licensed under the Apache License, Version 2.0 (the "License"); 4 | # // you may not use this file except in compliance with the License. 5 | # // You may obtain a copy of the License at 6 | # // 7 | # // http://www.apache.org/licenses/LICENSE-2.0 8 | # // 9 | # // Unless required by applicable law or agreed to in writing, software 10 | # // distributed under the License is distributed on an "AS IS" BASIS, 11 | # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # // See the License for the specific language governing permissions and 13 | # // limitations under the License. 14 | 15 | from enum import Enum 16 | from typing import Dict, Literal, NamedTuple, Optional 17 | import torch 18 | 19 | _receptive_field_t = Literal["half", "full"] 20 | _inflation_mode_t = Literal["none", "tail", "replicate"] 21 | _memory_device_t = Optional[Literal["cpu", "same"]] 22 | _gradient_checkpointing_t = Optional[Literal["half", "full"]] 23 | _selective_checkpointing_t = Optional[Literal["coarse", "fine"]] 24 | 25 | class DiagonalGaussianDistribution: 26 | def __init__(self, mean: torch.Tensor, logvar: torch.Tensor): 27 | self.mean = mean 28 | self.logvar = torch.clamp(logvar, -30.0, 20.0) 29 | self.std = torch.exp(0.5 * self.logvar) 30 | self.var = torch.exp(self.logvar) 31 | 32 | def mode(self) -> torch.Tensor: 33 | return self.mean 34 | 35 | def sample(self) -> torch.FloatTensor: 36 | return self.mean + self.std * torch.randn_like(self.mean) 37 | 38 | def kl(self) -> torch.Tensor: 39 | return 0.5 * torch.sum( 40 | self.mean**2 + self.var - 1.0 - self.logvar, 41 | dim=list(range(1, self.mean.ndim)), 42 | ) 43 | 44 | class MemoryState(Enum): 45 | """ 46 | State[Disabled]: No memory bank will be enabled. 47 | State[Initializing]: The model is handling the first clip, need to reset the memory bank. 48 | State[Active]: There has been some data in the memory bank. 49 | State[Unset]: Error state, indicating users didn't pass correct memory state in. 50 | """ 51 | 52 | DISABLED = 0 53 | INITIALIZING = 1 54 | ACTIVE = 2 55 | UNSET = 3 56 | 57 | 58 | class QuantizerOutput(NamedTuple): 59 | latent: torch.Tensor 60 | extra_loss: torch.Tensor 61 | statistics: Dict[str, torch.Tensor] 62 | 63 | 64 | class CausalAutoencoderOutput(NamedTuple): 65 | sample: torch.Tensor 66 | latent: torch.Tensor 67 | posterior: Optional[DiagonalGaussianDistribution] 68 | 69 | 70 | class CausalEncoderOutput(NamedTuple): 71 | latent: torch.Tensor 72 | posterior: Optional[DiagonalGaussianDistribution] 73 | 74 | 75 | class CausalDecoderOutput(NamedTuple): 76 | sample: torch.Tensor 77 | -------------------------------------------------------------------------------- /src/common/diffusion/utils.py: -------------------------------------------------------------------------------- 1 | # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 2 | # // 3 | # // Licensed under the Apache License, Version 2.0 (the "License"); 4 | # // you may not use this file except in compliance with the License. 5 | # // You may obtain a copy of the License at 6 | # // 7 | # // http://www.apache.org/licenses/LICENSE-2.0 8 | # // 9 | # // Unless required by applicable law or agreed to in writing, software 10 | # // distributed under the License is distributed on an "AS IS" BASIS, 11 | # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # // See the License for the specific language governing permissions and 13 | # // limitations under the License. 14 | 15 | """ 16 | Utility functions. 17 | """ 18 | 19 | from typing import Callable 20 | import torch 21 | 22 | 23 | def expand_dims(tensor: torch.Tensor, ndim: int): 24 | """ 25 | Expand tensor to target ndim. New dims are added to the right. 26 | For example, if the tensor shape was (8,), target ndim is 4, return (8, 1, 1, 1). 27 | """ 28 | shape = tensor.shape + (1,) * (ndim - tensor.ndim) 29 | return tensor.reshape(shape) 30 | 31 | 32 | def assert_schedule_timesteps_compatible(schedule, timesteps): 33 | """ 34 | Check if schedule and timesteps are compatible. 35 | """ 36 | if schedule.T != timesteps.T: 37 | raise ValueError("Schedule and timesteps must have the same T.") 38 | if schedule.is_continuous() != timesteps.is_continuous(): 39 | raise ValueError("Schedule and timesteps must have the same continuity.") 40 | 41 | 42 | def classifier_free_guidance( 43 | pos: torch.Tensor, 44 | neg: torch.Tensor, 45 | scale: float, 46 | rescale: float = 0.0, 47 | ): 48 | """ 49 | Apply classifier-free guidance. 50 | """ 51 | # Classifier-free guidance (https://arxiv.org/abs/2207.12598) 52 | cfg = neg + scale * (pos - neg) 53 | 54 | # Classifier-free guidance rescale (https://arxiv.org/pdf/2305.08891.pdf) 55 | if rescale != 0.0: 56 | pos_std = pos.std(dim=list(range(1, pos.ndim)), keepdim=True) 57 | cfg_std = cfg.std(dim=list(range(1, cfg.ndim)), keepdim=True) 58 | factor = pos_std / cfg_std 59 | factor = rescale * factor + (1 - rescale) 60 | cfg *= factor 61 | 62 | return cfg 63 | 64 | 65 | def classifier_free_guidance_dispatcher( 66 | pos: Callable, 67 | neg: Callable, 68 | scale: float, 69 | rescale: float = 0.0, 70 | ): 71 | """ 72 | Optionally execute models depending on classifer-free guidance scale. 73 | """ 74 | # If scale is 1, no need to execute neg model. 75 | if scale == 1.0: 76 | return pos() 77 | 78 | # Otherwise, execute both pos nad neg models and apply cfg. 79 | return classifier_free_guidance( 80 | pos=pos(), 81 | neg=neg(), 82 | scale=scale, 83 | rescale=rescale, 84 | ) 85 | -------------------------------------------------------------------------------- /src/utils/downloads.py: -------------------------------------------------------------------------------- 1 | """ 2 | Downloads utility module for SeedVR2 3 | Handles model and VAE downloads from HuggingFace repositories 4 | """ 5 | 6 | import os 7 | import urllib.error 8 | from typing import Optional 9 | from torchvision.datasets.utils import download_url 10 | 11 | from src.utils.model_registry import get_model_repo, DEFAULT_VAE 12 | from src.utils.constants import get_base_cache_dir 13 | 14 | # HuggingFace URL template 15 | HUGGINGFACE_BASE_URL = "https://huggingface.co/{repo}/resolve/main/{filename}" 16 | 17 | def download_weight(model: str, model_dir: Optional[str] = None, debug=None) -> bool: 18 | """ 19 | Download a SeedVR2 model and its associated VAE from HuggingFace Hub 20 | 21 | Args: 22 | model: Model filename to download 23 | model_dir: Optional custom directory for models 24 | debug: Optional Debug instance for logging 25 | """ 26 | # Setup paths 27 | cache_dir = model_dir or get_base_cache_dir() 28 | model_path = os.path.join(cache_dir, model) 29 | vae_path = os.path.join(cache_dir, DEFAULT_VAE) 30 | is_gguf = model.endswith('.gguf') 31 | # Download model if not exists 32 | if not os.path.exists(model_path): 33 | repo = get_model_repo(model, gguf=is_gguf) 34 | url = HUGGINGFACE_BASE_URL.format(repo=repo, filename=model) 35 | 36 | if debug: 37 | debug.log(f"Downloading {model} from HF {repo}...", category="download", force=True) 38 | try: 39 | download_url(url, cache_dir, filename=model) 40 | if debug: 41 | debug.log(f"Downloaded: {model}", category="success", force=True) 42 | except (urllib.error.HTTPError, urllib.error.URLError) as e: 43 | if debug: 44 | debug.log(f"Model download failed: {e}", level="ERROR", category="download", force=True) 45 | debug.log(f"Please download model manually from: https://huggingface.co/{repo}", category="info", force=True) 46 | debug.log(f"and place it in: {cache_dir}", category="info", force=True) 47 | return False 48 | 49 | # Download VAE if not exists 50 | if not os.path.exists(vae_path): 51 | vae_repo = get_model_repo(DEFAULT_VAE, gguf=is_gguf) 52 | vae_url = HUGGINGFACE_BASE_URL.format(repo=vae_repo, filename=DEFAULT_VAE) 53 | 54 | if debug: 55 | debug.log(f"Downloading VAE: {DEFAULT_VAE} from HF {vae_repo}...", category="download", force=True) 56 | try: 57 | download_url(vae_url, cache_dir, filename=DEFAULT_VAE) 58 | if debug: 59 | debug.log(f"Downloaded: {DEFAULT_VAE}", category="success", force=True) 60 | except (urllib.error.HTTPError, urllib.error.URLError) as e: 61 | if debug: 62 | debug.log(f"VAE download failed: {e}", level="ERROR", category="download", force=True) 63 | debug.log(f"Please download VAE manually from: https://huggingface.co/{vae_repo}", category="info", force=True) 64 | debug.log(f"and place it in: {cache_dir}", category="info", force=True) 65 | return False 66 | 67 | return True -------------------------------------------------------------------------------- /src/models/dit/window.py: -------------------------------------------------------------------------------- 1 | # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 2 | # // 3 | # // Licensed under the Apache License, Version 2.0 (the "License"); 4 | # // you may not use this file except in compliance with the License. 5 | # // You may obtain a copy of the License at 6 | # // 7 | # // http://www.apache.org/licenses/LICENSE-2.0 8 | # // 9 | # // Unless required by applicable law or agreed to in writing, software 10 | # // distributed under the License is distributed on an "AS IS" BASIS, 11 | # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # // See the License for the specific language governing permissions and 13 | # // limitations under the License. 14 | 15 | from math import ceil 16 | from typing import Tuple 17 | import math 18 | 19 | def get_window_op(name: str): 20 | if name == "720pwin_by_size_bysize": 21 | return make_720Pwindows_bysize 22 | if name == "720pswin_by_size_bysize": 23 | return make_shifted_720Pwindows_bysize 24 | raise ValueError(f"Unknown windowing method: {name}") 25 | 26 | 27 | # -------------------------------- Windowing -------------------------------- # 28 | def make_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tuple[int, int, int]): 29 | t, h, w = size 30 | resized_nt, resized_nh, resized_nw = num_windows 31 | #cal windows under 720p 32 | scale = math.sqrt((45 * 80) / (h * w)) 33 | resized_h, resized_w = round(h * scale), round(w * scale) 34 | wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) # window size. 35 | wt = ceil(min(t, 30) / resized_nt) # window size. 36 | nt, nh, nw = ceil(t / wt), ceil(h / wh), ceil(w / ww) # window size. 37 | return [ 38 | ( 39 | slice(it * wt, min((it + 1) * wt, t)), 40 | slice(ih * wh, min((ih + 1) * wh, h)), 41 | slice(iw * ww, min((iw + 1) * ww, w)), 42 | ) 43 | for iw in range(nw) 44 | if min((iw + 1) * ww, w) > iw * ww 45 | for ih in range(nh) 46 | if min((ih + 1) * wh, h) > ih * wh 47 | for it in range(nt) 48 | if min((it + 1) * wt, t) > it * wt 49 | ] 50 | 51 | def make_shifted_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tuple[int, int, int]): 52 | t, h, w = size 53 | resized_nt, resized_nh, resized_nw = num_windows 54 | #cal windows under 720p 55 | scale = math.sqrt((45 * 80) / (h * w)) 56 | resized_h, resized_w = round(h * scale), round(w * scale) 57 | wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) # window size. 58 | wt = ceil(min(t, 30) / resized_nt) # window size. 59 | 60 | st, sh, sw = ( # shift size. 61 | 0.5 if wt < t else 0, 62 | 0.5 if wh < h else 0, 63 | 0.5 if ww < w else 0, 64 | ) 65 | nt, nh, nw = ceil((t - st) / wt), ceil((h - sh) / wh), ceil((w - sw) / ww) # window size. 66 | nt, nh, nw = ( # number of window. 67 | nt + 1 if st > 0 else 1, 68 | nh + 1 if sh > 0 else 1, 69 | nw + 1 if sw > 0 else 1, 70 | ) 71 | return [ 72 | ( 73 | slice(max(int((it - st) * wt), 0), min(int((it - st + 1) * wt), t)), 74 | slice(max(int((ih - sh) * wh), 0), min(int((ih - sh + 1) * wh), h)), 75 | slice(max(int((iw - sw) * ww), 0), min(int((iw - sw + 1) * ww), w)), 76 | ) 77 | for iw in range(nw) 78 | if min(int((iw - sw + 1) * ww), w) > max(int((iw - sw) * ww), 0) 79 | for ih in range(nh) 80 | if min(int((ih - sh + 1) * wh), h) > max(int((ih - sh) * wh), 0) 81 | for it in range(nt) 82 | if min(int((it - st + 1) * wt), t) > max(int((it - st) * wt), 0) 83 | ] 84 | -------------------------------------------------------------------------------- /src/models/dit_v2/window.py: -------------------------------------------------------------------------------- 1 | # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 2 | # // 3 | # // Licensed under the Apache License, Version 2.0 (the "License"); 4 | # // you may not use this file except in compliance with the License. 5 | # // You may obtain a copy of the License at 6 | # // 7 | # // http://www.apache.org/licenses/LICENSE-2.0 8 | # // 9 | # // Unless required by applicable law or agreed to in writing, software 10 | # // distributed under the License is distributed on an "AS IS" BASIS, 11 | # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # // See the License for the specific language governing permissions and 13 | # // limitations under the License. 14 | 15 | from math import ceil 16 | from typing import Tuple 17 | import math 18 | 19 | def get_window_op(name: str): 20 | if name == "720pwin_by_size_bysize": 21 | return make_720Pwindows_bysize 22 | if name == "720pswin_by_size_bysize": 23 | return make_shifted_720Pwindows_bysize 24 | raise ValueError(f"Unknown windowing method: {name}") 25 | 26 | 27 | # -------------------------------- Windowing -------------------------------- # 28 | def make_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tuple[int, int, int]): 29 | t, h, w = size 30 | resized_nt, resized_nh, resized_nw = num_windows 31 | #cal windows under 720p 32 | scale = math.sqrt((45 * 80) / (h * w)) 33 | resized_h, resized_w = round(h * scale), round(w * scale) 34 | wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) # window size. 35 | wt = ceil(min(t, 30) / resized_nt) # window size. 36 | nt, nh, nw = ceil(t / wt), ceil(h / wh), ceil(w / ww) # window size. 37 | return [ 38 | ( 39 | slice(it * wt, min((it + 1) * wt, t)), 40 | slice(ih * wh, min((ih + 1) * wh, h)), 41 | slice(iw * ww, min((iw + 1) * ww, w)), 42 | ) 43 | for iw in range(nw) 44 | if min((iw + 1) * ww, w) > iw * ww 45 | for ih in range(nh) 46 | if min((ih + 1) * wh, h) > ih * wh 47 | for it in range(nt) 48 | if min((it + 1) * wt, t) > it * wt 49 | ] 50 | 51 | def make_shifted_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tuple[int, int, int]): 52 | t, h, w = size 53 | resized_nt, resized_nh, resized_nw = num_windows 54 | #cal windows under 720p 55 | scale = math.sqrt((45 * 80) / (h * w)) 56 | resized_h, resized_w = round(h * scale), round(w * scale) 57 | wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) # window size. 58 | wt = ceil(min(t, 30) / resized_nt) # window size. 59 | 60 | st, sh, sw = ( # shift size. 61 | 0.5 if wt < t else 0, 62 | 0.5 if wh < h else 0, 63 | 0.5 if ww < w else 0, 64 | ) 65 | nt, nh, nw = ceil((t - st) / wt), ceil((h - sh) / wh), ceil((w - sw) / ww) # window size. 66 | nt, nh, nw = ( # number of window. 67 | nt + 1 if st > 0 else 1, 68 | nh + 1 if sh > 0 else 1, 69 | nw + 1 if sw > 0 else 1, 70 | ) 71 | return [ 72 | ( 73 | slice(max(int((it - st) * wt), 0), min(int((it - st + 1) * wt), t)), 74 | slice(max(int((ih - sh) * wh), 0), min(int((ih - sh + 1) * wh), h)), 75 | slice(max(int((iw - sw) * ww), 0), min(int((iw - sw + 1) * ww), w)), 76 | ) 77 | for iw in range(nw) 78 | if min(int((iw - sw + 1) * ww), w) > max(int((iw - sw) * ww), 0) 79 | for ih in range(nh) 80 | if min(int((ih - sh + 1) * wh), h) > max(int((ih - sh) * wh), 0) 81 | for it in range(nt) 82 | if min(int((it - st + 1) * wt), t) > max(int((it - st) * wt), 0) 83 | ] 84 | -------------------------------------------------------------------------------- /src/common/diffusion/samplers/base.py: -------------------------------------------------------------------------------- 1 | # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 2 | # // 3 | # // Licensed under the Apache License, Version 2.0 (the "License"); 4 | # // you may not use this file except in compliance with the License. 5 | # // You may obtain a copy of the License at 6 | # // 7 | # // http://www.apache.org/licenses/LICENSE-2.0 8 | # // 9 | # // Unless required by applicable law or agreed to in writing, software 10 | # // distributed under the License is distributed on an "AS IS" BASIS, 11 | # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # // See the License for the specific language governing permissions and 13 | # // limitations under the License. 14 | 15 | """ 16 | Sampler base class. 17 | """ 18 | 19 | from abc import ABC, abstractmethod 20 | from dataclasses import dataclass 21 | from typing import Callable 22 | import torch 23 | from tqdm import tqdm 24 | 25 | from ..schedules.base import Schedule 26 | from ..timesteps.base import SamplingTimesteps 27 | from ..types import PredictionType, SamplingDirection 28 | from ..utils import assert_schedule_timesteps_compatible 29 | 30 | 31 | @dataclass 32 | class SamplerModelArgs: 33 | x_t: torch.Tensor 34 | t: torch.Tensor 35 | i: int 36 | 37 | 38 | class Sampler(ABC): 39 | """ 40 | Samplers are ODE/SDE solvers. 41 | """ 42 | 43 | def __init__( 44 | self, 45 | schedule: Schedule, 46 | timesteps: SamplingTimesteps, 47 | prediction_type: PredictionType, 48 | return_endpoint: bool = True, 49 | ): 50 | assert_schedule_timesteps_compatible( 51 | schedule=schedule, 52 | timesteps=timesteps, 53 | ) 54 | self.schedule = schedule 55 | self.timesteps = timesteps 56 | self.prediction_type = prediction_type 57 | self.return_endpoint = return_endpoint 58 | 59 | @abstractmethod 60 | def sample( 61 | self, 62 | x: torch.Tensor, 63 | f: Callable[[SamplerModelArgs], torch.Tensor], 64 | ) -> torch.Tensor: 65 | """ 66 | Generate a new sample given the the intial sample x and score function f. 67 | """ 68 | 69 | def get_next_timestep( 70 | self, 71 | t: torch.Tensor, 72 | ) -> torch.Tensor: 73 | """ 74 | Get the next sample timestep. 75 | Support multiple different timesteps t in a batch. 76 | If no more steps, return out of bound value -1 or T+1. 77 | """ 78 | T = self.timesteps.T 79 | steps = len(self.timesteps) 80 | curr_idx = self.timesteps.index(t) 81 | next_idx = curr_idx + 1 82 | bound = -1 if self.timesteps.direction == SamplingDirection.backward else T + 1 83 | 84 | s = self.timesteps[next_idx.clamp_max(steps - 1)] 85 | s = s.where(next_idx < steps, bound) 86 | return s 87 | 88 | def get_endpoint( 89 | self, 90 | pred: torch.Tensor, 91 | x_t: torch.Tensor, 92 | t: torch.Tensor, 93 | ) -> torch.Tensor: 94 | """ 95 | Get to the endpoint of the probability flow. 96 | """ 97 | x_0, x_T = self.schedule.convert_from_pred(pred, self.prediction_type, x_t, t) 98 | return x_0 if self.timesteps.direction == SamplingDirection.backward else x_T 99 | 100 | def get_progress_bar(self): 101 | """ 102 | Get progress bar for sampling. 103 | """ 104 | return tqdm( 105 | iterable=range(len(self.timesteps) - (0 if self.return_endpoint else 1)), 106 | dynamic_ncols=True, 107 | desc=self.__class__.__name__, 108 | ) 109 | -------------------------------------------------------------------------------- /src/models/dit/patch.py: -------------------------------------------------------------------------------- 1 | # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 2 | # // 3 | # // Licensed under the Apache License, Version 2.0 (the "License"); 4 | # // you may not use this file except in compliance with the License. 5 | # // You may obtain a copy of the License at 6 | # // 7 | # // http://www.apache.org/licenses/LICENSE-2.0 8 | # // 9 | # // Unless required by applicable law or agreed to in writing, software 10 | # // distributed under the License is distributed on an "AS IS" BASIS, 11 | # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # // See the License for the specific language governing permissions and 13 | # // limitations under the License. 14 | 15 | from typing import Tuple, Union 16 | import torch 17 | from einops import rearrange 18 | from torch import nn 19 | from torch.nn.modules.utils import _triple 20 | 21 | from ...common.cache import Cache 22 | from ...common.distributed.ops import gather_outputs, slice_inputs 23 | 24 | from . import na 25 | 26 | 27 | class PatchIn(nn.Module): 28 | def __init__( 29 | self, 30 | in_channels: int, 31 | patch_size: Union[int, Tuple[int, int, int]], 32 | dim: int, 33 | ): 34 | super().__init__() 35 | t, h, w = _triple(patch_size) 36 | self.patch_size = t, h, w 37 | self.proj = nn.Linear(in_channels * t * h * w, dim) 38 | 39 | def forward( 40 | self, 41 | vid: torch.Tensor, 42 | ) -> torch.Tensor: 43 | t, h, w = self.patch_size 44 | vid = rearrange(vid, "b c (T t) (H h) (W w) -> b T H W (t h w c)", t=t, h=h, w=w) 45 | vid = self.proj(vid) 46 | return vid 47 | 48 | 49 | class PatchOut(nn.Module): 50 | def __init__( 51 | self, 52 | out_channels: int, 53 | patch_size: Union[int, Tuple[int, int, int]], 54 | dim: int, 55 | ): 56 | super().__init__() 57 | t, h, w = _triple(patch_size) 58 | self.patch_size = t, h, w 59 | self.proj = nn.Linear(dim, out_channels * t * h * w) 60 | 61 | def forward( 62 | self, 63 | vid: torch.Tensor, 64 | ) -> torch.Tensor: 65 | t, h, w = self.patch_size 66 | vid = self.proj(vid) 67 | vid = rearrange(vid, "b T H W (t h w c) -> b c (T t) (H h) (W w)", t=t, h=h, w=w) 68 | return vid 69 | 70 | 71 | class NaPatchIn(PatchIn): 72 | def forward( 73 | self, 74 | vid: torch.Tensor, # l c 75 | vid_shape: torch.LongTensor, 76 | ) -> torch.Tensor: 77 | t, h, w = self.patch_size 78 | if not (t == h == w == 1): 79 | vid, vid_shape = na.rearrange( 80 | vid, vid_shape, "(T t) (H h) (W w) c -> T H W (t h w c)", t=t, h=h, w=w 81 | ) 82 | # slice vid after patching in when using sequence parallelism 83 | vid = slice_inputs(vid, dim=0) 84 | vid = self.proj(vid) 85 | return vid, vid_shape 86 | 87 | 88 | class NaPatchOut(PatchOut): 89 | def forward( 90 | self, 91 | vid: torch.FloatTensor, # l c 92 | vid_shape: torch.LongTensor, 93 | cache: Cache = Cache(disable=True), 94 | ) -> Tuple[ 95 | torch.FloatTensor, 96 | torch.LongTensor, 97 | ]: 98 | t, h, w = self.patch_size 99 | vid = self.proj(vid) 100 | # gather vid before patching out when enabling sequence parallelism 101 | vid = gather_outputs( 102 | vid, 103 | gather_dim=0, 104 | padding_dim=0, 105 | unpad_shape=vid_shape, 106 | cache=cache.namespace("vid"), 107 | ) 108 | if not (t == h == w == 1): 109 | vid, vid_shape = na.rearrange( 110 | vid, vid_shape, "T H W (t h w c) -> (T t) (H h) (W w) c", t=t, h=h, w=w 111 | ) 112 | return vid, vid_shape 113 | -------------------------------------------------------------------------------- /src/models/dit/rope.py: -------------------------------------------------------------------------------- 1 | # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 2 | # // 3 | # // Licensed under the Apache License, Version 2.0 (the "License"); 4 | # // you may not use this file except in compliance with the License. 5 | # // You may obtain a copy of the License at 6 | # // 7 | # // http://www.apache.org/licenses/LICENSE-2.0 8 | # // 9 | # // Unless required by applicable law or agreed to in writing, software 10 | # // distributed under the License is distributed on an "AS IS" BASIS, 11 | # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # // See the License for the specific language governing permissions and 13 | # // limitations under the License. 14 | 15 | from functools import lru_cache 16 | from typing import Tuple 17 | import torch 18 | from einops import rearrange 19 | from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb 20 | from torch import nn 21 | 22 | from ...common.cache import Cache 23 | 24 | 25 | class RotaryEmbeddingBase(nn.Module): 26 | def __init__(self, dim: int, rope_dim: int): 27 | super().__init__() 28 | self.rope = RotaryEmbedding( 29 | dim=dim // rope_dim, 30 | freqs_for="pixel", 31 | max_freq=256, 32 | ) 33 | # 1. Set model.requires_grad_(True) after model creation will make 34 | # the `requires_grad=False` for rope freqs no longer hold. 35 | # 2. Even if we don't set requires_grad_(True) explicitly, 36 | # FSDP is not memory efficient when handling fsdp_wrap 37 | # with mixed requires_grad=True/False. 38 | # With above consideration, it is easier just remove the freqs 39 | # out of nn.Parameters when `learned_freq=False` 40 | freqs = self.rope.freqs 41 | del self.rope.freqs 42 | self.rope.register_buffer("freqs", freqs.data) 43 | 44 | @lru_cache(maxsize=128) 45 | def get_axial_freqs(self, *dims): 46 | return self.rope.get_axial_freqs(*dims) 47 | 48 | 49 | class RotaryEmbedding3d(RotaryEmbeddingBase): 50 | def __init__(self, dim: int): 51 | super().__init__(dim, rope_dim=3) 52 | 53 | def forward( 54 | self, 55 | q: torch.FloatTensor, # b h l d 56 | k: torch.FloatTensor, # b h l d 57 | size: Tuple[int, int, int], 58 | ) -> Tuple[ 59 | torch.FloatTensor, 60 | torch.FloatTensor, 61 | ]: 62 | T, H, W = size 63 | freqs = self.get_axial_freqs(T, H, W) 64 | q = rearrange(q, "b h (T H W) d -> b h T H W d", T=T, H=H, W=W) 65 | k = rearrange(k, "b h (T H W) d -> b h T H W d", T=T, H=H, W=W) 66 | q = apply_rotary_emb(freqs, q) 67 | k = apply_rotary_emb(freqs, k) 68 | q = rearrange(q, "b h T H W d -> b h (T H W) d") 69 | k = rearrange(k, "b h T H W d -> b h (T H W) d") 70 | return q, k 71 | 72 | 73 | class NaRotaryEmbedding3d(RotaryEmbedding3d): 74 | def forward( 75 | self, 76 | q: torch.FloatTensor, # L h d 77 | k: torch.FloatTensor, # L h d 78 | shape: torch.LongTensor, 79 | cache: Cache, 80 | ) -> Tuple[ 81 | torch.FloatTensor, 82 | torch.FloatTensor, 83 | ]: 84 | freqs = cache("rope_freqs_3d", lambda: self.get_freqs(shape)) 85 | freqs = freqs.to(device=q.device, dtype=q.dtype) 86 | q = rearrange(q, "L h d -> h L d") 87 | k = rearrange(k, "L h d -> h L d") 88 | q = apply_rotary_emb(freqs, q.float()).to(q.dtype) 89 | k = apply_rotary_emb(freqs, k.float()).to(k.dtype) 90 | q = rearrange(q, "h L d -> L h d") 91 | k = rearrange(k, "h L d -> L h d") 92 | return q, k 93 | 94 | def get_freqs( 95 | self, 96 | shape: torch.LongTensor, 97 | ) -> torch.Tensor: 98 | freq_list = [] 99 | for f, h, w in shape.tolist(): 100 | freqs = self.get_axial_freqs(f, h, w) 101 | freq_list.append(freqs.view(-1, freqs.size(-1))) 102 | return torch.cat(freq_list, dim=0) 103 | -------------------------------------------------------------------------------- /src/models/dit_v2/attention.py: -------------------------------------------------------------------------------- 1 | # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 2 | # // 3 | # // Licensed under the Apache License, Version 2.0 (the "License"); 4 | # // you may not use this file except in compliance with the License. 5 | # // You may obtain a copy of the License at 6 | # // 7 | # // http://www.apache.org/licenses/LICENSE-2.0 8 | # // 9 | # // Unless required by applicable law or agreed to in writing, software 10 | # // distributed under the License is distributed on an "AS IS" BASIS, 11 | # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # // See the License for the specific language governing permissions and 13 | # // limitations under the License. 14 | 15 | import torch 16 | import torch.nn.functional as F 17 | 18 | #from flash_attn import flash_attn_varlen_func 19 | 20 | from torch import nn 21 | 22 | 23 | def pytorch_varlen_attention(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p=0.0, softmax_scale=None, causal=False, deterministic=False): 24 | """ 25 | A PyTorch-based implementation of variable-length attention to replace flash_attn_varlen_func. 26 | It processes each sequence in the batch individually. 27 | """ 28 | # Create an empty tensor to store the output. 29 | output = torch.empty_like(q) 30 | 31 | # Iterate over each sequence in the batch. The batch size is the number of sequences. 32 | for i in range(len(cu_seqlens_q) - 1): 33 | # Determine the start and end indices for the current sequence. 34 | start_q, end_q = cu_seqlens_q[i], cu_seqlens_q[i+1] 35 | start_k, end_k = cu_seqlens_k[i], cu_seqlens_k[i+1] 36 | 37 | # Slice the q, k, and v tensors to get the data for the current sequence. 38 | # The shape is (seq_len, heads, head_dim). 39 | q_i = q[start_q:end_q] 40 | k_i = k[start_k:end_k] 41 | v_i = v[start_k:end_k] 42 | 43 | # Reshape for torch's scaled_dot_product_attention which expects (batch, heads, seq, dim). 44 | # Here, we treat each sequence as a batch of 1. 45 | q_i = q_i.permute(1, 0, 2).unsqueeze(0) # (1, heads, seq_len_q, head_dim) 46 | k_i = k_i.permute(1, 0, 2).unsqueeze(0) # (1, heads, seq_len_k, head_dim) 47 | v_i = v_i.permute(1, 0, 2).unsqueeze(0) # (1, heads, seq_len_k, head_dim) 48 | 49 | # Use PyTorch's built-in scaled dot-product attention. 50 | output_i = F.scaled_dot_product_attention( 51 | q_i, k_i, v_i, 52 | dropout_p=dropout_p if not deterministic else 0.0, 53 | is_causal=causal 54 | ) 55 | 56 | # Reshape the output back to the original format (seq_len, heads, head_dim) 57 | output_i = output_i.squeeze(0).permute(1, 0, 2) 58 | 59 | # Place the result for the current sequence into the main output tensor. 60 | output[start_q:end_q] = output_i 61 | 62 | return output 63 | 64 | class TorchAttention(nn.Module): 65 | def tflops(self, args, kwargs, output) -> float: 66 | assert len(args) == 0 or len(args) > 2, "query, key should both provided by args / kwargs" 67 | q = kwargs.get("query") or args[0] 68 | k = kwargs.get("key") or args[1] 69 | b, h, sq, d = q.shape 70 | b, h, sk, d = k.shape 71 | return b * h * (4 * d * (sq / 1e6) * (sk / 1e6)) 72 | 73 | def forward(self, *args, **kwargs): 74 | return F.scaled_dot_product_attention(*args, **kwargs) 75 | 76 | 77 | class FlashAttentionVarlen(nn.Module): 78 | def tflops(self, args, kwargs, output) -> float: 79 | cu_seqlens_q = kwargs["cu_seqlens_q"] 80 | cu_seqlens_k = kwargs["cu_seqlens_k"] 81 | _, h, d = output.shape 82 | seqlens_q = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]) / 1e6 83 | seqlens_k = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]) / 1e6 84 | return h * (4 * d * (seqlens_q * seqlens_k).sum()) 85 | 86 | def forward(self, *args, **kwargs): 87 | kwargs["deterministic"] = torch.are_deterministic_algorithms_enabled() 88 | try: 89 | from flash_attn import flash_attn_varlen_func 90 | return flash_attn_varlen_func(*args, **kwargs) 91 | except ImportError: 92 | return pytorch_varlen_attention(*args, **kwargs) 93 | 94 | -------------------------------------------------------------------------------- /src/models/dit/attention.py: -------------------------------------------------------------------------------- 1 | # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 2 | # // 3 | # // Licensed under the Apache License, Version 2.0 (the "License"); 4 | # // you may not use this file except in compliance with the License. 5 | # // You may obtain a copy of the License at 6 | # // 7 | # // http://www.apache.org/licenses/LICENSE-2.0 8 | # // 9 | # // Unless required by applicable law or agreed to in writing, software 10 | # // distributed under the License is distributed on an "AS IS" BASIS, 11 | # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # // See the License for the specific language governing permissions and 13 | # // limitations under the License. 14 | 15 | import torch 16 | import torch.nn.functional as F 17 | 18 | #from flash_attn import flash_attn_varlen_func 19 | 20 | from torch import nn 21 | 22 | 23 | def pytorch_varlen_attention(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p=0.0, softmax_scale=None, causal=False, deterministic=False): 24 | """ 25 | A PyTorch-based implementation of variable-length attention to replace flash_attn_varlen_func. 26 | It processes each sequence in the batch individually. 27 | """ 28 | # Create an empty tensor to store the output. 29 | output = torch.empty_like(q) 30 | 31 | # Iterate over each sequence in the batch. The batch size is the number of sequences. 32 | for i in range(len(cu_seqlens_q) - 1): 33 | # Determine the start and end indices for the current sequence. 34 | start_q, end_q = cu_seqlens_q[i], cu_seqlens_q[i+1] 35 | start_k, end_k = cu_seqlens_k[i], cu_seqlens_k[i+1] 36 | 37 | # Slice the q, k, and v tensors to get the data for the current sequence. 38 | # The shape is (seq_len, heads, head_dim). 39 | q_i = q[start_q:end_q] 40 | k_i = k[start_k:end_k] 41 | v_i = v[start_k:end_k] 42 | 43 | # Reshape for torch's scaled_dot_product_attention which expects (batch, heads, seq, dim). 44 | # Here, we treat each sequence as a batch of 1. 45 | q_i = q_i.permute(1, 0, 2).unsqueeze(0) # (1, heads, seq_len_q, head_dim) 46 | k_i = k_i.permute(1, 0, 2).unsqueeze(0) # (1, heads, seq_len_k, head_dim) 47 | v_i = v_i.permute(1, 0, 2).unsqueeze(0) # (1, heads, seq_len_k, head_dim) 48 | 49 | # Use PyTorch's built-in scaled dot-product attention. 50 | output_i = F.scaled_dot_product_attention( 51 | q_i, k_i, v_i, 52 | dropout_p=dropout_p if not deterministic else 0.0, 53 | is_causal=causal 54 | ) 55 | 56 | # Reshape the output back to the original format (seq_len, heads, head_dim) 57 | output_i = output_i.squeeze(0).permute(1, 0, 2) 58 | 59 | # Place the result for the current sequence into the main output tensor. 60 | output[start_q:end_q] = output_i 61 | 62 | return output 63 | 64 | 65 | class TorchAttention(nn.Module): 66 | def tflops(self, args, kwargs, output) -> float: 67 | assert len(args) == 0 or len(args) > 2, "query, key should both provided by args / kwargs" 68 | q = kwargs.get("query") or args[0] 69 | k = kwargs.get("key") or args[1] 70 | b, h, sq, d = q.shape 71 | b, h, sk, d = k.shape 72 | return b * h * (4 * d * (sq / 1e6) * (sk / 1e6)) 73 | 74 | def forward(self, *args, **kwargs): 75 | #return pytorch_varlen_attention(*args, **kwargs) 76 | return F.scaled_dot_product_attention(*args, **kwargs) 77 | 78 | 79 | class FlashAttentionVarlen(nn.Module): 80 | def tflops(self, args, kwargs, output) -> float: 81 | cu_seqlens_q = kwargs["cu_seqlens_q"] 82 | cu_seqlens_k = kwargs["cu_seqlens_k"] 83 | _, h, d = output.shape 84 | seqlens_q = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]) / 1e6 85 | seqlens_k = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]) / 1e6 86 | return h * (4 * d * (seqlens_q * seqlens_k).sum()) 87 | 88 | def forward(self, *args, **kwargs): 89 | kwargs["deterministic"] = torch.are_deterministic_algorithms_enabled() 90 | try: 91 | from flash_attn import flash_attn_varlen_func 92 | return flash_attn_varlen_func(*args, **kwargs) 93 | except ImportError: 94 | return pytorch_varlen_attention(*args, **kwargs) -------------------------------------------------------------------------------- /src/models/video_vae_v3/modules/inflated_layers.py: -------------------------------------------------------------------------------- 1 | # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 2 | # // 3 | # // Licensed under the Apache License, Version 2.0 (the "License"); 4 | # // you may not use this file except in compliance with the License. 5 | # // You may obtain a copy of the License at 6 | # // 7 | # // http://www.apache.org/licenses/LICENSE-2.0 8 | # // 9 | # // Unless required by applicable law or agreed to in writing, software 10 | # // distributed under the License is distributed on an "AS IS" BASIS, 11 | # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # // See the License for the specific language governing permissions and 13 | # // limitations under the License. 14 | 15 | from functools import partial 16 | from typing import Literal, Optional 17 | from torch import Tensor 18 | from torch.nn import Conv3d 19 | 20 | from .inflated_lib import ( 21 | MemoryState, 22 | extend_head, 23 | inflate_bias, 24 | inflate_weight, 25 | modify_state_dict, 26 | ) 27 | 28 | _inflation_mode_t = Literal["none", "tail", "replicate"] 29 | _memory_device_t = Optional[Literal["cpu", "same"]] 30 | 31 | 32 | class InflatedCausalConv3d(Conv3d): 33 | def __init__( 34 | self, 35 | *args, 36 | inflation_mode: _inflation_mode_t, 37 | memory_device: _memory_device_t = "same", 38 | **kwargs, 39 | ): 40 | self.inflation_mode = inflation_mode 41 | self.memory = None 42 | super().__init__(*args, **kwargs) 43 | self.temporal_padding = self.padding[0] 44 | self.memory_device = memory_device 45 | self.padding = (0, *self.padding[1:]) # Remove temporal pad to keep causal. 46 | 47 | def set_memory_device(self, memory_device: _memory_device_t): 48 | self.memory_device = memory_device 49 | 50 | def forward(self, input: Tensor, memory_state: MemoryState = MemoryState.DISABLED) -> Tensor: 51 | mem_size = self.stride[0] - self.kernel_size[0] 52 | if (self.memory is not None) and (memory_state == MemoryState.ACTIVE): 53 | input = extend_head(input, memory=self.memory) 54 | else: 55 | input = extend_head(input, times=self.temporal_padding * 2) 56 | memory = ( 57 | input[:, :, mem_size:].detach() 58 | if (mem_size != 0 and memory_state != MemoryState.DISABLED) 59 | else None 60 | ) 61 | if ( 62 | memory_state != MemoryState.DISABLED 63 | and not self.training 64 | and (self.memory_device is not None) 65 | ): 66 | self.memory = memory 67 | if self.memory_device == "cpu" and self.memory is not None: 68 | self.memory = self.memory.to("cpu") 69 | return super().forward(input) 70 | 71 | def _load_from_state_dict( 72 | self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs 73 | ): 74 | if self.inflation_mode != "none": 75 | state_dict = modify_state_dict( 76 | self, 77 | state_dict, 78 | prefix, 79 | inflate_weight_fn=partial(inflate_weight, position="tail"), 80 | inflate_bias_fn=partial(inflate_bias, position="tail"), 81 | ) 82 | super()._load_from_state_dict( 83 | state_dict, 84 | prefix, 85 | local_metadata, 86 | (strict and self.inflation_mode == "none"), 87 | missing_keys, 88 | unexpected_keys, 89 | error_msgs, 90 | ) 91 | 92 | 93 | def init_causal_conv3d( 94 | *args, 95 | inflation_mode: _inflation_mode_t, 96 | **kwargs, 97 | ): 98 | """ 99 | Initialize a Causal-3D convolution layer. 100 | Parameters: 101 | inflation_mode: Listed as below. It's compatible with all the 3D-VAE checkpoints we have. 102 | - none: No inflation will be conducted. 103 | The loading logic of state dict will fall back to default. 104 | - tail / replicate: Refer to the definition of `InflatedCausalConv3d`. 105 | """ 106 | return InflatedCausalConv3d(*args, inflation_mode=inflation_mode, **kwargs) 107 | -------------------------------------------------------------------------------- /src/models/dit/modulation.py: -------------------------------------------------------------------------------- 1 | # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 2 | # // 3 | # // Licensed under the Apache License, Version 2.0 (the "License"); 4 | # // you may not use this file except in compliance with the License. 5 | # // You may obtain a copy of the License at 6 | # // 7 | # // http://www.apache.org/licenses/LICENSE-2.0 8 | # // 9 | # // Unless required by applicable law or agreed to in writing, software 10 | # // distributed under the License is distributed on an "AS IS" BASIS, 11 | # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # // See the License for the specific language governing permissions and 13 | # // limitations under the License. 14 | 15 | from typing import Callable, List, Optional 16 | import torch 17 | from einops import rearrange 18 | from torch import nn 19 | 20 | from ...common.cache import Cache 21 | from ...common.distributed.ops import slice_inputs 22 | 23 | # (dim: int, emb_dim: int) 24 | ada_layer_type = Callable[[int, int], nn.Module] 25 | 26 | 27 | def get_ada_layer(ada_layer: str) -> ada_layer_type: 28 | if ada_layer == "single": 29 | return AdaSingle 30 | raise NotImplementedError(f"{ada_layer} is not supported") 31 | 32 | 33 | def expand_dims(x: torch.Tensor, dim: int, ndim: int): 34 | """ 35 | Expand tensor "x" to "ndim" by adding empty dims at "dim". 36 | Example: x is (b d), target ndim is 5, add dim at 1, return (b 1 1 1 d). 37 | """ 38 | shape = x.shape 39 | shape = shape[:dim] + (1,) * (ndim - len(shape)) + shape[dim:] 40 | return x.reshape(shape) 41 | 42 | 43 | class AdaSingle(nn.Module): 44 | def __init__( 45 | self, 46 | dim: int, 47 | emb_dim: int, 48 | layers: List[str], 49 | ): 50 | assert emb_dim == 6 * dim, "AdaSingle requires emb_dim == 6 * dim" 51 | super().__init__() 52 | self.dim = dim 53 | self.emb_dim = emb_dim 54 | self.layers = layers 55 | for l in layers: 56 | self.register_parameter(f"{l}_shift", nn.Parameter(torch.randn(dim) / dim**0.5)) 57 | self.register_parameter(f"{l}_scale", nn.Parameter(torch.randn(dim) / dim**0.5 + 1)) 58 | self.register_parameter(f"{l}_gate", nn.Parameter(torch.randn(dim) / dim**0.5)) 59 | 60 | def forward( 61 | self, 62 | hid: torch.FloatTensor, # b ... c 63 | emb: torch.FloatTensor, # b d 64 | layer: str, 65 | mode: str, 66 | cache: Cache = Cache(disable=True), 67 | branch_tag: str = "", 68 | hid_len: Optional[torch.LongTensor] = None, # b 69 | ) -> torch.FloatTensor: 70 | idx = self.layers.index(layer) 71 | emb = rearrange(emb, "b (d l g) -> b d l g", l=len(self.layers), g=3)[..., idx, :] 72 | emb = expand_dims(emb, 1, hid.ndim + 1) 73 | 74 | if hid_len is not None: 75 | emb = cache( 76 | f"emb_repeat_{idx}_{branch_tag}", 77 | lambda: slice_inputs( 78 | torch.cat([e.repeat(l, *([1] * e.ndim)) for e, l in zip(emb, hid_len)]), 79 | dim=0, 80 | ), 81 | ) 82 | 83 | shiftA, scaleA, gateA = emb.unbind(-1) 84 | shiftB, scaleB, gateB = ( 85 | getattr(self, f"{layer}_shift"), 86 | getattr(self, f"{layer}_scale"), 87 | getattr(self, f"{layer}_gate"), 88 | ) 89 | 90 | # Handle potential FP8 parameters - convert to computation dtype 91 | if hasattr(torch, 'float8_e4m3fn'): 92 | fp8_types = (torch.float8_e4m3fn, torch.float8_e5m2) 93 | 94 | # Convert FP8 parameters to BFloat16 for arithmetic operations 95 | if shiftB.dtype in fp8_types: 96 | shiftB = shiftB.to(torch.bfloat16) 97 | if scaleB.dtype in fp8_types: 98 | scaleB = scaleB.to(torch.bfloat16) 99 | if gateB.dtype in fp8_types: 100 | gateB = gateB.to(torch.bfloat16) 101 | 102 | if mode == "in": 103 | return hid.mul_(scaleA + scaleB).add_(shiftA + shiftB) 104 | if mode == "out": 105 | return hid.mul_(gateA + gateB) 106 | 107 | raise NotImplementedError 108 | 109 | def extra_repr(self) -> str: 110 | return f"dim={self.dim}, emb_dim={self.emb_dim}, layers={self.layers}" -------------------------------------------------------------------------------- /src/common/diffusion/samplers/euler.py: -------------------------------------------------------------------------------- 1 | # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 2 | # // 3 | # // Licensed under the Apache License, Version 2.0 (the "License"); 4 | # // you may not use this file except in compliance with the License. 5 | # // You may obtain a copy of the License at 6 | # // 7 | # // http://www.apache.org/licenses/LICENSE-2.0 8 | # // 9 | # // Unless required by applicable law or agreed to in writing, software 10 | # // distributed under the License is distributed on an "AS IS" BASIS, 11 | # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # // See the License for the specific language governing permissions and 13 | # // limitations under the License. 14 | 15 | 16 | """ 17 | Euler ODE solver. 18 | """ 19 | 20 | from typing import Callable 21 | import torch 22 | from einops import rearrange 23 | from torch.nn import functional as F 24 | 25 | #from ....models.dit_v2 import na 26 | 27 | from ..types import PredictionType 28 | from ..utils import expand_dims 29 | from .base import Sampler, SamplerModelArgs 30 | 31 | 32 | class EulerSampler(Sampler): 33 | """ 34 | The Euler method is the simplest ODE solver. 35 | 36 | """ 37 | 38 | def sample( 39 | self, 40 | x: torch.Tensor, 41 | f: Callable[[SamplerModelArgs], torch.Tensor], 42 | ) -> torch.Tensor: 43 | timesteps = self.timesteps.timesteps 44 | progress = self.get_progress_bar() 45 | i = 0 46 | 47 | # Optimisations VRAM 48 | original_dtype = x.dtype 49 | device = x.device 50 | 51 | # Forcer FP16 pour économiser la VRAM 52 | if x.dtype != torch.float16: 53 | x = x.half() 54 | 55 | for t, s in zip(timesteps[:-1], timesteps[1:]): 56 | # Forcer FP16 pour les timesteps 57 | if t.dtype != torch.float16: 58 | t = t.half() 59 | if s.dtype != torch.float16: 60 | s = s.half() 61 | 62 | # Appel du modèle avec monitoring 63 | pred = f(SamplerModelArgs(x, t, i)) 64 | 65 | # Forcer FP16 pour la prédiction 66 | if pred.dtype != torch.float16: 67 | pred = pred.half() 68 | 69 | # Étape suivante 70 | x = self.step_to(pred, x, t, s) 71 | 72 | # Nettoyer les tenseurs temporaires 73 | del pred 74 | if torch.mps.is_available(): 75 | if torch.mps.is_available(): 76 | torch.mps.empty_cache() 77 | else: 78 | if torch.cuda.is_available(): 79 | torch.cuda.empty_cache() 80 | torch.cuda.ipc_collect() 81 | 82 | i += 1 83 | progress.update() 84 | 85 | if self.return_endpoint: 86 | t = timesteps[-1] 87 | if t.dtype != torch.float16: 88 | t = t.half() 89 | pred = f(SamplerModelArgs(x, t, i)) 90 | if pred.dtype != torch.float16: 91 | pred = pred.half() 92 | x = self.get_endpoint(pred, x, t) 93 | del pred 94 | progress.update() 95 | 96 | # Restaurer le dtype original si nécessaire 97 | if original_dtype != torch.float16: 98 | x = x.to(original_dtype) 99 | 100 | return x 101 | 102 | def step( 103 | self, 104 | pred: torch.Tensor, 105 | x_t: torch.Tensor, 106 | t: torch.Tensor, 107 | ) -> torch.Tensor: 108 | """ 109 | Step to the next timestep. 110 | """ 111 | return self.step_to(pred, x_t, t, self.get_next_timestep(t)) 112 | 113 | def step_to( 114 | self, 115 | pred: torch.Tensor, 116 | x_t: torch.Tensor, 117 | t: torch.Tensor, 118 | s: torch.Tensor, 119 | ) -> torch.Tensor: 120 | """ 121 | Steps from x_t at timestep t to x_s at timestep s. Returns x_s. 122 | """ 123 | t = expand_dims(t, x_t.ndim) 124 | s = expand_dims(s, x_t.ndim) 125 | T = self.schedule.T 126 | # Step from x_t to x_s. 127 | pred_x_0, pred_x_T = self.schedule.convert_from_pred(pred, self.prediction_type, x_t, t) 128 | pred_x_s = self.schedule.forward(pred_x_0, pred_x_T, s.clamp(0, T)) 129 | # Clamp x_s to x_0 and x_T if s is out of bound. 130 | pred_x_s = pred_x_s.where(s >= 0, pred_x_0) 131 | pred_x_s = pred_x_s.where(s <= T, pred_x_T) 132 | return pred_x_s 133 | -------------------------------------------------------------------------------- /src/common/diffusion/schedules/base.py: -------------------------------------------------------------------------------- 1 | # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 2 | # // 3 | # // Licensed under the Apache License, Version 2.0 (the "License"); 4 | # // you may not use this file except in compliance with the License. 5 | # // You may obtain a copy of the License at 6 | # // 7 | # // http://www.apache.org/licenses/LICENSE-2.0 8 | # // 9 | # // Unless required by applicable law or agreed to in writing, software 10 | # // distributed under the License is distributed on an "AS IS" BASIS, 11 | # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # // See the License for the specific language governing permissions and 13 | # // limitations under the License. 14 | 15 | """ 16 | Schedule base class. 17 | """ 18 | 19 | from abc import ABC, abstractmethod, abstractproperty 20 | from typing import Tuple, Union 21 | import torch 22 | 23 | from ..types import PredictionType 24 | from ..utils import expand_dims 25 | 26 | 27 | class Schedule(ABC): 28 | """ 29 | Diffusion schedules are uniquely defined by T, A, B: 30 | 31 | x_t = A(t) * x_0 + B(t) * x_T, where t in [0, T] 32 | 33 | Schedules can be continuous or discrete. 34 | """ 35 | 36 | @abstractproperty 37 | def T(self) -> Union[int, float]: 38 | """ 39 | Maximum timestep inclusive. 40 | Schedule is continuous if float, discrete if int. 41 | """ 42 | 43 | @abstractmethod 44 | def A(self, t: torch.Tensor) -> torch.Tensor: 45 | """ 46 | Interpolation coefficient A. 47 | Returns tensor with the same shape as t. 48 | """ 49 | 50 | @abstractmethod 51 | def B(self, t: torch.Tensor) -> torch.Tensor: 52 | """ 53 | Interpolation coefficient B. 54 | Returns tensor with the same shape as t. 55 | """ 56 | 57 | # ---------------------------------------------------- 58 | 59 | def snr(self, t: torch.Tensor) -> torch.Tensor: 60 | """ 61 | Signal to noise ratio. 62 | Returns tensor with the same shape as t. 63 | """ 64 | return (self.A(t) ** 2) / (self.B(t) ** 2) 65 | 66 | def isnr(self, snr: torch.Tensor) -> torch.Tensor: 67 | """ 68 | Inverse signal to noise ratio. 69 | Returns tensor with the same shape as snr. 70 | Subclass may implement. 71 | """ 72 | raise NotImplementedError 73 | 74 | # ---------------------------------------------------- 75 | 76 | def is_continuous(self) -> bool: 77 | """ 78 | Whether the schedule is continuous. 79 | """ 80 | return isinstance(self.T, float) 81 | 82 | def forward(self, x_0: torch.Tensor, x_T: torch.Tensor, t: torch.Tensor) -> torch.Tensor: 83 | """ 84 | Diffusion forward function. 85 | """ 86 | t = expand_dims(t, x_0.ndim) 87 | return self.A(t) * x_0 + self.B(t) * x_T 88 | 89 | def convert_from_pred( 90 | self, pred: torch.Tensor, pred_type: PredictionType, x_t: torch.Tensor, t: torch.Tensor 91 | ) -> Tuple[torch.Tensor, torch.Tensor]: 92 | """ 93 | Convert from prediction. Return predicted x_0 and x_T. 94 | """ 95 | t = expand_dims(t, x_t.ndim) 96 | A_t = self.A(t) 97 | B_t = self.B(t) 98 | 99 | if pred_type == PredictionType.x_T: 100 | pred_x_T = pred 101 | pred_x_0 = (x_t - B_t * pred_x_T) / A_t 102 | elif pred_type == PredictionType.x_0: 103 | pred_x_0 = pred 104 | pred_x_T = (x_t - A_t * pred_x_0) / B_t 105 | elif pred_type == PredictionType.v_cos: 106 | pred_x_0 = A_t * x_t - B_t * pred 107 | pred_x_T = A_t * pred + B_t * x_t 108 | elif pred_type == PredictionType.v_lerp: 109 | pred_x_0 = (x_t - B_t * pred) / (A_t + B_t) 110 | pred_x_T = (x_t + A_t * pred) / (A_t + B_t) 111 | else: 112 | raise NotImplementedError 113 | 114 | return pred_x_0, pred_x_T 115 | 116 | def convert_to_pred( 117 | self, x_0: torch.Tensor, x_T: torch.Tensor, t: torch.Tensor, pred_type: PredictionType 118 | ) -> torch.FloatTensor: 119 | """ 120 | Convert to prediction target given x_0 and x_T. 121 | """ 122 | if pred_type == PredictionType.x_T: 123 | return x_T 124 | if pred_type == PredictionType.x_0: 125 | return x_0 126 | if pred_type == PredictionType.v_cos: 127 | t = expand_dims(t, x_0.ndim) 128 | return self.A(t) * x_T - self.B(t) * x_0 129 | if pred_type == PredictionType.v_lerp: 130 | return x_T - x_0 131 | raise NotImplementedError 132 | -------------------------------------------------------------------------------- /src/models/dit_v2/modulation.py: -------------------------------------------------------------------------------- 1 | # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 2 | # // 3 | # // Licensed under the Apache License, Version 2.0 (the "License"); 4 | # // you may not use this file except in compliance with the License. 5 | # // You may obtain a copy of the License at 6 | # // 7 | # // http://www.apache.org/licenses/LICENSE-2.0 8 | # // 9 | # // Unless required by applicable law or agreed to in writing, software 10 | # // distributed under the License is distributed on an "AS IS" BASIS, 11 | # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # // See the License for the specific language governing permissions and 13 | # // limitations under the License. 14 | 15 | from typing import Callable, List, Optional 16 | import torch 17 | from einops import rearrange 18 | from torch import nn 19 | 20 | from ...common.cache import Cache 21 | from ...common.distributed.ops import slice_inputs 22 | 23 | # (dim: int, emb_dim: int) 24 | ada_layer_type = Callable[[int, int], nn.Module] 25 | 26 | 27 | def get_ada_layer(ada_layer: str) -> ada_layer_type: 28 | if ada_layer == "single": 29 | return AdaSingle 30 | raise NotImplementedError(f"{ada_layer} is not supported") 31 | 32 | 33 | def expand_dims(x: torch.Tensor, dim: int, ndim: int): 34 | """ 35 | Expand tensor "x" to "ndim" by adding empty dims at "dim". 36 | Example: x is (b d), target ndim is 5, add dim at 1, return (b 1 1 1 d). 37 | """ 38 | shape = x.shape 39 | shape = shape[:dim] + (1,) * (ndim - len(shape)) + shape[dim:] 40 | return x.reshape(shape) 41 | 42 | 43 | class AdaSingle(nn.Module): 44 | def __init__( 45 | self, 46 | dim: int, 47 | emb_dim: int, 48 | layers: List[str], 49 | modes: List[str] = ["in", "out"], 50 | ): 51 | assert emb_dim == 6 * dim, "AdaSingle requires emb_dim == 6 * dim" 52 | super().__init__() 53 | self.dim = dim 54 | self.emb_dim = emb_dim 55 | self.layers = layers 56 | for l in layers: 57 | if "in" in modes: 58 | self.register_parameter(f"{l}_shift", nn.Parameter(torch.randn(dim) / dim**0.5)) 59 | self.register_parameter( 60 | f"{l}_scale", nn.Parameter(torch.randn(dim) / dim**0.5 + 1) 61 | ) 62 | if "out" in modes: 63 | self.register_parameter(f"{l}_gate", nn.Parameter(torch.randn(dim) / dim**0.5)) 64 | 65 | def forward( 66 | self, 67 | hid: torch.FloatTensor, # b ... c 68 | emb: torch.FloatTensor, # b d 69 | layer: str, 70 | mode: str, 71 | cache: Cache = Cache(disable=True), 72 | branch_tag: str = "", 73 | hid_len: Optional[torch.LongTensor] = None, # b 74 | ) -> torch.FloatTensor: 75 | idx = self.layers.index(layer) 76 | emb = rearrange(emb, "b (d l g) -> b d l g", l=len(self.layers), g=3)[..., idx, :] 77 | emb = expand_dims(emb, 1, hid.ndim + 1) 78 | 79 | if hid_len is not None: 80 | emb = cache( 81 | f"emb_repeat_{idx}_{branch_tag}", 82 | lambda: slice_inputs( 83 | torch.cat([e.repeat(l, *([1] * e.ndim)) for e, l in zip(emb, hid_len)]), 84 | dim=0, 85 | ), 86 | ) 87 | 88 | shiftA, scaleA, gateA = emb.unbind(-1) 89 | shiftB, scaleB, gateB = ( 90 | getattr(self, f"{layer}_shift", None), 91 | getattr(self, f"{layer}_scale", None), 92 | getattr(self, f"{layer}_gate", None), 93 | ) 94 | 95 | # Handle potential FP8 parameters - convert to computation dtype 96 | if hasattr(torch, 'float8_e4m3fn'): 97 | fp8_types = (torch.float8_e4m3fn, torch.float8_e5m2) 98 | 99 | # Convert FP8 parameters to BFloat16 for arithmetic operations 100 | if shiftB is not None and shiftB.dtype in fp8_types: 101 | shiftB = shiftB.to(torch.bfloat16) 102 | if scaleB is not None and scaleB.dtype in fp8_types: 103 | scaleB = scaleB.to(torch.bfloat16) 104 | if gateB is not None and gateB.dtype in fp8_types: 105 | gateB = gateB.to(torch.bfloat16) 106 | 107 | if mode == "in": 108 | return hid.mul_(scaleA + scaleB).add_(shiftA + shiftB) 109 | if mode == "out": 110 | if gateB is not None: 111 | return hid.mul_(gateA + gateB) 112 | else: 113 | # If no gate parameter, just use the embedding gate 114 | return hid.mul_(gateA) 115 | 116 | raise NotImplementedError 117 | 118 | def extra_repr(self) -> str: 119 | return f"dim={self.dim}, emb_dim={self.emb_dim}, layers={self.layers}" -------------------------------------------------------------------------------- /src/models/dit_v2/patch/patch_v1.py: -------------------------------------------------------------------------------- 1 | # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 2 | # // 3 | # // Licensed under the Apache License, Version 2.0 (the "License"); 4 | # // you may not use this file except in compliance with the License. 5 | # // You may obtain a copy of the License at 6 | # // 7 | # // http://www.apache.org/licenses/LICENSE-2.0 8 | # // 9 | # // Unless required by applicable law or agreed to in writing, software 10 | # // distributed under the License is distributed on an "AS IS" BASIS, 11 | # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # // See the License for the specific language governing permissions and 13 | # // limitations under the License. 14 | 15 | from typing import Tuple, Union 16 | import torch 17 | from einops import rearrange 18 | from torch import nn 19 | from torch.nn.modules.utils import _triple 20 | 21 | from ....common.cache import Cache 22 | from ....common.distributed.ops import gather_outputs, slice_inputs 23 | 24 | from .. import na 25 | 26 | 27 | class PatchIn(nn.Module): 28 | def __init__( 29 | self, 30 | in_channels: int, 31 | patch_size: Union[int, Tuple[int, int, int]], 32 | dim: int, 33 | ): 34 | super().__init__() 35 | t, h, w = _triple(patch_size) 36 | self.patch_size = t, h, w 37 | self.proj = nn.Linear(in_channels * t * h * w, dim) 38 | 39 | def forward( 40 | self, 41 | vid: torch.Tensor, 42 | ) -> torch.Tensor: 43 | t, h, w = self.patch_size 44 | if t > 1: 45 | assert vid.size(2) % t == 1 46 | vid = torch.cat([vid[:, :, :1]] * (t - 1) + [vid], dim=2) 47 | vid = rearrange(vid, "b c (T t) (H h) (W w) -> b T H W (t h w c)", t=t, h=h, w=w) 48 | vid = self.proj(vid) 49 | return vid 50 | 51 | 52 | class PatchOut(nn.Module): 53 | def __init__( 54 | self, 55 | out_channels: int, 56 | patch_size: Union[int, Tuple[int, int, int]], 57 | dim: int, 58 | ): 59 | super().__init__() 60 | t, h, w = _triple(patch_size) 61 | self.patch_size = t, h, w 62 | self.proj = nn.Linear(dim, out_channels * t * h * w) 63 | 64 | def forward( 65 | self, 66 | vid: torch.Tensor, 67 | ) -> torch.Tensor: 68 | t, h, w = self.patch_size 69 | vid = self.proj(vid) 70 | vid = rearrange(vid, "b T H W (t h w c) -> b c (T t) (H h) (W w)", t=t, h=h, w=w) 71 | if t > 1: 72 | vid = vid[:, :, (t - 1) :] 73 | return vid 74 | 75 | 76 | class NaPatchIn(PatchIn): 77 | def forward( 78 | self, 79 | vid: torch.Tensor, # l c 80 | vid_shape: torch.LongTensor, 81 | cache: Cache = Cache(disable=True), # for test 82 | ) -> torch.Tensor: 83 | cache = cache.namespace("patch") 84 | vid_shape_before_patchify = cache("vid_shape_before_patchify", lambda: vid_shape) 85 | t, h, w = self.patch_size 86 | if not (t == h == w == 1): 87 | vid = na.unflatten(vid, vid_shape) 88 | for i in range(len(vid)): 89 | if t > 1 and vid_shape_before_patchify[i, 0] % t != 0: 90 | vid[i] = torch.cat([vid[i][:1]] * (t - vid[i].size(0) % t) + [vid[i]], dim=0) 91 | vid[i] = rearrange(vid[i], "(T t) (H h) (W w) c -> T H W (t h w c)", t=t, h=h, w=w) 92 | vid, vid_shape = na.flatten(vid) 93 | 94 | # slice vid after patching in when using sequence parallelism 95 | vid = slice_inputs(vid, dim=0) 96 | vid = self.proj(vid) 97 | return vid, vid_shape 98 | 99 | 100 | class NaPatchOut(PatchOut): 101 | def forward( 102 | self, 103 | vid: torch.FloatTensor, # l c 104 | vid_shape: torch.LongTensor, 105 | cache: Cache = Cache(disable=True), # for test 106 | ) -> Tuple[ 107 | torch.FloatTensor, 108 | torch.LongTensor, 109 | ]: 110 | cache = cache.namespace("patch") 111 | vid_shape_before_patchify = cache.get("vid_shape_before_patchify") 112 | 113 | t, h, w = self.patch_size 114 | vid = self.proj(vid) 115 | # gather vid before patching out when enabling sequence parallelism 116 | vid = gather_outputs( 117 | vid, gather_dim=0, padding_dim=0, unpad_shape=vid_shape, cache=cache.namespace("vid") 118 | ) 119 | if not (t == h == w == 1): 120 | vid = na.unflatten(vid, vid_shape) 121 | for i in range(len(vid)): 122 | vid[i] = rearrange(vid[i], "T H W (t h w c) -> (T t) (H h) (W w) c", t=t, h=h, w=w) 123 | if t > 1 and vid_shape_before_patchify[i, 0] % t != 0: 124 | vid[i] = vid[i][(t - vid_shape_before_patchify[i, 0] % t) :] 125 | vid, vid_shape = na.flatten(vid) 126 | 127 | return vid, vid_shape 128 | -------------------------------------------------------------------------------- /src/models/dit_v2/nablocks/mmsr_block.py: -------------------------------------------------------------------------------- 1 | # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 2 | # // 3 | # // Licensed under the Apache License, Version 2.0 (the "License"); 4 | # // you may not use this file except in compliance with the License. 5 | # // You may obtain a copy of the License at 6 | # // 7 | # // http://www.apache.org/licenses/LICENSE-2.0 8 | # // 9 | # // Unless required by applicable law or agreed to in writing, software 10 | # // distributed under the License is distributed on an "AS IS" BASIS, 11 | # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # // See the License for the specific language governing permissions and 13 | # // limitations under the License. 14 | 15 | from typing import Tuple 16 | import torch 17 | import torch.nn as nn 18 | 19 | # from ..cache import Cache 20 | from ....common.cache import Cache 21 | 22 | from .attention.mmattn import NaSwinAttention 23 | from ..mm import MMArg 24 | from ..modulation import ada_layer_type 25 | from ..normalization import norm_layer_type 26 | from ..mm import MMArg, MMModule 27 | from ..mlp import get_mlp 28 | 29 | 30 | class NaMMSRTransformerBlock(nn.Module): 31 | def __init__( 32 | self, 33 | *, 34 | vid_dim: int, 35 | txt_dim: int, 36 | emb_dim: int, 37 | heads: int, 38 | head_dim: int, 39 | expand_ratio: int, 40 | norm: norm_layer_type, 41 | norm_eps: float, 42 | ada: ada_layer_type, 43 | qk_bias: bool, 44 | qk_norm: norm_layer_type, 45 | mlp_type: str, 46 | shared_weights: bool, 47 | rope_type: str, 48 | rope_dim: int, 49 | is_last_layer: bool, 50 | **kwargs, 51 | ): 52 | super().__init__() 53 | dim = MMArg(vid_dim, txt_dim) 54 | self.attn_norm = MMModule(norm, dim=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights,) 55 | 56 | self.attn = NaSwinAttention( 57 | vid_dim=vid_dim, 58 | txt_dim=txt_dim, 59 | heads=heads, 60 | head_dim=head_dim, 61 | qk_bias=qk_bias, 62 | qk_norm=qk_norm, 63 | qk_norm_eps=norm_eps, 64 | rope_type=rope_type, 65 | rope_dim=rope_dim, 66 | shared_weights=shared_weights, 67 | window=kwargs.pop("window", None), 68 | window_method=kwargs.pop("window_method", None), 69 | ) 70 | 71 | self.mlp_norm = MMModule(norm, dim=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights, vid_only=is_last_layer) 72 | self.mlp = MMModule( 73 | get_mlp(mlp_type), 74 | dim=dim, 75 | expand_ratio=expand_ratio, 76 | shared_weights=shared_weights, 77 | vid_only=is_last_layer 78 | ) 79 | self.ada = MMModule(ada, dim=dim, emb_dim=emb_dim, layers=["attn", "mlp"], shared_weights=shared_weights, vid_only=is_last_layer) 80 | self.is_last_layer = is_last_layer 81 | 82 | def forward( 83 | self, 84 | vid: torch.FloatTensor, # l c 85 | txt: torch.FloatTensor, # l c 86 | vid_shape: torch.LongTensor, # b 3 87 | txt_shape: torch.LongTensor, # b 1 88 | emb: torch.FloatTensor, 89 | cache: Cache, 90 | ) -> Tuple[ 91 | torch.FloatTensor, 92 | torch.FloatTensor, 93 | torch.LongTensor, 94 | torch.LongTensor, 95 | ]: 96 | hid_len = MMArg( 97 | cache("vid_len", lambda: vid_shape.prod(-1)), 98 | cache("txt_len", lambda: txt_shape.prod(-1)), 99 | ) 100 | ada_kwargs = { 101 | "emb": emb, 102 | "hid_len": hid_len, 103 | "cache": cache, 104 | "branch_tag": MMArg("vid", "txt"), 105 | } 106 | 107 | vid_attn, txt_attn = self.attn_norm(vid, txt) 108 | 109 | vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="in", **ada_kwargs) 110 | vid_attn, txt_attn = self.attn(vid_attn, txt_attn, vid_shape, txt_shape, cache) 111 | vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="out", **ada_kwargs) 112 | vid_attn, txt_attn = (vid_attn + vid), (txt_attn + txt) 113 | 114 | vid_mlp, txt_mlp = self.mlp_norm(vid_attn, txt_attn) 115 | # ADD BY NUMZ 116 | if vid_mlp.dtype != vid_attn.dtype: 117 | vid_mlp = vid_mlp.to(vid_attn.dtype) 118 | if txt_mlp.dtype != txt_attn.dtype: 119 | txt_mlp = txt_mlp.to(txt_attn.dtype) 120 | # END BY NUMZ 121 | vid_mlp, txt_mlp = self.ada(vid_mlp, txt_mlp, layer="mlp", mode="in", **ada_kwargs) 122 | vid_mlp, txt_mlp = self.mlp(vid_mlp, txt_mlp) 123 | vid_mlp, txt_mlp = self.ada(vid_mlp, txt_mlp, layer="mlp", mode="out", **ada_kwargs) 124 | vid_mlp, txt_mlp = (vid_mlp + vid_attn), (txt_mlp + txt_attn) 125 | 126 | return vid_mlp, txt_mlp, vid_shape, txt_shape 127 | -------------------------------------------------------------------------------- /src/utils/model_registry.py: -------------------------------------------------------------------------------- 1 | """ 2 | Model Registry for SeedVR2 3 | Central registry for model definitions, repositories, and metadata 4 | """ 5 | 6 | from typing import Dict, List, Optional 7 | from dataclasses import dataclass 8 | from src.utils.constants import SEEDVR2_MODEL_TYPE, is_supported_model_file, get_base_cache_dir 9 | 10 | @dataclass 11 | class ModelInfo: 12 | """Model metadata""" 13 | repo: str = "numz/SeedVR2_comfyUI" 14 | category: str = "model" # 'model' or 'vae' 15 | precision: str = "fp16" # 'fp16', 'fp8_e4m3fn', 'Q4_K_M', etc. 16 | size: str = "3B" # '3B', '7B', etc. 17 | variant: Optional[str] = None # 'sharp', etc. 18 | 19 | # Model registry with metadata 20 | MODEL_REGISTRY = { 21 | # 3B models 22 | "seedvr2_ema_3b_fp8_e4m3fn.safetensors": ModelInfo(size="3B", precision="fp8_e4m3fn"), 23 | "seedvr2_ema_3b_fp16.safetensors": ModelInfo(size="3B", precision="fp16"), 24 | 25 | # 7B models 26 | "seedvr2_ema_7b_fp8_e4m3fn_mixed_block35_fp16.safetensors": ModelInfo(repo="AInVFX/SeedVR2_comfyUI", size="7B", precision="fp8_e4m3fn_mixed_block35_fp16"), 27 | "seedvr2_ema_7b_fp16.safetensors": ModelInfo(size="7B", precision="fp16"), 28 | 29 | # 7B sharp variants 30 | "seedvr2_ema_7b_sharp_fp8_e4m3fn_mixed_block35_fp16.safetensors": ModelInfo(repo="AInVFX/SeedVR2_comfyUI", size="7B", precision="fp8_e4m3fn_mixed_block35_fp16", variant="sharp"), 31 | "seedvr2_ema_7b_sharp_fp16.safetensors": ModelInfo(size="7B", precision="fp16", variant="sharp"), 32 | 33 | # VAE models 34 | "ema_vae_fp16.safetensors": ModelInfo(category="vae", precision="fp16"), 35 | } 36 | 37 | GGUF_MODEL_REGISTRY = { 38 | # 3B models 39 | "seedvr2_ema_3b-Q3_K_M.gguf": ModelInfo(repo="cmeka/SeedVR2-GGUF", size="3B", precision="Q3_K_M"), 40 | "seedvr2_ema_3b-Q4_K_M.gguf": ModelInfo(repo="cmeka/SeedVR2-GGUF", size="3B", precision="Q4_K_M"), 41 | "seedvr2_ema_3b-Q5_K_M.gguf": ModelInfo(repo="cmeka/SeedVR2-GGUF", size="3B", precision="Q5_K_M"), 42 | "seedvr2_ema_3b-Q6_K_M.gguf": ModelInfo(repo="cmeka/SeedVR2-GGUF", size="3B", precision="Q6_K_M"), 43 | "seedvr2_ema_3b-Q8_K_M.gguf": ModelInfo(repo="cmeka/SeedVR2-GGUF", size="3B", precision="Q8_K_M"), 44 | 45 | # 7B models 46 | "seedvr2_ema_7b-Q3_K_M.gguf": ModelInfo(repo="cmeka/SeedVR2-GGUF", size="7B", precision="Q3_K_M"), 47 | "seedvr2_ema_7b-Q4_K_M.gguf": ModelInfo(repo="cmeka/SeedVR2-GGUF", size="7B", precision="Q4_K_M"), 48 | "seedvr2_ema_7b-Q5_K_M.gguf": ModelInfo(repo="cmeka/SeedVR2-GGUF", size="7B", precision="Q5_K_M"), 49 | "seedvr2_ema_7b-Q6_K_M.gguf": ModelInfo(repo="cmeka/SeedVR2-GGUF", size="7B", precision="Q6_K_M"), 50 | "seedvr2_ema_7b-Q8_K_M.gguf": ModelInfo(repo="cmeka/SeedVR2-GGUF", size="7B", precision="Q8_K_M"), 51 | 52 | # 7B sharp variants 53 | "seedvr2_ema_7b_sharp-Q3_K_M.gguf": ModelInfo(repo="cmeka/SeedVR2-GGUF", size="7B", precision="Q3_K_M"), 54 | "seedvr2_ema_7b_sharp-Q4_K_M.gguf": ModelInfo(repo="cmeka/SeedVR2-GGUF", size="7B", precision="Q4_K_M"), 55 | "seedvr2_ema_7b_sharp-Q5_K_M.gguf": ModelInfo(repo="cmeka/SeedVR2-GGUF", size="7B", precision="Q5_K_M"), 56 | "seedvr2_ema_7b_sharp-Q6_K_M.gguf": ModelInfo(repo="cmeka/SeedVR2-GGUF", size="7B", precision="Q6_K_M"), 57 | "seedvr2_ema_7b_sharp-Q8_K_M.gguf": ModelInfo(repo="cmeka/SeedVR2-GGUF", size="7B", precision="Q8_K_M"), 58 | 59 | # VAE models 60 | "ema_vae_fp16.safetensors": ModelInfo(category="vae", precision="fp16"), 61 | } 62 | 63 | # Configuration constants 64 | DEFAULT_MODEL = "seedvr2_ema_3b_fp8_e4m3fn.safetensors" 65 | DEFAULT_GGUF_MODEL = "seedvr2_ema_3b-Q4_K_M.gguf" 66 | DEFAULT_VAE = "ema_vae_fp16.safetensors" 67 | 68 | def get_default_models(gguf=False) -> List[str]: 69 | """Get list of default models (non-VAE)""" 70 | if gguf: 71 | return [name for name, info in GGUF_MODEL_REGISTRY.items() if info.category == "model"] 72 | return [name for name, info in MODEL_REGISTRY.items() if info.category == "model"] 73 | 74 | def get_model_repo(model_name: str, gguf=False) -> str: 75 | """Get repository for a specific model""" 76 | if gguf: 77 | return GGUF_MODEL_REGISTRY.get(model_name, ModelInfo()).repo 78 | return MODEL_REGISTRY.get(model_name, ModelInfo()).repo 79 | 80 | def get_available_models(gguf=False) -> List[str]: 81 | """Get all available models including those discovered on disk""" 82 | model_list = get_default_models(gguf) 83 | 84 | try: 85 | import folder_paths # only works if comfyui is available 86 | # Ensure the folder is registered before trying to list files 87 | get_base_cache_dir() 88 | # Get all models from the SEEDVR2 folder using centralized constant 89 | available_models = folder_paths.get_filename_list(SEEDVR2_MODEL_TYPE) 90 | 91 | # Add any models not in the registry with supported extensions 92 | for model in available_models: 93 | if is_supported_model_file(model) and model not in MODEL_REGISTRY: 94 | model_list.append(model) 95 | except: 96 | pass 97 | 98 | return model_list -------------------------------------------------------------------------------- /src/data/image/transforms/area_resize.py: -------------------------------------------------------------------------------- 1 | # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 2 | # // 3 | # // Licensed under the Apache License, Version 2.0 (the "License"); 4 | # // you may not use this file except in compliance with the License. 5 | # // You may obtain a copy of the License at 6 | # // 7 | # // http://www.apache.org/licenses/LICENSE-2.0 8 | # // 9 | # // Unless required by applicable law or agreed to in writing, software 10 | # // distributed under the License is distributed on an "AS IS" BASIS, 11 | # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # // See the License for the specific language governing permissions and 13 | # // limitations under the License. 14 | 15 | import math 16 | import random 17 | from typing import Union 18 | import torch 19 | from PIL import Image 20 | from torchvision.transforms import functional as TVF 21 | from torchvision.transforms.functional import InterpolationMode 22 | 23 | 24 | class AreaResize: 25 | def __init__( 26 | self, 27 | max_area: float, 28 | downsample_only: bool = False, 29 | interpolation: InterpolationMode = InterpolationMode.BICUBIC, 30 | ): 31 | self.max_area = max_area 32 | self.downsample_only = downsample_only 33 | self.interpolation = interpolation 34 | if torch.mps.is_available(): 35 | self.interpolation = InterpolationMode.BILINEAR 36 | 37 | def __call__(self, image: Union[torch.Tensor, Image.Image]): 38 | 39 | if isinstance(image, torch.Tensor): 40 | height, width = image.shape[-2:] 41 | elif isinstance(image, Image.Image): 42 | width, height = image.size 43 | else: 44 | raise NotImplementedError 45 | 46 | scale = math.sqrt(self.max_area / (height * width)) 47 | 48 | # keep original height and width for small pictures. 49 | scale = 1 if scale >= 1 and self.downsample_only else scale 50 | 51 | resized_height, resized_width = round(height * scale), round(width * scale) 52 | 53 | return TVF.resize( 54 | image, 55 | size=(resized_height, resized_width), 56 | interpolation=self.interpolation, 57 | ) 58 | 59 | 60 | class AreaRandomCrop: 61 | def __init__( 62 | self, 63 | max_area: float, 64 | ): 65 | self.max_area = max_area 66 | 67 | def get_params(self, input_size, output_size): 68 | """Get parameters for ``crop`` for a random crop. 69 | 70 | Args: 71 | img (PIL Image): Image to be cropped. 72 | output_size (tuple): Expected output size of the crop. 73 | 74 | Returns: 75 | tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. 76 | """ 77 | # w, h = _get_image_size(img) 78 | h, w = input_size 79 | th, tw = output_size 80 | if w <= tw and h <= th: 81 | return 0, 0, h, w 82 | 83 | i = random.randint(0, h - th) 84 | j = random.randint(0, w - tw) 85 | return i, j, th, tw 86 | 87 | def __call__(self, image: Union[torch.Tensor, Image.Image]): 88 | if isinstance(image, torch.Tensor): 89 | height, width = image.shape[-2:] 90 | elif isinstance(image, Image.Image): 91 | width, height = image.size 92 | else: 93 | raise NotImplementedError 94 | 95 | resized_height = math.sqrt(self.max_area / (width / height)) 96 | resized_width = (width / height) * resized_height 97 | 98 | # print('>>>>>>>>>>>>>>>>>>>>>') 99 | # print((height, width)) 100 | # print( (resized_height, resized_width)) 101 | 102 | resized_height, resized_width = round(resized_height), round(resized_width) 103 | i, j, h, w = self.get_params((height, width), (resized_height, resized_width)) 104 | image = TVF.crop(image, i, j, h, w) 105 | return image 106 | 107 | class ScaleResize: 108 | def __init__( 109 | self, 110 | scale: float, 111 | ): 112 | self.scale = scale 113 | 114 | def __call__(self, image: Union[torch.Tensor, Image.Image]): 115 | if isinstance(image, torch.Tensor): 116 | height, width = image.shape[-2:] 117 | interpolation_mode = InterpolationMode.BILINEAR 118 | antialias = True if image.ndim == 4 else "warn" 119 | elif isinstance(image, Image.Image): 120 | width, height = image.size 121 | interpolation_mode = InterpolationMode.LANCZOS 122 | antialias = "warn" 123 | else: 124 | raise NotImplementedError 125 | 126 | scale = self.scale 127 | 128 | # keep original height and width for small pictures 129 | 130 | resized_height, resized_width = round(height * scale), round(width * scale) 131 | image = TVF.resize( 132 | image, 133 | size=(resized_height, resized_width), 134 | interpolation=interpolation_mode, 135 | antialias=antialias, 136 | ) 137 | return image 138 | -------------------------------------------------------------------------------- /src/common/decorators.py: -------------------------------------------------------------------------------- 1 | # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 2 | # // 3 | # // Licensed under the Apache License, Version 2.0 (the "License"); 4 | # // you may not use this file except in compliance with the License. 5 | # // You may obtain a copy of the License at 6 | # // 7 | # // http://www.apache.org/licenses/LICENSE-2.0 8 | # // 9 | # // Unless required by applicable law or agreed to in writing, software 10 | # // distributed under the License is distributed on an "AS IS" BASIS, 11 | # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # // See the License for the specific language governing permissions and 13 | # // limitations under the License. 14 | 15 | """ 16 | Decorators. 17 | """ 18 | 19 | import functools 20 | import threading 21 | import time 22 | from typing import Callable 23 | import torch 24 | 25 | from .distributed import barrier_if_distributed, get_global_rank, get_local_rank 26 | from .logger import get_logger 27 | 28 | logger = get_logger(__name__) 29 | 30 | 31 | def log_on_entry(func: Callable) -> Callable: 32 | """ 33 | Functions with this decorator will log the function name at entry. 34 | When using multiple decorators, this must be applied innermost to properly capture the name. 35 | """ 36 | 37 | def log_on_entry_wrapper(*args, **kwargs): 38 | logger.info(f"Entering {func.__name__}") 39 | return func(*args, **kwargs) 40 | 41 | return log_on_entry_wrapper 42 | 43 | 44 | def barrier_on_entry(func: Callable) -> Callable: 45 | """ 46 | Functions with this decorator will start executing when all ranks are ready to enter. 47 | """ 48 | 49 | def barrier_on_entry_wrapper(*args, **kwargs): 50 | barrier_if_distributed() 51 | return func(*args, **kwargs) 52 | 53 | return barrier_on_entry_wrapper 54 | 55 | 56 | def _conditional_execute_wrapper_factory(execute: bool, func: Callable) -> Callable: 57 | """ 58 | Helper function for local_rank_zero_only and global_rank_zero_only. 59 | """ 60 | 61 | def conditional_execute_wrapper(*args, **kwargs): 62 | # Only execute if needed. 63 | result = func(*args, **kwargs) if execute else None 64 | # All GPUs must wait. 65 | barrier_if_distributed() 66 | # Return results. 67 | return result 68 | 69 | return conditional_execute_wrapper 70 | 71 | 72 | def _asserted_wrapper_factory(condition: bool, func: Callable, err_msg: str = "") -> Callable: 73 | """ 74 | Helper function for some functions with special constraints, 75 | especially functions called by other global_rank_zero_only / local_rank_zero_only ones, 76 | in case they are wrongly invoked in other scenarios. 77 | """ 78 | 79 | def asserted_execute_wrapper(*args, **kwargs): 80 | assert condition, err_msg 81 | result = func(*args, **kwargs) 82 | return result 83 | 84 | return asserted_execute_wrapper 85 | 86 | 87 | def local_rank_zero_only(func: Callable) -> Callable: 88 | """ 89 | Functions with this decorator will only execute on local rank zero. 90 | """ 91 | return _conditional_execute_wrapper_factory(get_local_rank() == 0, func) 92 | 93 | 94 | def global_rank_zero_only(func: Callable) -> Callable: 95 | """ 96 | Functions with this decorator will only execute on global rank zero. 97 | """ 98 | return _conditional_execute_wrapper_factory(get_global_rank() == 0, func) 99 | 100 | 101 | def assert_only_global_rank_zero(func: Callable) -> Callable: 102 | """ 103 | Functions with this decorator are only accessible to processes with global rank zero. 104 | """ 105 | return _asserted_wrapper_factory( 106 | get_global_rank() == 0, func, err_msg="Not accessible to processes with global_rank != 0" 107 | ) 108 | 109 | 110 | def assert_only_local_rank_zero(func: Callable) -> Callable: 111 | """ 112 | Functions with this decorator are only accessible to processes with local rank zero. 113 | """ 114 | return _asserted_wrapper_factory( 115 | get_local_rank() == 0, func, err_msg="Not accessible to processes with local_rank != 0" 116 | ) 117 | 118 | 119 | def new_thread(func: Callable) -> Callable: 120 | """ 121 | Functions with this decorator will run in a new thread. 122 | The function will return the thread, which can be joined to wait for completion. 123 | """ 124 | 125 | def new_thread_wrapper(*args, **kwargs): 126 | thread = threading.Thread(target=func, args=args, kwargs=kwargs) 127 | thread.start() 128 | return thread 129 | 130 | return new_thread_wrapper 131 | 132 | 133 | def log_runtime(func: Callable) -> Callable: 134 | """ 135 | Functions with this decorator will logging the runtime. 136 | """ 137 | 138 | @functools.wraps(func) 139 | def wrapped(*args, **kwargs): 140 | torch.distributed.barrier() 141 | start = time.perf_counter() 142 | result = func(*args, **kwargs) 143 | torch.distributed.barrier() 144 | logger.info(f"Completed {func.__name__} in {time.perf_counter() - start:.3f} seconds.") 145 | return result 146 | 147 | return wrapped 148 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | SeedVR2 Video Upscaler - Modular Architecture 3 | Refactored from monolithic seedvr2.py for better maintainability 4 | 5 | Author: Refactored codebase 6 | Version: 2.0.0 - Modular 7 | 8 | Available Modules: 9 | - utils: Download and path utilities 10 | - optimization: Memory, performance, and compatibility optimizations 11 | - core: Model management and generation pipeline (NEW) 12 | - processing: Video and tensor processing (coming next) 13 | - interfaces: ComfyUI integration 14 | """ 15 | ''' 16 | # Track which modules are available for progressive migration 17 | MODULES_AVAILABLE = { 18 | 'downloads': True, # ✅ Module 1 - Downloads and model management 19 | 'memory_manager': True, # ✅ Module 2 - Memory optimization 20 | 'performance': True, # ✅ Module 3 - Performance optimizations 21 | 'compatibility': True, # ✅ Module 4 - FP8/FP16 compatibility 22 | 'model_manager': True, # ✅ Module 5 - Model configuration and loading 23 | 'generation': True, # ✅ Module 6 - Generation loop and inference 24 | 'video_transforms': True, # ✅ Module 7 - Video processing and transforms 25 | 'comfyui_node': True, # ✅ Module 8 - ComfyUI node interface (COMPLETE!) 26 | 'infer': True, # ✅ Module 9 - Infer 27 | } 28 | ''' 29 | # Core imports (always available) 30 | import os 31 | import sys 32 | 33 | # Add current directory to path for fallback imports 34 | current_dir = os.path.dirname(os.path.abspath(__file__)) 35 | parent_dir = os.path.dirname(current_dir) 36 | if parent_dir not in sys.path: 37 | sys.path.insert(0, parent_dir) 38 | ''' 39 | # Progressive import system with fallback 40 | # ===== MODULE 0: Constants ===== 41 | if MODULES_AVAILABLE['downloads']: 42 | from src.utils.constants import ( 43 | get_base_cache_dir, 44 | ) 45 | 46 | # ===== MODULE 1: Downloads ===== 47 | if MODULES_AVAILABLE['downloads']: 48 | from src.utils.downloads import ( 49 | download_weight, 50 | ) 51 | 52 | 53 | # ===== MODULE 2: Memory Manager ===== 54 | if MODULES_AVAILABLE['memory_manager']: 55 | from src.optimization.memory_manager import ( 56 | get_vram_usage, 57 | clear_vram_cache, 58 | reset_vram_peak, 59 | preinitialize_rope_cache, 60 | ) 61 | 62 | 63 | # ===== MODULE 3: Performance ===== 64 | if MODULES_AVAILABLE['performance']: 65 | from src.optimization.performance import ( 66 | optimized_video_rearrange, 67 | optimized_single_video_rearrange, 68 | optimized_sample_to_image_format, 69 | temporal_latent_blending, 70 | ) 71 | 72 | 73 | # ===== MODULE 4: Compatibility ===== 74 | if MODULES_AVAILABLE['compatibility']: 75 | from src.optimization.compatibility import ( 76 | FP8CompatibleDiT, 77 | ) 78 | 79 | 80 | # ===== MODULE 5: Model Manager ===== 81 | if MODULES_AVAILABLE['model_manager']: 82 | from src.core.model_manager import ( 83 | configure_runner, 84 | load_quantized_state_dict, 85 | configure_dit_model_inference, 86 | configure_vae_model_inference, 87 | ) 88 | 89 | 90 | # ===== MODULE 6: Generation ===== 91 | if MODULES_AVAILABLE['generation']: 92 | from src.core.generation import ( 93 | generation_step, 94 | generation_loop, 95 | load_text_embeddings, 96 | calculate_optimal_batch_params, 97 | prepare_video_transforms 98 | ) 99 | 100 | # ===== MODULE 7: Video Transforms ===== 101 | if MODULES_AVAILABLE['infer']: 102 | from src.core.infer import VideoDiffusionInfer 103 | 104 | 105 | # ===== MODULE 8: ComfyUI Node ===== 106 | if MODULES_AVAILABLE['comfyui_node']: 107 | try: 108 | from src.interfaces.comfyui_node import ( 109 | SeedVR2, 110 | NODE_CLASS_MAPPINGS, 111 | NODE_DISPLAY_NAME_MAPPINGS 112 | ) 113 | except: 114 | pass 115 | 116 | # Export all available functions 117 | __all__ = [ 118 | # Constants 119 | 'get_base_cache_dir', 120 | 121 | # Utils 122 | 'download_weight', 123 | 124 | # Memory Management 125 | 'get_vram_usage', 'clear_vram_cache', 'reset_vram_peak', 126 | 'preinitialize_rope_cache', 127 | 128 | # Performance & Video Processing 129 | 'optimized_video_rearrange', 'optimized_single_video_rearrange', 'optimized_sample_to_image_format', 130 | 'temporal_latent_blending', 131 | 'validate_video_format', 'ensure_4n_plus_1_format', 'calculate_padding_requirements', 'apply_wavelet_reconstruction', 'temporal_consistency_check', 132 | 133 | # Compatibility 134 | 'FP8CompatibleDiT', 135 | 136 | # Core Model & Generation & Infer 137 | 'configure_runner', 'load_quantized_state_dict', 'configure_dit_model_inference', 'configure_vae_model_inference', 138 | 'generation_step', 'generation_loop', 'load_text_embeddings', 'calculate_optimal_batch_params', 139 | 'prepare_video_transforms', 'VideoDiffusionInfer', 140 | 141 | # ComfyUI Interface 142 | 'SeedVR2', 'NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS', 'format_execution_results', 143 | 144 | # Progress tracking 145 | 'get_refactoring_progress' 146 | ] 147 | ''' -------------------------------------------------------------------------------- /src/models/dit/normalization.py: -------------------------------------------------------------------------------- 1 | # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 2 | # // 3 | # // Licensed under the Apache License, Version 2.0 (the "License"); 4 | # // you may not use this file except in compliance with the License. 5 | # // You may obtain a copy of the License at 6 | # // 7 | # // http://www.apache.org/licenses/LICENSE-2.0 8 | # // 9 | # // Unless required by applicable law or agreed to in writing, software 10 | # // distributed under the License is distributed on an "AS IS" BASIS, 11 | # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # // See the License for the specific language governing permissions and 13 | # // limitations under the License. 14 | 15 | from typing import Callable, Optional 16 | from diffusers.models.normalization import RMSNorm 17 | from torch import nn 18 | import torch 19 | import torch.nn.functional as F 20 | import numbers 21 | from torch.nn.parameter import Parameter 22 | from torch.nn import init 23 | 24 | # (dim: int, eps: float, elementwise_affine: bool) 25 | norm_layer_type = Callable[[int, float, bool], nn.Module] 26 | 27 | 28 | class CustomLayerNorm(nn.Module): 29 | """ 30 | Custom LayerNorm implementation to replace Apex FusedLayerNorm 31 | """ 32 | def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): 33 | super(CustomLayerNorm, self).__init__() 34 | 35 | if isinstance(normalized_shape, numbers.Integral): 36 | normalized_shape = (normalized_shape,) 37 | self.normalized_shape = torch.Size(normalized_shape) 38 | self.eps = eps 39 | self.elementwise_affine = elementwise_affine 40 | 41 | if self.elementwise_affine: 42 | self.weight = Parameter(torch.Tensor(*normalized_shape)) 43 | self.bias = Parameter(torch.Tensor(*normalized_shape)) 44 | else: 45 | self.register_parameter('weight', None) 46 | self.register_parameter('bias', None) 47 | self.reset_parameters() 48 | 49 | def reset_parameters(self): 50 | if self.elementwise_affine: 51 | init.ones_(self.weight) 52 | init.zeros_(self.bias) 53 | 54 | def forward(self, input): 55 | return F.layer_norm( 56 | input, self.normalized_shape, self.weight, self.bias, self.eps) 57 | 58 | 59 | class CustomRMSNorm(nn.Module): 60 | """ 61 | Custom RMSNorm implementation to replace Apex FusedRMSNorm 62 | """ 63 | def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): 64 | super(CustomRMSNorm, self).__init__() 65 | 66 | if isinstance(normalized_shape, numbers.Integral): 67 | normalized_shape = (normalized_shape,) 68 | self.normalized_shape = torch.Size(normalized_shape) 69 | self.eps = eps 70 | self.elementwise_affine = elementwise_affine 71 | 72 | if self.elementwise_affine: 73 | self.weight = Parameter(torch.ones(*normalized_shape)) 74 | else: 75 | self.register_parameter('weight', None) 76 | 77 | def forward(self, input): 78 | # RMS normalization: x / sqrt(mean(x^2) + eps) * weight 79 | dims = tuple(range(-len(self.normalized_shape), 0)) 80 | 81 | # Calculate RMS: sqrt(mean(x^2)) 82 | variance = input.pow(2).mean(dim=dims, keepdim=True) 83 | rms = torch.sqrt(variance + self.eps) 84 | 85 | # Normalize 86 | normalized = input / rms 87 | 88 | if self.elementwise_affine: 89 | # Convert FP8 weight to BFloat16 for arithmetic operations 90 | if hasattr(torch, 'float8_e4m3fn'): 91 | fp8_types = (torch.float8_e4m3fn, torch.float8_e5m2) 92 | if self.weight.dtype in fp8_types: 93 | weight = self.weight.to(torch.bfloat16) 94 | return normalized * weight 95 | 96 | return normalized * self.weight 97 | return normalized 98 | 99 | 100 | def get_norm_layer(norm_type: Optional[str]) -> norm_layer_type: 101 | 102 | def _norm_layer(dim: int, eps: float, elementwise_affine: bool): 103 | if norm_type is None: 104 | return nn.Identity() 105 | 106 | if norm_type == "layer": 107 | return nn.LayerNorm( 108 | normalized_shape=dim, 109 | eps=eps, 110 | elementwise_affine=elementwise_affine, 111 | ) 112 | 113 | if norm_type == "rms": 114 | return RMSNorm( 115 | dim=dim, 116 | eps=eps, 117 | elementwise_affine=elementwise_affine, 118 | ) 119 | 120 | if norm_type == "fusedln": 121 | # Use custom LayerNorm instead of Apex FusedLayerNorm 122 | return CustomLayerNorm( 123 | normalized_shape=dim, 124 | elementwise_affine=elementwise_affine, 125 | eps=eps, 126 | ) 127 | 128 | if norm_type == "fusedrms": 129 | # Use custom RMSNorm instead of Apex FusedRMSNorm 130 | return CustomRMSNorm( 131 | normalized_shape=dim, 132 | elementwise_affine=elementwise_affine, 133 | eps=eps, 134 | ) 135 | 136 | raise NotImplementedError(f"{norm_type} is not supported") 137 | 138 | return _norm_layer -------------------------------------------------------------------------------- /src/common/config.py: -------------------------------------------------------------------------------- 1 | # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 2 | # // 3 | # // Licensed under the Apache License, Version 2.0 (the "License"); 4 | # // you may not use this file except in compliance with the License. 5 | # // You may obtain a copy of the License at 6 | # // 7 | # // http://www.apache.org/licenses/LICENSE-2.0 8 | # // 9 | # // Unless required by applicable law or agreed to in writing, software 10 | # // distributed under the License is distributed on an "AS IS" BASIS, 11 | # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # // See the License for the specific language governing permissions and 13 | # // limitations under the License. 14 | 15 | """ 16 | Configuration utility functions 17 | """ 18 | 19 | import importlib 20 | from typing import Any, Callable, List, Union 21 | from omegaconf import DictConfig, ListConfig, OmegaConf 22 | 23 | try: 24 | OmegaConf.register_new_resolver("eval", eval) 25 | except Exception as e: 26 | if "already registered" not in str(e): 27 | raise 28 | 29 | 30 | 31 | def load_config(path: str, argv: List[str] = None) -> Union[DictConfig, ListConfig]: 32 | """ 33 | Load a configuration. Will resolve inheritance. 34 | """ 35 | 36 | #print(path) 37 | config = OmegaConf.load(path) 38 | if argv is not None: 39 | config_argv = OmegaConf.from_dotlist(argv) 40 | config = OmegaConf.merge(config, config_argv) 41 | config = resolve_recursive(config, resolve_inheritance) 42 | return config 43 | 44 | 45 | def resolve_recursive( 46 | config: Any, 47 | resolver: Callable[[Union[DictConfig, ListConfig]], Union[DictConfig, ListConfig]], 48 | ) -> Any: 49 | config = resolver(config) 50 | if isinstance(config, DictConfig): 51 | for k in config.keys(): 52 | v = config.get(k) 53 | if isinstance(v, (DictConfig, ListConfig)): 54 | config[k] = resolve_recursive(v, resolver) 55 | if isinstance(config, ListConfig): 56 | for i in range(len(config)): 57 | v = config.get(i) 58 | if isinstance(v, (DictConfig, ListConfig)): 59 | config[i] = resolve_recursive(v, resolver) 60 | return config 61 | 62 | 63 | def resolve_inheritance(config: Union[DictConfig, ListConfig]) -> Any: 64 | """ 65 | Recursively resolve inheritance if the config contains: 66 | __inherit__: path/to/parent.yaml or a ListConfig of such paths. 67 | """ 68 | if isinstance(config, DictConfig): 69 | inherit = config.pop("__inherit__", None) 70 | 71 | if inherit: 72 | inherit_list = inherit if isinstance(inherit, ListConfig) else [inherit] 73 | 74 | parent_config = None 75 | for parent_path in inherit_list: 76 | assert isinstance(parent_path, str) 77 | parent_config = ( 78 | load_config(parent_path) 79 | if parent_config is None 80 | else OmegaConf.merge(parent_config, load_config(parent_path)) 81 | ) 82 | 83 | if len(config.keys()) > 0: 84 | config = OmegaConf.merge(parent_config, config) 85 | else: 86 | config = parent_config 87 | return config 88 | 89 | 90 | def import_item(path: Union[str, List[str]], name: str) -> Any: 91 | """ 92 | Import a python item with fallback support. 93 | 94 | Args: 95 | path: Single path string or list of paths to try (fallback order) 96 | name: Class/function name to import 97 | 98 | Returns: 99 | Imported object 100 | 101 | Example: 102 | import_item("path.to.file", "MyClass") -> MyClass 103 | import_item(["path1.to.file", "path2.to.file"], "MyClass") -> MyClass (first working path) 104 | """ 105 | if isinstance(path, str): 106 | # Single path - original behavior 107 | return getattr(importlib.import_module(path), name) 108 | 109 | elif isinstance(path, (list, ListConfig)): 110 | # Multiple paths - try each until one works 111 | last_error = None 112 | for single_path in path: 113 | try: 114 | return getattr(importlib.import_module(single_path), name) 115 | except ImportError as e: 116 | last_error = e 117 | continue 118 | 119 | # If we get here, none of the paths worked 120 | raise ImportError(f"Could not import '{name}' from any of the paths: {path}. Last error: {last_error}") 121 | 122 | else: 123 | raise ValueError(f"Path must be string or list of strings, got: {type(path)}") 124 | 125 | 126 | def create_object(config: DictConfig) -> Any: 127 | """ 128 | Create an object from config. 129 | The config is expected to contains the following: 130 | __object__: 131 | path: path.to.module 132 | name: MyClass 133 | args: as_config | as_params (default to as_config) 134 | """ 135 | 136 | item = import_item( 137 | path=config.__object__.path, 138 | name=config.__object__.name, 139 | ) 140 | args = config.__object__.get("args", "as_config") 141 | if args == "as_config": 142 | return item(config) 143 | if args == "as_params": 144 | config = OmegaConf.to_object(config) 145 | config.pop("__object__") 146 | return item(**config) 147 | raise NotImplementedError(f"Unknown args type: {args}") -------------------------------------------------------------------------------- /src/models/dit_v2/normalization.py: -------------------------------------------------------------------------------- 1 | # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 2 | # // 3 | # // Licensed under the Apache License, Version 2.0 (the "License"); 4 | # // you may not use this file except in compliance with the License. 5 | # // You may obtain a copy of the License at 6 | # // 7 | # // http://www.apache.org/licenses/LICENSE-2.0 8 | # // 9 | # // Unless required by applicable law or agreed to in writing, software 10 | # // distributed under the License is distributed on an "AS IS" BASIS, 11 | # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # // See the License for the specific language governing permissions and 13 | # // limitations under the License. 14 | 15 | from typing import Callable, Optional 16 | from diffusers.models.normalization import RMSNorm 17 | from torch import nn 18 | import torch 19 | import torch.nn.functional as F 20 | import numbers 21 | from torch.nn.parameter import Parameter 22 | from torch.nn import init 23 | 24 | # (dim: int, eps: float, elementwise_affine: bool) 25 | norm_layer_type = Callable[[int, float, bool], nn.Module] 26 | 27 | 28 | class CustomLayerNorm(nn.Module): 29 | """ 30 | Custom LayerNorm implementation to replace Apex FusedLayerNorm 31 | """ 32 | def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): 33 | super(CustomLayerNorm, self).__init__() 34 | 35 | if isinstance(normalized_shape, numbers.Integral): 36 | normalized_shape = (normalized_shape,) 37 | self.normalized_shape = torch.Size(normalized_shape) 38 | self.eps = eps 39 | self.elementwise_affine = elementwise_affine 40 | 41 | if self.elementwise_affine: 42 | self.weight = Parameter(torch.Tensor(*normalized_shape)) 43 | self.bias = Parameter(torch.Tensor(*normalized_shape)) 44 | else: 45 | self.register_parameter('weight', None) 46 | self.register_parameter('bias', None) 47 | self.reset_parameters() 48 | 49 | def reset_parameters(self): 50 | if self.elementwise_affine: 51 | init.ones_(self.weight) 52 | init.zeros_(self.bias) 53 | 54 | def forward(self, input): 55 | # 🚀 FP8 COMPATIBILITY: Convert parameters to match input dtype 56 | # This prevents "Promotion for Float8 Types is not supported" errors 57 | weight = self.weight 58 | bias = self.bias 59 | 60 | if self.elementwise_affine and weight is not None: 61 | if weight.dtype != input.dtype: 62 | weight = weight.to(input.dtype) 63 | if bias is not None and bias.dtype != input.dtype: 64 | bias = bias.to(input.dtype) 65 | 66 | return F.layer_norm( 67 | input, self.normalized_shape, weight, bias, self.eps) 68 | 69 | 70 | class CustomRMSNorm(nn.Module): 71 | """ 72 | Custom RMSNorm implementation to replace Apex FusedRMSNorm 73 | """ 74 | def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): 75 | super(CustomRMSNorm, self).__init__() 76 | 77 | if isinstance(normalized_shape, numbers.Integral): 78 | normalized_shape = (normalized_shape,) 79 | self.normalized_shape = torch.Size(normalized_shape) 80 | self.eps = eps 81 | self.elementwise_affine = elementwise_affine 82 | 83 | if self.elementwise_affine: 84 | self.weight = Parameter(torch.ones(*normalized_shape)) 85 | else: 86 | self.register_parameter('weight', None) 87 | 88 | def forward(self, input): 89 | # RMS normalization: x / sqrt(mean(x^2) + eps) * weight 90 | dims = tuple(range(-len(self.normalized_shape), 0)) 91 | 92 | # Calculate RMS: sqrt(mean(x^2)) 93 | variance = input.pow(2).mean(dim=dims, keepdim=True) 94 | rms = torch.sqrt(variance + self.eps) 95 | 96 | # Normalize 97 | normalized = input / rms 98 | 99 | if self.elementwise_affine: 100 | # Convert FP8 weight to BFloat16 for arithmetic operations 101 | if hasattr(torch, 'float8_e4m3fn'): 102 | fp8_types = (torch.float8_e4m3fn, torch.float8_e5m2) 103 | if self.weight.dtype in fp8_types: 104 | weight = self.weight.to(torch.bfloat16) 105 | return normalized * weight 106 | 107 | return normalized * self.weight 108 | return normalized 109 | 110 | 111 | def get_norm_layer(norm_type: Optional[str]) -> norm_layer_type: 112 | 113 | def _norm_layer(dim: int, eps: float, elementwise_affine: bool): 114 | if norm_type is None: 115 | return nn.Identity() 116 | 117 | if norm_type == "layer": 118 | return nn.LayerNorm( 119 | normalized_shape=dim, 120 | eps=eps, 121 | elementwise_affine=elementwise_affine, 122 | ) 123 | 124 | if norm_type == "rms": 125 | return RMSNorm( 126 | dim=dim, 127 | eps=eps, 128 | elementwise_affine=elementwise_affine, 129 | ) 130 | 131 | if norm_type == "fusedln": 132 | # Use custom LayerNorm instead of Apex FusedLayerNorm 133 | return CustomLayerNorm( 134 | normalized_shape=dim, 135 | elementwise_affine=elementwise_affine, 136 | eps=eps, 137 | ) 138 | 139 | if norm_type == "fusedrms": 140 | # Use custom RMSNorm instead of Apex FusedRMSNorm 141 | return CustomRMSNorm( 142 | normalized_shape=dim, 143 | elementwise_affine=elementwise_affine, 144 | eps=eps, 145 | ) 146 | 147 | raise NotImplementedError(f"{norm_type} is not supported") 148 | 149 | return _norm_layer -------------------------------------------------------------------------------- /src/optimization/performance.py: -------------------------------------------------------------------------------- 1 | """ 2 | Performance optimization module for SeedVR2 3 | Contains optimized tensor operations and video processing functions 4 | 5 | Extracted from: seedvr2.py (lines 1633-1730) 6 | """ 7 | 8 | import torch 9 | from typing import List, Union 10 | 11 | 12 | def optimized_video_rearrange(video_tensors: List[torch.Tensor]) -> List[torch.Tensor]: 13 | """ 14 | 🚀 OPTIMIZED version of video rearrangement 15 | Replaces slow loops with vectorized operations 16 | 17 | Transforms: 18 | - 3D: c h w -> t c h w (with t=1) 19 | - 4D: c t h w -> t c h w 20 | 21 | Expected gains: 5-10x faster than naive loops 22 | 23 | Args: 24 | video_tensors: List of video tensors to rearrange 25 | 26 | Returns: 27 | List of rearranged tensors in t c h w format 28 | """ 29 | if not video_tensors: 30 | return [] 31 | 32 | # 🔍 Analyze dimensions to optimize processing 33 | videos_3d = [] 34 | videos_4d = [] 35 | indices_3d = [] 36 | indices_4d = [] 37 | 38 | for i, video in enumerate(video_tensors): 39 | if video.ndim == 3: 40 | videos_3d.append(video) 41 | indices_3d.append(i) 42 | else: # ndim == 4 43 | videos_4d.append(video) 44 | indices_4d.append(i) 45 | 46 | # 🎯 Prepare final result 47 | samples = [None] * len(video_tensors) 48 | 49 | # 🚀 BATCH PROCESSING for 3D videos (c h w -> 1 c h w) 50 | if videos_3d: 51 | # Method 1: Stack + permute (faster than rearrange) 52 | # c h w -> c 1 h w -> 1 c h w 53 | batch_3d = torch.stack([v.unsqueeze(1) for v in videos_3d]) # [batch, c, 1, h, w] 54 | batch_3d = batch_3d.permute(0, 2, 1, 3, 4) # [batch, 1, c, h, w] 55 | 56 | for i, idx in enumerate(indices_3d): 57 | samples[idx] = batch_3d[i] # [1, c, h, w] 58 | 59 | # 🚀 BATCH PROCESSING for 4D videos (c t h w -> t c h w) 60 | if videos_4d: 61 | # Check if all 4D videos have the same shape for maximum optimization 62 | shapes = [v.shape for v in videos_4d] 63 | if len(set(shapes)) == 1: 64 | # 🎯 MAXIMUM OPTIMIZATION: All shapes identical 65 | # Stack + permute in single operation 66 | batch_4d = torch.stack(videos_4d) # [batch, c, t, h, w] 67 | batch_4d = batch_4d.permute(0, 2, 1, 3, 4) # [batch, t, c, h, w] 68 | 69 | for i, idx in enumerate(indices_4d): 70 | samples[idx] = batch_4d[i] # [t, c, h, w] 71 | else: 72 | # 🔄 FALLBACK: Different shapes, optimized individual processing 73 | for i, idx in enumerate(indices_4d): 74 | # Use permute instead of rearrange (faster) 75 | samples[idx] = videos_4d[i].permute(1, 0, 2, 3) # c t h w -> t c h w 76 | 77 | return samples 78 | 79 | 80 | def optimized_single_video_rearrange(video: torch.Tensor) -> torch.Tensor: 81 | """ 82 | 🚀 OPTIMIZED version for single video tensor 83 | Replaces rearrange() with native PyTorch operations 84 | 85 | Transforms: 86 | - 3D: c h w -> 1 c h w (add temporal dimension) 87 | - 4D: c t h w -> t c h w (permute dimensions) 88 | 89 | Expected gains: 2-5x faster than rearrange() 90 | 91 | Args: 92 | video: Input video tensor 93 | 94 | Returns: 95 | Rearranged tensor with temporal dimension first 96 | """ 97 | if video.ndim == 3: 98 | # c h w -> 1 c h w (add temporal dimension t=1) 99 | return video.unsqueeze(0) 100 | else: # ndim == 4 101 | # c t h w -> t c h w (permute channels and temporal) 102 | return video.permute(1, 0, 2, 3) 103 | 104 | 105 | def optimized_sample_to_image_format(sample: torch.Tensor) -> torch.Tensor: 106 | """ 107 | 🚀 OPTIMIZED version to convert sample to image format 108 | Replaces rearrange() with native PyTorch operations 109 | 110 | Transforms: 111 | - 3D: c h w -> 1 h w c (add temporal dimension + permute to image format) 112 | - 4D: t c h w -> t h w c (permute to image format) 113 | 114 | Expected gains: 2-5x faster than rearrange() 115 | 116 | Args: 117 | sample: Input sample tensor 118 | 119 | Returns: 120 | Tensor in image format (channels last) 121 | """ 122 | if sample.ndim == 3: 123 | # c h w -> 1 h w c (add temporal dimension then permute) 124 | return sample.unsqueeze(0).permute(0, 2, 3, 1) 125 | else: # ndim == 4 126 | # t c h w -> t h w c (permute channels to last) 127 | return sample.permute(0, 2, 3, 1) 128 | 129 | 130 | def temporal_latent_blending(latents1: torch.Tensor, latents2: torch.Tensor, blend_frames: int) -> torch.Tensor: 131 | """ 132 | 🎨 Temporal blending in latent space to avoid discontinuities 133 | 134 | Args: 135 | latents1: Latents from previous batch (end frames) 136 | latents2: Latents from current batch (start frames) 137 | blend_frames: Number of frames to blend 138 | 139 | Returns: 140 | Blended latents for smooth transition 141 | """ 142 | if latents1.shape[0] != latents2.shape[0]: 143 | # Adjust dimensions if necessary 144 | min_frames = min(latents1.shape[0], latents2.shape[0]) 145 | latents1 = latents1[:min_frames] 146 | latents2 = latents2[:min_frames] 147 | 148 | # Create linear blending weights 149 | # Frame 0: 100% latents1, 0% latents2 150 | # Frame n: 0% latents1, 100% latents2 151 | weights1 = torch.linspace(1.0, 0.0, blend_frames).view(-1, 1, 1, 1).to(latents1.device) 152 | weights2 = torch.linspace(0.0, 1.0, blend_frames).view(-1, 1, 1, 1).to(latents2.device) 153 | 154 | # Apply blending 155 | blended_latents = weights1 * latents1 + weights2 * latents2 156 | 157 | return blended_latents 158 | 159 | 160 | 161 | -------------------------------------------------------------------------------- /example_workflows/SeedVR2_Image_Upscaling.json: -------------------------------------------------------------------------------- 1 | { 2 | "id": "27decc95-db23-489f-bd36-91f56e7ec825", 3 | "revision": 0, 4 | "last_node_id": 10, 5 | "last_link_id": 10, 6 | "nodes": [ 7 | { 8 | "id": 3, 9 | "type": "LoadImage", 10 | "pos": [ 11 | 990, 12 | 130 13 | ], 14 | "size": [ 15 | 270, 16 | 314 17 | ], 18 | "flags": {}, 19 | "order": 0, 20 | "mode": 0, 21 | "inputs": [], 22 | "outputs": [ 23 | { 24 | "label": "IMAGE", 25 | "name": "IMAGE", 26 | "type": "IMAGE", 27 | "links": [ 28 | 5 29 | ] 30 | }, 31 | { 32 | "label": "MASK", 33 | "name": "MASK", 34 | "type": "MASK", 35 | "links": null 36 | } 37 | ], 38 | "properties": { 39 | "cnr_id": "comfy-core", 40 | "ver": "0.3.50", 41 | "Node name for S&R": "LoadImage" 42 | }, 43 | "widgets_values": [ 44 | "example.png", 45 | "image" 46 | ] 47 | }, 48 | { 49 | "id": 7, 50 | "type": "PreviewImage", 51 | "pos": [ 52 | 1980, 53 | 130 54 | ], 55 | "size": [ 56 | 140, 57 | 26 58 | ], 59 | "flags": {}, 60 | "order": 4, 61 | "mode": 0, 62 | "inputs": [ 63 | { 64 | "label": "images", 65 | "name": "images", 66 | "type": "IMAGE", 67 | "link": 8 68 | } 69 | ], 70 | "outputs": [], 71 | "properties": { 72 | "cnr_id": "comfy-core", 73 | "ver": "0.3.50", 74 | "Node name for S&R": "PreviewImage" 75 | }, 76 | "widgets_values": [] 77 | }, 78 | { 79 | "id": 9, 80 | "type": "SeedVR2BlockSwap", 81 | "pos": [ 82 | 1320, 83 | 200 84 | ], 85 | "size": [ 86 | 287.873046875, 87 | 106 88 | ], 89 | "flags": {}, 90 | "order": 1, 91 | "mode": 0, 92 | "inputs": [], 93 | "outputs": [ 94 | { 95 | "label": "block_swap_config", 96 | "name": "block_swap_config", 97 | "type": "block_swap_config", 98 | "links": [ 99 | 9 100 | ] 101 | } 102 | ], 103 | "properties": { 104 | "aux_id": "numz/ComfyUI-SeedVR2_VideoUpscaler", 105 | "ver": "12aefc08bb302fa595a5938e1520522dc7246d7c", 106 | "Node name for S&R": "SeedVR2BlockSwap" 107 | }, 108 | "widgets_values": [ 109 | 16, 110 | false, 111 | false 112 | ] 113 | }, 114 | { 115 | "id": 10, 116 | "type": "SeedVR2ExtraArgs", 117 | "pos": [ 118 | 1320, 119 | 370 120 | ], 121 | "size": [ 122 | 270, 123 | 202 124 | ], 125 | "flags": {}, 126 | "order": 2, 127 | "mode": 0, 128 | "inputs": [], 129 | "outputs": [ 130 | { 131 | "label": "extra_args", 132 | "name": "extra_args", 133 | "type": "extra_args", 134 | "links": [ 135 | 10 136 | ] 137 | } 138 | ], 139 | "properties": { 140 | "aux_id": "numz/ComfyUI-SeedVR2_VideoUpscaler", 141 | "ver": "12aefc08bb302fa595a5938e1520522dc7246d7c", 142 | "Node name for S&R": "SeedVR2ExtraArgs" 143 | }, 144 | "widgets_values": [ 145 | true, 146 | 512, 147 | 64, 148 | false, 149 | false, 150 | false, 151 | "cuda:0" 152 | ] 153 | }, 154 | { 155 | "id": 8, 156 | "type": "SeedVR2", 157 | "pos": [ 158 | 1650, 159 | 130 160 | ], 161 | "size": [ 162 | 270, 163 | 194 164 | ], 165 | "flags": {}, 166 | "order": 3, 167 | "mode": 0, 168 | "inputs": [ 169 | { 170 | "label": "images", 171 | "name": "images", 172 | "type": "IMAGE", 173 | "link": 5 174 | }, 175 | { 176 | "label": "block_swap_config", 177 | "name": "block_swap_config", 178 | "shape": 7, 179 | "type": "block_swap_config", 180 | "link": 9 181 | }, 182 | { 183 | "label": "extra_args", 184 | "name": "extra_args", 185 | "shape": 7, 186 | "type": "extra_args", 187 | "link": 10 188 | } 189 | ], 190 | "outputs": [ 191 | { 192 | "label": "image", 193 | "name": "image", 194 | "type": "IMAGE", 195 | "links": [ 196 | 8 197 | ] 198 | } 199 | ], 200 | "properties": { 201 | "aux_id": "numz/ComfyUI-SeedVR2_VideoUpscaler", 202 | "ver": "12aefc08bb302fa595a5938e1520522dc7246d7c", 203 | "Node name for S&R": "SeedVR2" 204 | }, 205 | "widgets_values": [ 206 | "seedvr2_ema_3b_fp8_e4m3fn.safetensors", 207 | 100, 208 | "randomize", 209 | 1072, 210 | 1 211 | ] 212 | } 213 | ], 214 | "links": [ 215 | [ 216 | 5, 217 | 3, 218 | 0, 219 | 8, 220 | 0, 221 | "IMAGE" 222 | ], 223 | [ 224 | 8, 225 | 8, 226 | 0, 227 | 7, 228 | 0, 229 | "IMAGE" 230 | ], 231 | [ 232 | 9, 233 | 9, 234 | 0, 235 | 8, 236 | 1, 237 | "block_swap_config" 238 | ], 239 | [ 240 | 10, 241 | 10, 242 | 0, 243 | 8, 244 | 2, 245 | "extra_args" 246 | ] 247 | ], 248 | "groups": [], 249 | "config": {}, 250 | "extra": { 251 | "ds": { 252 | "scale": 1.1471148376625333, 253 | "offset": [ 254 | -888.8886822691388, 255 | -24.501608073694676 256 | ] 257 | }, 258 | "frontendVersion": "1.25.7", 259 | "VHS_latentpreview": false, 260 | "VHS_latentpreviewrate": 0, 261 | "VHS_MetadataImage": true, 262 | "VHS_KeepIntermediate": true 263 | }, 264 | "version": 0.4 265 | } -------------------------------------------------------------------------------- /src/models/video_vae_v3/modules/inflated_lib.py: -------------------------------------------------------------------------------- 1 | # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 2 | # // 3 | # // Licensed under the Apache License, Version 2.0 (the "License"); 4 | # // you may not use this file except in compliance with the License. 5 | # // You may obtain a copy of the License at 6 | # // 7 | # // http://www.apache.org/licenses/LICENSE-2.0 8 | # // 9 | # // Unless required by applicable law or agreed to in writing, software 10 | # // distributed under the License is distributed on an "AS IS" BASIS, 11 | # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # // See the License for the specific language governing permissions and 13 | # // limitations under the License. 14 | 15 | from enum import Enum 16 | from typing import Optional 17 | import numpy as np 18 | import torch 19 | from diffusers.models.normalization import RMSNorm 20 | from einops import rearrange 21 | from torch import Tensor, nn 22 | 23 | from ....common.logger import get_logger 24 | 25 | logger = get_logger(__name__) 26 | 27 | 28 | class MemoryState(Enum): 29 | """ 30 | State[Disabled]: No memory bank will be enabled. 31 | State[Initializing]: The model is handling the first clip, 32 | need to reset / initialize the memory bank. 33 | State[Active]: There has been some data in the memory bank. 34 | """ 35 | 36 | DISABLED = 0 37 | INITIALIZING = 1 38 | ACTIVE = 2 39 | 40 | 41 | def causal_norm_wrapper(norm_layer: nn.Module, x: torch.Tensor) -> torch.Tensor: 42 | if isinstance(norm_layer, (nn.LayerNorm, RMSNorm)): 43 | if x.ndim == 4: 44 | x = rearrange(x, "b c h w -> b h w c") 45 | x = norm_layer(x) 46 | x = rearrange(x, "b h w c -> b c h w") 47 | return x 48 | if x.ndim == 5: 49 | x = rearrange(x, "b c t h w -> b t h w c") 50 | x = norm_layer(x) 51 | x = rearrange(x, "b t h w c -> b c t h w") 52 | return x 53 | if isinstance(norm_layer, (nn.GroupNorm, nn.BatchNorm2d, nn.SyncBatchNorm)): 54 | if x.ndim <= 4: 55 | return norm_layer(x) 56 | if x.ndim == 5: 57 | t = x.size(2) 58 | x = rearrange(x, "b c t h w -> (b t) c h w") 59 | x = norm_layer(x) 60 | x = rearrange(x, "(b t) c h w -> b c t h w", t=t) 61 | return x 62 | raise NotImplementedError 63 | 64 | 65 | def remove_head(tensor: Tensor, times: int = 1) -> Tensor: 66 | """ 67 | Remove duplicated first frame features in the up-sampling process. 68 | """ 69 | if times == 0: 70 | return tensor 71 | return torch.cat(tensors=(tensor[:, :, :1], tensor[:, :, times + 1 :]), dim=2) 72 | 73 | 74 | def extend_head( 75 | tensor: Tensor, times: Optional[int] = 2, memory: Optional[Tensor] = None 76 | ) -> Tensor: 77 | """ 78 | When memory is None: 79 | - Duplicate first frame features in the down-sampling process. 80 | When memory is not None: 81 | - Concatenate memory features with the input features to keep temporal consistency. 82 | """ 83 | if times == 0: 84 | return tensor 85 | if memory is not None: 86 | return torch.cat((memory.to(tensor), tensor), dim=2) 87 | else: 88 | tile_repeat = np.ones(tensor.ndim).astype(int) 89 | tile_repeat[2] = times 90 | return torch.cat(tensors=(torch.tile(tensor[:, :, :1], list(tile_repeat)), tensor), dim=2) 91 | 92 | 93 | def inflate_weight(weight_2d: torch.Tensor, weight_3d: torch.Tensor, inflation_mode: str): 94 | """ 95 | Inflate a 2D convolution weight matrix to a 3D one. 96 | Parameters: 97 | weight_2d: The weight matrix of 2D conv to be inflated. 98 | weight_3d: The weight matrix of 3D conv to be initialized. 99 | inflation_mode: the mode of inflation 100 | """ 101 | assert inflation_mode in ["constant", "replicate"] 102 | assert weight_3d.shape[:2] == weight_2d.shape[:2] 103 | with torch.no_grad(): 104 | if inflation_mode == "replicate": 105 | depth = weight_3d.size(2) 106 | weight_3d.copy_(weight_2d.unsqueeze(2).repeat(1, 1, depth, 1, 1) / depth) 107 | else: 108 | weight_3d.fill_(0.0) 109 | weight_3d[:, :, -1].copy_(weight_2d) 110 | return weight_3d 111 | 112 | 113 | def inflate_bias(bias_2d: torch.Tensor, bias_3d: torch.Tensor, inflation_mode: str): 114 | """ 115 | Inflate a 2D convolution bias tensor to a 3D one 116 | Parameters: 117 | bias_2d: The bias tensor of 2D conv to be inflated. 118 | bias_3d: The bias tensor of 3D conv to be initialized. 119 | inflation_mode: Placeholder to align `inflate_weight`. 120 | """ 121 | assert bias_3d.shape == bias_2d.shape 122 | with torch.no_grad(): 123 | bias_3d.copy_(bias_2d) 124 | return bias_3d 125 | 126 | 127 | def modify_state_dict(layer, state_dict, prefix, inflate_weight_fn, inflate_bias_fn): 128 | """ 129 | the main function to inflated 2D parameters to 3D. 130 | """ 131 | weight_name = prefix + "weight" 132 | bias_name = prefix + "bias" 133 | if weight_name in state_dict: 134 | weight_2d = state_dict[weight_name] 135 | if weight_2d.dim() == 4: 136 | # Assuming the 2D weights are 4D tensors (out_channels, in_channels, h, w) 137 | weight_3d = inflate_weight_fn( 138 | weight_2d=weight_2d, 139 | weight_3d=layer.weight, 140 | inflation_mode=layer.inflation_mode, 141 | ) 142 | state_dict[weight_name] = weight_3d 143 | else: 144 | return state_dict 145 | # It's a 3d state dict, should not do inflation on both bias and weight. 146 | if bias_name in state_dict: 147 | bias_2d = state_dict[bias_name] 148 | if bias_2d.dim() == 1: 149 | # Assuming the 2D biases are 1D tensors (out_channels,) 150 | bias_3d = inflate_bias_fn( 151 | bias_2d=bias_2d, 152 | bias_3d=layer.bias, 153 | inflation_mode=layer.inflation_mode, 154 | ) 155 | state_dict[bias_name] = bias_3d 156 | return state_dict 157 | -------------------------------------------------------------------------------- /src/models/dit_v2/rope.py: -------------------------------------------------------------------------------- 1 | # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 2 | # // 3 | # // Licensed under the Apache License, Version 2.0 (the "License"); 4 | # // you may not use this file except in compliance with the License. 5 | # // You may obtain a copy of the License at 6 | # // 7 | # // http://www.apache.org/licenses/LICENSE-2.0 8 | # // 9 | # // Unless required by applicable law or agreed to in writing, software 10 | # // distributed under the License is distributed on an "AS IS" BASIS, 11 | # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # // See the License for the specific language governing permissions and 13 | # // limitations under the License. 14 | 15 | from functools import lru_cache 16 | from typing import Optional, Tuple 17 | import torch 18 | from einops import rearrange 19 | from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb 20 | from torch import nn 21 | 22 | from src.common.cache import Cache 23 | 24 | 25 | class RotaryEmbeddingBase(nn.Module): 26 | def __init__(self, dim: int, rope_dim: int): 27 | super().__init__() 28 | self.rope = RotaryEmbedding( 29 | dim=dim // rope_dim, 30 | freqs_for="pixel", 31 | max_freq=256, 32 | ) 33 | # 1. Set model.requires_grad_(True) after model creation will make 34 | # the `requires_grad=False` for rope freqs no longer hold. 35 | # 2. Even if we don't set requires_grad_(True) explicitly, 36 | # FSDP is not memory efficient when handling fsdp_wrap 37 | # with mixed requires_grad=True/False. 38 | # With above consideration, it is easier just remove the freqs 39 | # out of nn.Parameters when `learned_freq=False` 40 | freqs = self.rope.freqs 41 | del self.rope.freqs 42 | self.rope.register_buffer("freqs", freqs.data) 43 | 44 | @lru_cache(maxsize=128) 45 | def get_axial_freqs(self, *dims): 46 | return self.rope.get_axial_freqs(*dims) 47 | 48 | 49 | class RotaryEmbedding3d(RotaryEmbeddingBase): 50 | def __init__(self, dim: int): 51 | super().__init__(dim, rope_dim=3) 52 | self.mm = False 53 | 54 | def forward( 55 | self, 56 | q: torch.FloatTensor, # b h l d 57 | k: torch.FloatTensor, # b h l d 58 | size: Tuple[int, int, int], 59 | ) -> Tuple[ 60 | torch.FloatTensor, 61 | torch.FloatTensor, 62 | ]: 63 | T, H, W = size 64 | freqs = self.get_axial_freqs(T, H, W) 65 | q = rearrange(q, "b h (T H W) d -> b h T H W d", T=T, H=H, W=W) 66 | k = rearrange(k, "b h (T H W) d -> b h T H W d", T=T, H=H, W=W) 67 | q = apply_rotary_emb(freqs, q.float()).to(q.dtype) 68 | k = apply_rotary_emb(freqs, k.float()).to(k.dtype) 69 | q = rearrange(q, "b h T H W d -> b h (T H W) d") 70 | k = rearrange(k, "b h T H W d -> b h (T H W) d") 71 | return q, k 72 | 73 | 74 | class MMRotaryEmbeddingBase(RotaryEmbeddingBase): 75 | def __init__(self, dim: int, rope_dim: int): 76 | super().__init__(dim, rope_dim) 77 | self.rope = RotaryEmbedding( 78 | dim=dim // rope_dim, 79 | freqs_for="lang", 80 | theta=10000, 81 | ) 82 | freqs = self.rope.freqs 83 | del self.rope.freqs 84 | self.rope.register_buffer("freqs", freqs.data) 85 | self.mm = True 86 | 87 | 88 | class NaMMRotaryEmbedding3d(MMRotaryEmbeddingBase): 89 | def __init__(self, dim: int): 90 | super().__init__(dim, rope_dim=3) 91 | 92 | def forward( 93 | self, 94 | vid_q: torch.FloatTensor, # L h d 95 | vid_k: torch.FloatTensor, # L h d 96 | vid_shape: torch.LongTensor, # B 3 97 | txt_q: torch.FloatTensor, # L h d 98 | txt_k: torch.FloatTensor, # L h d 99 | txt_shape: torch.LongTensor, # B 1 100 | cache: Cache, 101 | ) -> Tuple[ 102 | torch.FloatTensor, 103 | torch.FloatTensor, 104 | torch.FloatTensor, 105 | torch.FloatTensor, 106 | ]: 107 | vid_freqs, txt_freqs = cache( 108 | "mmrope_freqs_3d", 109 | lambda: self.get_freqs(vid_shape, txt_shape), 110 | ) 111 | target_device = vid_q.device 112 | if vid_freqs.device != target_device: 113 | vid_freqs = vid_freqs.to(target_device) 114 | if txt_freqs.device != target_device: 115 | txt_freqs = txt_freqs.to(target_device) 116 | vid_q = rearrange(vid_q, "L h d -> h L d") 117 | vid_k = rearrange(vid_k, "L h d -> h L d") 118 | vid_q = apply_rotary_emb(vid_freqs, vid_q.float()).to(vid_q.dtype) 119 | vid_k = apply_rotary_emb(vid_freqs, vid_k.float()).to(vid_k.dtype) 120 | vid_q = rearrange(vid_q, "h L d -> L h d") 121 | vid_k = rearrange(vid_k, "h L d -> L h d") 122 | 123 | txt_q = rearrange(txt_q, "L h d -> h L d") 124 | txt_k = rearrange(txt_k, "L h d -> h L d") 125 | txt_q = apply_rotary_emb(txt_freqs, txt_q.float()).to(txt_q.dtype) 126 | txt_k = apply_rotary_emb(txt_freqs, txt_k.float()).to(txt_k.dtype) 127 | txt_q = rearrange(txt_q, "h L d -> L h d") 128 | txt_k = rearrange(txt_k, "h L d -> L h d") 129 | return vid_q, vid_k, txt_q, txt_k 130 | 131 | def get_freqs( 132 | self, 133 | vid_shape: torch.LongTensor, 134 | txt_shape: torch.LongTensor, 135 | ) -> Tuple[ 136 | torch.Tensor, 137 | torch.Tensor, 138 | ]: 139 | vid_freqs = self.get_axial_freqs(1024, 128, 128) 140 | txt_freqs = self.get_axial_freqs(1024) 141 | vid_freq_list, txt_freq_list = [], [] 142 | for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()): 143 | vid_freq = vid_freqs[l : l + f, :h, :w].reshape(-1, vid_freqs.size(-1)) 144 | txt_freq = txt_freqs[:l].repeat(1, 3).reshape(-1, vid_freqs.size(-1)) 145 | vid_freq_list.append(vid_freq) 146 | txt_freq_list.append(txt_freq) 147 | return torch.cat(vid_freq_list, dim=0), torch.cat(txt_freq_list, dim=0) 148 | 149 | 150 | def get_na_rope(rope_type: Optional[str], dim: int): 151 | if rope_type is None: 152 | return None 153 | if rope_type == "mmrope3d": 154 | return NaMMRotaryEmbedding3d(dim=dim) 155 | raise NotImplementedError(f"{rope_type} is not supported.") 156 | -------------------------------------------------------------------------------- /src/common/distributed/advanced.py: -------------------------------------------------------------------------------- 1 | # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 2 | # // 3 | # // Licensed under the Apache License, Version 2.0 (the "License"); 4 | # // you may not use this file except in compliance with the License. 5 | # // You may obtain a copy of the License at 6 | # // 7 | # // http://www.apache.org/licenses/LICENSE-2.0 8 | # // 9 | # // Unless required by applicable law or agreed to in writing, software 10 | # // distributed under the License is distributed on an "AS IS" BASIS, 11 | # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # // See the License for the specific language governing permissions and 13 | # // limitations under the License. 14 | 15 | """ 16 | Advanced distributed functions for sequence parallel. 17 | """ 18 | 19 | from typing import Optional, List 20 | import torch 21 | import torch.distributed as dist 22 | from torch.distributed.device_mesh import DeviceMesh, init_device_mesh 23 | from torch.distributed.fsdp import ShardingStrategy 24 | 25 | from .basic import get_global_rank, get_world_size 26 | 27 | 28 | _DATA_PARALLEL_GROUP = None 29 | _SEQUENCE_PARALLEL_GROUP = None 30 | _SEQUENCE_PARALLEL_CPU_GROUP = None 31 | _MODEL_SHARD_CPU_INTER_GROUP = None 32 | _MODEL_SHARD_CPU_INTRA_GROUP = None 33 | _MODEL_SHARD_INTER_GROUP = None 34 | _MODEL_SHARD_INTRA_GROUP = None 35 | _SEQUENCE_PARALLEL_GLOBAL_RANKS = None 36 | 37 | 38 | def get_data_parallel_group() -> Optional[dist.ProcessGroup]: 39 | """ 40 | Get data parallel process group. 41 | """ 42 | return _DATA_PARALLEL_GROUP 43 | 44 | 45 | def get_sequence_parallel_group() -> Optional[dist.ProcessGroup]: 46 | """ 47 | Get sequence parallel process group. 48 | """ 49 | return _SEQUENCE_PARALLEL_GROUP 50 | 51 | 52 | def get_sequence_parallel_cpu_group() -> Optional[dist.ProcessGroup]: 53 | """ 54 | Get sequence parallel CPU process group. 55 | """ 56 | return _SEQUENCE_PARALLEL_CPU_GROUP 57 | 58 | 59 | def get_data_parallel_rank() -> int: 60 | """ 61 | Get data parallel rank. 62 | """ 63 | group = get_data_parallel_group() 64 | return dist.get_rank(group) if group else get_global_rank() 65 | 66 | 67 | def get_data_parallel_world_size() -> int: 68 | """ 69 | Get data parallel world size. 70 | """ 71 | group = get_data_parallel_group() 72 | return dist.get_world_size(group) if group else get_world_size() 73 | 74 | 75 | def get_sequence_parallel_rank() -> int: 76 | """ 77 | Get sequence parallel rank. 78 | """ 79 | group = get_sequence_parallel_group() 80 | return dist.get_rank(group) if group else 0 81 | 82 | 83 | def get_sequence_parallel_world_size() -> int: 84 | """ 85 | Get sequence parallel world size. 86 | """ 87 | group = get_sequence_parallel_group() 88 | return dist.get_world_size(group) if group else 1 89 | 90 | 91 | def get_model_shard_cpu_intra_group() -> Optional[dist.ProcessGroup]: 92 | """ 93 | Get the CPU intra process group of model sharding. 94 | """ 95 | return _MODEL_SHARD_CPU_INTRA_GROUP 96 | 97 | 98 | def get_model_shard_cpu_inter_group() -> Optional[dist.ProcessGroup]: 99 | """ 100 | Get the CPU inter process group of model sharding. 101 | """ 102 | return _MODEL_SHARD_CPU_INTER_GROUP 103 | 104 | 105 | def get_model_shard_intra_group() -> Optional[dist.ProcessGroup]: 106 | """ 107 | Get the GPU intra process group of model sharding. 108 | """ 109 | return _MODEL_SHARD_INTRA_GROUP 110 | 111 | 112 | def get_model_shard_inter_group() -> Optional[dist.ProcessGroup]: 113 | """ 114 | Get the GPU inter process group of model sharding. 115 | """ 116 | return _MODEL_SHARD_INTER_GROUP 117 | 118 | 119 | def init_sequence_parallel(sequence_parallel_size: int): 120 | """ 121 | Initialize sequence parallel. 122 | """ 123 | global _DATA_PARALLEL_GROUP 124 | global _SEQUENCE_PARALLEL_GROUP 125 | global _SEQUENCE_PARALLEL_CPU_GROUP 126 | global _SEQUENCE_PARALLEL_GLOBAL_RANKS 127 | assert dist.is_initialized() 128 | world_size = dist.get_world_size() 129 | rank = dist.get_rank() 130 | data_parallel_size = world_size // sequence_parallel_size 131 | for i in range(data_parallel_size): 132 | start_rank = i * sequence_parallel_size 133 | end_rank = (i + 1) * sequence_parallel_size 134 | ranks = range(start_rank, end_rank) 135 | group = dist.new_group(ranks) 136 | cpu_group = dist.new_group(ranks, backend="gloo") 137 | if rank in ranks: 138 | _SEQUENCE_PARALLEL_GROUP = group 139 | _SEQUENCE_PARALLEL_CPU_GROUP = cpu_group 140 | _SEQUENCE_PARALLEL_GLOBAL_RANKS = list(ranks) 141 | 142 | 143 | def init_model_shard_group( 144 | *, 145 | sharding_strategy: ShardingStrategy, 146 | device_mesh: Optional[DeviceMesh] = None, 147 | ): 148 | """ 149 | Initialize process group of model sharding. 150 | """ 151 | global _MODEL_SHARD_INTER_GROUP 152 | global _MODEL_SHARD_INTRA_GROUP 153 | global _MODEL_SHARD_CPU_INTER_GROUP 154 | global _MODEL_SHARD_CPU_INTRA_GROUP 155 | assert dist.is_initialized() 156 | world_size = dist.get_world_size() 157 | if device_mesh is not None: 158 | num_shards_per_group = device_mesh.shape[1] 159 | elif sharding_strategy == ShardingStrategy.NO_SHARD: 160 | num_shards_per_group = 1 161 | elif sharding_strategy in [ 162 | ShardingStrategy.HYBRID_SHARD, 163 | ShardingStrategy._HYBRID_SHARD_ZERO2, 164 | ]: 165 | num_shards_per_group = torch.cuda.device_count() 166 | else: 167 | num_shards_per_group = world_size 168 | num_groups = world_size // num_shards_per_group 169 | device_mesh = (num_groups, num_shards_per_group) 170 | 171 | gpu_mesh_2d = init_device_mesh("cuda", device_mesh, mesh_dim_names=("inter", "intra")) 172 | cpu_mesh_2d = init_device_mesh("cpu", device_mesh, mesh_dim_names=("inter", "intra")) 173 | 174 | _MODEL_SHARD_INTER_GROUP = gpu_mesh_2d.get_group("inter") 175 | _MODEL_SHARD_INTRA_GROUP = gpu_mesh_2d.get_group("intra") 176 | _MODEL_SHARD_CPU_INTER_GROUP = cpu_mesh_2d.get_group("inter") 177 | _MODEL_SHARD_CPU_INTRA_GROUP = cpu_mesh_2d.get_group("intra") 178 | 179 | def get_sequence_parallel_global_ranks() -> List[int]: 180 | """ 181 | Get all global ranks of the sequence parallel process group 182 | that the caller rank belongs to. 183 | """ 184 | if _SEQUENCE_PARALLEL_GLOBAL_RANKS is None: 185 | return [dist.get_rank()] 186 | return _SEQUENCE_PARALLEL_GLOBAL_RANKS 187 | 188 | 189 | def get_next_sequence_parallel_rank() -> int: 190 | """ 191 | Get the next global rank of the sequence parallel process group 192 | that the caller rank belongs to. 193 | """ 194 | sp_global_ranks = get_sequence_parallel_global_ranks() 195 | sp_rank = get_sequence_parallel_rank() 196 | sp_size = get_sequence_parallel_world_size() 197 | return sp_global_ranks[(sp_rank + 1) % sp_size] 198 | 199 | 200 | def get_prev_sequence_parallel_rank() -> int: 201 | """ 202 | Get the previous global rank of the sequence parallel process group 203 | that the caller rank belongs to. 204 | """ 205 | sp_global_ranks = get_sequence_parallel_global_ranks() 206 | sp_rank = get_sequence_parallel_rank() 207 | sp_size = get_sequence_parallel_world_size() 208 | return sp_global_ranks[(sp_rank + sp_size - 1) % sp_size] -------------------------------------------------------------------------------- /example_workflows/SeedVR2_Video_Upscaling.json: -------------------------------------------------------------------------------- 1 | { 2 | "id": "27decc95-db23-489f-bd36-91f56e7ec825", 3 | "revision": 0, 4 | "last_node_id": 15, 5 | "last_link_id": 18, 6 | "nodes": [ 7 | { 8 | "id": 14, 9 | "type": "SeedVR2ExtraArgs", 10 | "pos": [ 11 | 1160, 12 | 430 13 | ], 14 | "size": [ 15 | 270, 16 | 202 17 | ], 18 | "flags": {}, 19 | "order": 0, 20 | "mode": 0, 21 | "inputs": [], 22 | "outputs": [ 23 | { 24 | "label": "extra_args", 25 | "name": "extra_args", 26 | "type": "extra_args", 27 | "links": [ 28 | 17 29 | ] 30 | } 31 | ], 32 | "properties": { 33 | "aux_id": "numz/ComfyUI-SeedVR2_VideoUpscaler", 34 | "ver": "12aefc08bb302fa595a5938e1520522dc7246d7c", 35 | "Node name for S&R": "SeedVR2ExtraArgs" 36 | }, 37 | "widgets_values": [ 38 | true, 39 | 512, 40 | 64, 41 | false, 42 | false, 43 | false, 44 | "cuda:0" 45 | ] 46 | }, 47 | { 48 | "id": 15, 49 | "type": "SeedVR2", 50 | "pos": [ 51 | 1490, 52 | 130 53 | ], 54 | "size": [ 55 | 270, 56 | 194 57 | ], 58 | "flags": {}, 59 | "order": 4, 60 | "mode": 0, 61 | "inputs": [ 62 | { 63 | "label": "images", 64 | "name": "images", 65 | "type": "IMAGE", 66 | "link": 15 67 | }, 68 | { 69 | "label": "block_swap_config", 70 | "name": "block_swap_config", 71 | "shape": 7, 72 | "type": "block_swap_config", 73 | "link": 16 74 | }, 75 | { 76 | "label": "extra_args", 77 | "name": "extra_args", 78 | "shape": 7, 79 | "type": "extra_args", 80 | "link": 17 81 | } 82 | ], 83 | "outputs": [ 84 | { 85 | "label": "image", 86 | "name": "image", 87 | "type": "IMAGE", 88 | "links": [ 89 | 18 90 | ] 91 | } 92 | ], 93 | "properties": { 94 | "aux_id": "numz/ComfyUI-SeedVR2_VideoUpscaler", 95 | "ver": "12aefc08bb302fa595a5938e1520522dc7246d7c", 96 | "Node name for S&R": "SeedVR2" 97 | }, 98 | "widgets_values": [ 99 | "seedvr2_ema_3b_fp8_e4m3fn.safetensors", 100 | 100, 101 | "randomize", 102 | 1072, 103 | 5 104 | ] 105 | }, 106 | { 107 | "id": 13, 108 | "type": "SeedVR2BlockSwap", 109 | "pos": [ 110 | 1160, 111 | 260 112 | ], 113 | "size": [ 114 | 270, 115 | 110 116 | ], 117 | "flags": {}, 118 | "order": 1, 119 | "mode": 0, 120 | "inputs": [], 121 | "outputs": [ 122 | { 123 | "label": "block_swap_config", 124 | "name": "block_swap_config", 125 | "type": "block_swap_config", 126 | "links": [ 127 | 16 128 | ] 129 | } 130 | ], 131 | "properties": { 132 | "aux_id": "numz/ComfyUI-SeedVR2_VideoUpscaler", 133 | "ver": "12aefc08bb302fa595a5938e1520522dc7246d7c", 134 | "Node name for S&R": "SeedVR2BlockSwap" 135 | }, 136 | "widgets_values": [ 137 | 16, 138 | false, 139 | false 140 | ] 141 | }, 142 | { 143 | "id": 12, 144 | "type": "GetVideoComponents", 145 | "pos": [ 146 | 1160, 147 | 130 148 | ], 149 | "size": [ 150 | 150, 151 | 70 152 | ], 153 | "flags": {}, 154 | "order": 3, 155 | "mode": 0, 156 | "inputs": [ 157 | { 158 | "label": "video", 159 | "name": "video", 160 | "type": "VIDEO", 161 | "link": 10 162 | } 163 | ], 164 | "outputs": [ 165 | { 166 | "label": "images", 167 | "name": "images", 168 | "type": "IMAGE", 169 | "links": [ 170 | 15 171 | ] 172 | }, 173 | { 174 | "label": "audio", 175 | "name": "audio", 176 | "type": "AUDIO", 177 | "links": null 178 | }, 179 | { 180 | "label": "fps", 181 | "name": "fps", 182 | "type": "FLOAT", 183 | "links": [ 184 | 12 185 | ] 186 | } 187 | ], 188 | "properties": { 189 | "cnr_id": "comfy-core", 190 | "ver": "0.3.50", 191 | "Node name for S&R": "GetVideoComponents" 192 | } 193 | }, 194 | { 195 | "id": 11, 196 | "type": "CreateVideo", 197 | "pos": [ 198 | 1820, 199 | 130 200 | ], 201 | "size": [ 202 | 210, 203 | 80 204 | ], 205 | "flags": {}, 206 | "order": 5, 207 | "mode": 0, 208 | "inputs": [ 209 | { 210 | "label": "images", 211 | "name": "images", 212 | "type": "IMAGE", 213 | "link": 18 214 | }, 215 | { 216 | "label": "audio", 217 | "name": "audio", 218 | "shape": 7, 219 | "type": "AUDIO", 220 | "link": null 221 | }, 222 | { 223 | "label": "fps", 224 | "name": "fps", 225 | "type": "FLOAT", 226 | "widget": { 227 | "name": "fps" 228 | }, 229 | "link": 12 230 | } 231 | ], 232 | "outputs": [ 233 | { 234 | "label": "VIDEO", 235 | "name": "VIDEO", 236 | "type": "VIDEO", 237 | "links": [ 238 | 7 239 | ] 240 | } 241 | ], 242 | "properties": { 243 | "cnr_id": "comfy-core", 244 | "ver": "0.3.50", 245 | "Node name for S&R": "CreateVideo" 246 | }, 247 | "widgets_values": [ 248 | 30 249 | ] 250 | }, 251 | { 252 | "id": 10, 253 | "type": "SaveVideo", 254 | "pos": [ 255 | 2090, 256 | 130 257 | ], 258 | "size": [ 259 | 270, 260 | 106 261 | ], 262 | "flags": {}, 263 | "order": 6, 264 | "mode": 0, 265 | "inputs": [ 266 | { 267 | "label": "video", 268 | "name": "video", 269 | "type": "VIDEO", 270 | "link": 7 271 | } 272 | ], 273 | "outputs": [], 274 | "properties": { 275 | "cnr_id": "comfy-core", 276 | "ver": "0.3.50", 277 | "Node name for S&R": "SaveVideo" 278 | }, 279 | "widgets_values": [ 280 | "video/ComfyUI", 281 | "mp4", 282 | "auto" 283 | ] 284 | }, 285 | { 286 | "id": 8, 287 | "type": "LoadVideo", 288 | "pos": [ 289 | 820, 290 | 130 291 | ], 292 | "size": [ 293 | 280, 294 | 500 295 | ], 296 | "flags": {}, 297 | "order": 2, 298 | "mode": 0, 299 | "inputs": [], 300 | "outputs": [ 301 | { 302 | "label": "VIDEO", 303 | "name": "VIDEO", 304 | "type": "VIDEO", 305 | "links": [ 306 | 10 307 | ] 308 | } 309 | ], 310 | "properties": { 311 | "cnr_id": "comfy-core", 312 | "ver": "0.3.50", 313 | "Node name for S&R": "LoadVideo" 314 | }, 315 | "widgets_values": [ 316 | "none", 317 | "image" 318 | ] 319 | } 320 | ], 321 | "links": [ 322 | [ 323 | 7, 324 | 11, 325 | 0, 326 | 10, 327 | 0, 328 | "VIDEO" 329 | ], 330 | [ 331 | 10, 332 | 8, 333 | 0, 334 | 12, 335 | 0, 336 | "VIDEO" 337 | ], 338 | [ 339 | 12, 340 | 12, 341 | 2, 342 | 11, 343 | 2, 344 | "FLOAT" 345 | ], 346 | [ 347 | 15, 348 | 12, 349 | 0, 350 | 15, 351 | 0, 352 | "IMAGE" 353 | ], 354 | [ 355 | 16, 356 | 13, 357 | 0, 358 | 15, 359 | 1, 360 | "block_swap_config" 361 | ], 362 | [ 363 | 17, 364 | 14, 365 | 0, 366 | 15, 367 | 2, 368 | "extra_args" 369 | ], 370 | [ 371 | 18, 372 | 15, 373 | 0, 374 | 11, 375 | 0, 376 | "IMAGE" 377 | ] 378 | ], 379 | "groups": [], 380 | "config": {}, 381 | "extra": { 382 | "ds": { 383 | "scale": 0.9431135829263315, 384 | "offset": [ 385 | -724.0470189811382, 386 | 16.39777269620356 387 | ] 388 | }, 389 | "frontendVersion": "1.25.7", 390 | "VHS_latentpreview": false, 391 | "VHS_latentpreviewrate": 0, 392 | "VHS_MetadataImage": true, 393 | "VHS_KeepIntermediate": true 394 | }, 395 | "version": 0.4 396 | } --------------------------------------------------------------------------------