├── .gitattributes ├── pyramid_dit ├── mmdit_modules │ ├── __init__.py │ ├── modeling_text_encoder.py │ ├── modeling_normalization.py │ ├── modeling_embedding.py │ └── modeling_pyramid_mmdit.py ├── flux_modules │ ├── __init__.py │ ├── modeling_embedding.py │ ├── modeling_normalization.py │ └── modeling_pyramid_flux.py └── __init__.py ├── video_vae ├── __init__.py ├── modeling_causal_conv.py ├── modeling_discriminator.py ├── modeling_lpips.py ├── context_parallel_ops.py ├── modeling_loss.py └── modeling_enc_dec.py ├── __init__.py ├── diffusion_schedulers ├── __init__.py ├── scheduling_cosine_ddpm.py └── scheduling_flow_matching.py ├── requirements.txt ├── configs ├── miniflux_transformer_config.json ├── mmdit_transformer_config.json └── causal_video_vae_config.json ├── pyproject.toml ├── .github └── workflows │ └── publish.yml ├── README.md ├── LICENSE ├── .gitignore ├── fp8_optimization.py ├── latent_preview.py ├── examples ├── pyramid_flow_miniflux_text2vid_example_01.json ├── pyramid_flow_miniflux_img2vid_example_01.json └── pyramid_flow_miniflux_768_img2vid_example_01.json └── utils.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /pyramid_dit/mmdit_modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .modeling_pyramid_mmdit import PyramidDiffusionMMDiT -------------------------------------------------------------------------------- /video_vae/__init__.py: -------------------------------------------------------------------------------- 1 | from .modeling_loss import LPIPSWithDiscriminator 2 | from .modeling_causal_vae import CausalVideoVAE -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS 2 | 3 | __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] -------------------------------------------------------------------------------- /diffusion_schedulers/__init__.py: -------------------------------------------------------------------------------- 1 | from .scheduling_cosine_ddpm import DDPMCosineScheduler 2 | from .scheduling_flow_matching import PyramidFlowMatchEulerDiscreteScheduler -------------------------------------------------------------------------------- /pyramid_dit/flux_modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .modeling_pyramid_flux import PyramidFluxTransformer 2 | from .modeling_flux_block import FluxSingleTransformerBlock, FluxTransformerBlock -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | diffusers>=0.30.1 2 | accelerate>=0.30.0 3 | einops 4 | packaging 5 | sentencepiece>=0.2.0 6 | timm>=0.6.12 7 | numpy<=1.26.4 8 | protobuf>=4.25.4 9 | transformers>=4.44.2 -------------------------------------------------------------------------------- /pyramid_dit/__init__.py: -------------------------------------------------------------------------------- 1 | from .pyramid_dit_for_video_gen_pipeline import PyramidDiTForVideoGeneration 2 | #from .flux_modules import FluxSingleTransformerBlock, FluxTransformerBlock, FluxTextEncoderWithMask 3 | #from .mmdit_modules import JointTransformerBlock, SD3TextEncoderWithMask -------------------------------------------------------------------------------- /configs/miniflux_transformer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "PyramidFluxTransformer", 3 | "_diffusers_version": "0.30.3", 4 | "attention_head_dim": 64, 5 | "axes_dims_rope": [ 6 | 16, 7 | 24, 8 | 24 9 | ], 10 | "in_channels": 64, 11 | "interp_condition_pos": true, 12 | "joint_attention_dim": 4096, 13 | "num_attention_heads": 30, 14 | "num_layers": 8, 15 | "num_single_layers": 16, 16 | "patch_size": 1, 17 | "pooled_projection_dim": 768, 18 | "use_flash_attn": false, 19 | "use_gradient_checkpointing": false, 20 | "use_temporal_causal": true 21 | } 22 | -------------------------------------------------------------------------------- /configs/mmdit_transformer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "PyramidDiffusionMMDiT", 3 | "_diffusers_version": "0.30.0", 4 | "attention_head_dim": 64, 5 | "caption_projection_dim": 1536, 6 | "in_channels": 16, 7 | "joint_attention_dim": 4096, 8 | "max_num_frames": 200, 9 | "num_attention_heads": 24, 10 | "num_layers": 24, 11 | "patch_size": 2, 12 | "pooled_projection_dim": 2048, 13 | "pos_embed_max_size": 192, 14 | "pos_embed_type": "sincos", 15 | "qk_norm": "rms_norm", 16 | "sample_size": 128, 17 | "use_flash_attn": false, 18 | "use_gradient_checkpointing": false, 19 | "use_temporal_causal": true 20 | } 21 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "comfyui-pyramidflowwrapper" 3 | description = "Wrapper for PyramidFlow -models: [a/https://github.com/jy0205/Pyramid-Flow](https://github.com/jy0205/Pyramid-Flow)" 4 | version = "1.0.0" 5 | license = {file = "LICENSE"} 6 | dependencies = ["diffusers>=0.30.1", "accelerate>=0.30.0", "einops", "packaging", "sentencepiece", "timm>=0.6.12", "numpy<=1.26.4", "protobuf>=4.25.4"] 7 | 8 | [project.urls] 9 | Repository = "https://github.com/kijai/ComfyUI-PyramidFlowWrapper" 10 | # Used by Comfy Registry https://comfyregistry.org 11 | 12 | [tool.comfy] 13 | PublisherId = "kijai" 14 | DisplayName = "ComfyUI-PyramidFlowWrapper" 15 | Icon = "" 16 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to Comfy registry 2 | on: 3 | workflow_dispatch: 4 | push: 5 | branches: 6 | - main 7 | - master 8 | paths: 9 | - "pyproject.toml" 10 | 11 | jobs: 12 | publish-node: 13 | name: Publish Custom Node to registry 14 | runs-on: ubuntu-latest 15 | # if this is a forked repository. Skipping the workflow. 16 | if: github.event.repository.fork == false 17 | steps: 18 | - name: Check out code 19 | uses: actions/checkout@v4 20 | - name: Publish Custom Node 21 | uses: Comfy-Org/publish-node-action@main 22 | with: 23 | ## Add your own personal access token to your Github Repository secrets and reference it here. 24 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} 25 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ComfyUI wrapper nodes for [Pyramid-Flow](https://github.com/jy0205/Pyramid-Flow) 2 | 3 | ## UPDATE 4 | As the first Flux version is out, I'm dropping the SD3 support and refactored the whole thing, if you still want to use the old nodes they are archived in the legacy branch 5 | 6 | The fluxmini version can run with 7GB VRAM, currently it only supports 5 second videos (temp 16). 7 | Fp8 severely reduces quality and is not recommended, only use it if you must. 8 | 9 | Download models from: 10 | 11 | https://huggingface.co/Kijai/pyramid-flow-comfy/tree/main 12 | 13 | To `ComfyUI/models/diffusion_models` and `ComfyUI/models/vae` 14 | 15 | https://github.com/user-attachments/assets/1372549a-4b4e-4569-a062-8f72880e8c4e 16 | 17 | 18 | https://github.com/user-attachments/assets/d0bd38eb-6378-4cfa-ae55-1b4498b7ce84 19 | 20 | 21 | Original repo: https://github.com/jy0205/Pyramid-Flow 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Yang Jin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Xcode 2 | .DS_Store 3 | .idea 4 | 5 | # tyte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | # C extensions 10 | *.so 11 | onnx_model/*.onnx 12 | onnx_model/antelope/*.onnx 13 | 14 | 15 | logs/ 16 | prompts/ 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | wheels/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # celery beat schedule file 89 | celerybeat-schedule 90 | 91 | # SageMath parsed files 92 | *.sage.py 93 | 94 | # Environments 95 | .env 96 | .pt2/ 97 | .venv 98 | env/ 99 | venv/ 100 | ENV/ 101 | env.bak/ 102 | venv.bak/ 103 | 104 | # Spyder project settings 105 | .spyderproject 106 | .spyproject 107 | 108 | # Rope project settings 109 | .ropeproject 110 | 111 | # mkdocs documentation 112 | /site 113 | 114 | # mypy 115 | .mypy_cache/ 116 | .dmypy.json 117 | dmypy.json 118 | .bak 119 | 120 | -------------------------------------------------------------------------------- /configs/causal_video_vae_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "CausalVideoVAE", 3 | "_diffusers_version": "0.29.2", 4 | "add_post_quant_conv": true, 5 | "decoder_act_fn": "silu", 6 | "decoder_block_dropout": [ 7 | 0.0, 8 | 0.0, 9 | 0.0, 10 | 0.0 11 | ], 12 | "decoder_block_out_channels": [ 13 | 128, 14 | 256, 15 | 512, 16 | 512 17 | ], 18 | "decoder_in_channels": 16, 19 | "decoder_layers_per_block": [ 20 | 3, 21 | 3, 22 | 3, 23 | 3 24 | ], 25 | "decoder_norm_num_groups": 32, 26 | "decoder_out_channels": 3, 27 | "decoder_spatial_up_sample": [ 28 | true, 29 | true, 30 | true, 31 | false 32 | ], 33 | "decoder_temporal_up_sample": [ 34 | true, 35 | true, 36 | true, 37 | false 38 | ], 39 | "decoder_type": "causal_vae_conv", 40 | "decoder_up_block_types": [ 41 | "UpDecoderBlockCausal3D", 42 | "UpDecoderBlockCausal3D", 43 | "UpDecoderBlockCausal3D", 44 | "UpDecoderBlockCausal3D" 45 | ], 46 | "downsample_scale": 8, 47 | "encoder_act_fn": "silu", 48 | "encoder_block_dropout": [ 49 | 0.0, 50 | 0.0, 51 | 0.0, 52 | 0.0 53 | ], 54 | "encoder_block_out_channels": [ 55 | 128, 56 | 256, 57 | 512, 58 | 512 59 | ], 60 | "encoder_double_z": true, 61 | "encoder_down_block_types": [ 62 | "DownEncoderBlockCausal3D", 63 | "DownEncoderBlockCausal3D", 64 | "DownEncoderBlockCausal3D", 65 | "DownEncoderBlockCausal3D" 66 | ], 67 | "encoder_in_channels": 3, 68 | "encoder_layers_per_block": [ 69 | 2, 70 | 2, 71 | 2, 72 | 2 73 | ], 74 | "encoder_norm_num_groups": 32, 75 | "encoder_out_channels": 16, 76 | "encoder_spatial_down_sample": [ 77 | true, 78 | true, 79 | true, 80 | false 81 | ], 82 | "encoder_temporal_down_sample": [ 83 | true, 84 | true, 85 | true, 86 | false 87 | ], 88 | "encoder_type": "causal_vae_conv", 89 | "interpolate": false, 90 | "sample_size": 256, 91 | "scaling_factor": 0.13025 92 | } 93 | -------------------------------------------------------------------------------- /fp8_optimization.py: -------------------------------------------------------------------------------- 1 | #based on ComfyUI's and MinusZoneAI's fp8_linear optimization 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | def fp8_linear_forward(cls, original_dtype, input): 7 | weight_dtype = cls.weight.dtype 8 | if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: 9 | if len(input.shape) == 3: 10 | if weight_dtype == torch.float8_e4m3fn: 11 | inn = input.reshape(-1, input.shape[2]).to(torch.float8_e5m2) 12 | else: 13 | inn = input.reshape(-1, input.shape[2]).to(torch.float8_e4m3fn) 14 | w = cls.weight.t() 15 | 16 | scale_weight = torch.ones((1), device=input.device, dtype=torch.float32) 17 | scale_input = scale_weight 18 | 19 | bias = cls.bias.to(original_dtype) if cls.bias is not None else None 20 | out_dtype = original_dtype 21 | 22 | if bias is not None: 23 | o = torch._scaled_mm(inn, w, out_dtype=out_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight) 24 | else: 25 | o = torch._scaled_mm(inn, w, out_dtype=out_dtype, scale_a=scale_input, scale_b=scale_weight) 26 | 27 | if isinstance(o, tuple): 28 | o = o[0] 29 | 30 | return o.reshape((-1, input.shape[1], cls.weight.shape[0])) 31 | else: 32 | cls.to(original_dtype) 33 | out = cls.original_forward(input.to(original_dtype)) 34 | cls.to(original_dtype) 35 | return out 36 | else: 37 | return cls.original_forward(input) 38 | 39 | def convert_fp8_linear(module, original_dtype, params_to_keep): 40 | setattr(module, "fp8_matmul_enabled", True) 41 | for name, module in module.named_modules(): 42 | if not any(keyword in name for keyword in params_to_keep): 43 | if isinstance(module, nn.Linear): 44 | original_forward = module.forward 45 | setattr(module, "original_forward", original_forward) 46 | setattr(module, "forward", lambda input, m=module: fp8_linear_forward(m, original_dtype, input)) 47 | -------------------------------------------------------------------------------- /latent_preview.py: -------------------------------------------------------------------------------- 1 | import io 2 | 3 | import torch 4 | from PIL import Image 5 | import struct 6 | import numpy as np 7 | from comfy.cli_args import args, LatentPreviewMethod 8 | from comfy.taesd.taesd import TAESD 9 | import comfy.model_management 10 | import folder_paths 11 | import comfy.utils 12 | import logging 13 | 14 | MAX_PREVIEW_RESOLUTION = args.preview_size 15 | 16 | def preview_to_image(latent_image): 17 | latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1 18 | .mul(0xFF) # to 0..255 19 | ).to(device="cpu", dtype=torch.uint8, non_blocking=comfy.model_management.device_supports_non_blocking(latent_image.device)) 20 | 21 | return Image.fromarray(latents_ubyte.numpy()) 22 | 23 | class LatentPreviewer: 24 | def decode_latent_to_preview(self, x0): 25 | pass 26 | 27 | def decode_latent_to_preview_image(self, preview_format, x0): 28 | preview_image = self.decode_latent_to_preview(x0) 29 | return ("GIF", preview_image, MAX_PREVIEW_RESOLUTION) 30 | 31 | class Latent2RGBPreviewer(LatentPreviewer): 32 | def __init__(self, latent_rgb_factors, latent_rgb_factors_bias=None): 33 | latent_rgb_factors = [[0.05389399697934166, 0.025018778505575393, -0.009193515248318657], [0.02318250640590553, -0.026987363837713156, 0.040172639061236956], [0.046035451343323666, -0.02039565868920197, 0.01275569344290342], [-0.015559161155025095, 0.051403973219861246, 0.03179031307996347], [-0.02766167769640129, 0.03749545161530447, 0.003335141009473408], [0.05824598730479011, 0.021744367381243884, -0.01578925627951616], [0.05260929401500947, 0.0560165014956886, -0.027477296572565126], [0.018513891242931686, 0.041961785217662514, 0.004490763489747966], [0.024063060899760215, 0.065082853069653, 0.044343437673514896], [0.05250992323006226, 0.04361117432588933, 0.01030076055524387], [0.0038921710021782366, -0.025299228133723792, 0.019370764014574535], [-0.00011950534333568519, 0.06549370069727675, -0.03436712163379723], [-0.026020578032683626, -0.013341758571090847, -0.009119046570271953], [0.024412451175602937, 0.030135064560817174, -0.008355486384198006], [0.04002209845752687, -0.017341304390739463, 0.02818338690302971], [-0.032575108695213684, -0.009588338926775117, -0.03077312160940468]] 34 | self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu").transpose(0, 1) 35 | self.latent_rgb_factors_bias = None 36 | # if latent_rgb_factors_bias is not None: 37 | # self.latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device="cpu") 38 | 39 | def decode_latent_to_preview(self, x0): 40 | self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device) 41 | if self.latent_rgb_factors_bias is not None: 42 | self.latent_rgb_factors_bias = self.latent_rgb_factors_bias.to(dtype=x0.dtype, device=x0.device) 43 | 44 | latent_image = torch.nn.functional.linear(x0[0].permute(1, 2, 0), self.latent_rgb_factors, 45 | bias=self.latent_rgb_factors_bias) 46 | return preview_to_image(latent_image) 47 | 48 | 49 | def get_previewer(device, latent_format): 50 | previewer = None 51 | method = args.preview_method 52 | if method != LatentPreviewMethod.NoPreviews: 53 | # TODO previewer methods 54 | taesd_decoder_path = None 55 | if latent_format.taesd_decoder_name is not None: 56 | taesd_decoder_path = next( 57 | (fn for fn in folder_paths.get_filename_list("vae_approx") 58 | if fn.startswith(latent_format.taesd_decoder_name)), 59 | "" 60 | ) 61 | taesd_decoder_path = folder_paths.get_full_path("vae_approx", taesd_decoder_path) 62 | 63 | if method == LatentPreviewMethod.Auto: 64 | method = LatentPreviewMethod.Latent2RGB 65 | 66 | if previewer is None: 67 | if latent_format.latent_rgb_factors is not None: 68 | previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors, latent_format.latent_rgb_factors_bias) 69 | return previewer 70 | 71 | def prepare_callback(model, steps, x0_output_dict=None): 72 | preview_format = "JPEG" 73 | if preview_format not in ["JPEG", "PNG"]: 74 | preview_format = "JPEG" 75 | 76 | previewer = get_previewer(model.load_device, model.model.latent_format) 77 | 78 | pbar = comfy.utils.ProgressBar(steps) 79 | def callback(step, x0, x, total_steps): 80 | if x0_output_dict is not None: 81 | x0_output_dict["x0"] = x0 82 | preview_bytes = None 83 | if previewer: 84 | preview_bytes = previewer.decode_latent_to_preview_image(preview_format, x0) 85 | pbar.update_absolute(step + 1, total_steps, preview_bytes) 86 | return callback 87 | 88 | -------------------------------------------------------------------------------- /video_vae/modeling_causal_conv.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Union 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from collections import deque 6 | from einops import rearrange 7 | from timm.models.layers import trunc_normal_ 8 | from torch import Tensor 9 | 10 | from ..utils import ( 11 | is_context_parallel_initialized, 12 | get_context_parallel_rank, 13 | 14 | ) 15 | from .context_parallel_ops import ( 16 | cp_pass_from_previous_rank, 17 | ) 18 | 19 | 20 | def divisible_by(num, den): 21 | return (num % den) == 0 22 | 23 | def cast_tuple(t, length = 1): 24 | return t if isinstance(t, tuple) else ((t,) * length) 25 | 26 | def is_odd(n): 27 | return not divisible_by(n, 2) 28 | 29 | 30 | class CausalGroupNorm(nn.GroupNorm): 31 | 32 | def forward(self, x: Tensor) -> Tensor: 33 | t = x.shape[2] 34 | x = rearrange(x, 'b c t h w -> (b t) c h w') 35 | x = super().forward(x) 36 | x = rearrange(x, '(b t) c h w -> b c t h w', t=t) 37 | return x 38 | 39 | 40 | class CausalConv3d(nn.Module): 41 | 42 | def __init__( 43 | self, 44 | in_channels, 45 | out_channels, 46 | kernel_size: Union[int, Tuple[int, int, int]], 47 | stride: Union[int, Tuple[int, int, int]] = 1, 48 | pad_mode: str ='constant', 49 | **kwargs 50 | ): 51 | super().__init__() 52 | if isinstance(kernel_size, int): 53 | kernel_size = cast_tuple(kernel_size, 3) 54 | 55 | time_kernel_size, height_kernel_size, width_kernel_size = kernel_size 56 | self.time_kernel_size = time_kernel_size 57 | assert is_odd(height_kernel_size) and is_odd(width_kernel_size) 58 | dilation = kwargs.pop('dilation', 1) 59 | self.pad_mode = pad_mode 60 | 61 | if isinstance(stride, int): 62 | stride = (stride, 1, 1) 63 | 64 | time_pad = dilation * (time_kernel_size - 1) 65 | height_pad = height_kernel_size // 2 66 | width_pad = width_kernel_size // 2 67 | 68 | self.temporal_stride = stride[0] 69 | self.time_pad = time_pad 70 | self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0) 71 | self.time_uncausal_padding = (width_pad, width_pad, height_pad, height_pad, 0, 0) 72 | 73 | self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=0, dilation=dilation, **kwargs) 74 | self.cache_front_feat = deque() 75 | 76 | def _clear_context_parallel_cache(self): 77 | del self.cache_front_feat 78 | self.cache_front_feat = deque() 79 | 80 | def _init_weights(self, m): 81 | if isinstance(m, (nn.Linear, nn.Conv2d, nn.Conv3d)): 82 | trunc_normal_(m.weight, std=.02) 83 | if m.bias is not None: 84 | nn.init.constant_(m.bias, 0) 85 | elif isinstance(m, (nn.LayerNorm, nn.GroupNorm)): 86 | nn.init.constant_(m.bias, 0) 87 | nn.init.constant_(m.weight, 1.0) 88 | 89 | def context_parallel_forward(self, x): 90 | x = cp_pass_from_previous_rank(x, dim=2, kernel_size=self.time_kernel_size) 91 | 92 | x = F.pad(x, self.time_uncausal_padding, mode='constant') 93 | 94 | cp_rank = get_context_parallel_rank() 95 | if cp_rank != 0: 96 | if self.temporal_stride == 2 and self.time_kernel_size == 3: 97 | x = x[:,:,1:] 98 | 99 | x = self.conv(x) 100 | return x 101 | 102 | def forward(self, x, is_init_image=True, temporal_chunk=False): 103 | # temporal_chunk: whether to use the temporal chunk 104 | 105 | if is_context_parallel_initialized(): 106 | return self.context_parallel_forward(x) 107 | 108 | pad_mode = self.pad_mode if self.time_pad < x.shape[2] else 'constant' 109 | 110 | if not temporal_chunk: 111 | x = F.pad(x, self.time_causal_padding, mode=pad_mode) 112 | else: 113 | assert not self.training, "The feature cache should not be used in training" 114 | if is_init_image: 115 | # Encode the first chunk 116 | x = F.pad(x, self.time_causal_padding, mode=pad_mode) 117 | self._clear_context_parallel_cache() 118 | self.cache_front_feat.append(x[:, :, -2:].clone().detach()) 119 | else: 120 | x = F.pad(x, self.time_uncausal_padding, mode=pad_mode) 121 | video_front_context = self.cache_front_feat.pop() 122 | self._clear_context_parallel_cache() 123 | 124 | if self.temporal_stride == 1 and self.time_kernel_size == 3: 125 | x = torch.cat([video_front_context, x], dim=2) 126 | elif self.temporal_stride == 2 and self.time_kernel_size == 3: 127 | x = torch.cat([video_front_context[:,:,-1:], x], dim=2) 128 | 129 | self.cache_front_feat.append(x[:, :, -2:].clone().detach()) 130 | 131 | x = self.conv(x) 132 | return x -------------------------------------------------------------------------------- /video_vae/modeling_discriminator.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch.nn as nn 3 | 4 | 5 | def weights_init(m): 6 | classname = m.__class__.__name__ 7 | if classname.find('Conv') != -1: 8 | nn.init.normal_(m.weight.data, 0.0, 0.02) 9 | nn.init.constant_(m.bias.data, 0) 10 | elif classname.find('BatchNorm') != -1: 11 | nn.init.normal_(m.weight.data, 1.0, 0.02) 12 | nn.init.constant_(m.bias.data, 0) 13 | 14 | 15 | class NLayerDiscriminator(nn.Module): 16 | """Defines a PatchGAN discriminator as in Pix2Pix 17 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 18 | """ 19 | def __init__(self, input_nc=3, ndf=64, n_layers=4): 20 | """Construct a PatchGAN discriminator 21 | Parameters: 22 | input_nc (int) -- the number of channels in input images 23 | ndf (int) -- the number of filters in the last conv layer 24 | n_layers (int) -- the number of conv layers in the discriminator 25 | norm_layer -- normalization layer 26 | """ 27 | super(NLayerDiscriminator, self).__init__() 28 | 29 | # norm_layer = nn.BatchNorm2d 30 | norm_layer = nn.InstanceNorm2d 31 | 32 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 33 | use_bias = norm_layer.func != nn.BatchNorm2d 34 | else: 35 | use_bias = norm_layer != nn.BatchNorm2d 36 | 37 | kw = 4 38 | padw = 1 39 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] 40 | nf_mult = 1 41 | nf_mult_prev = 1 42 | for n in range(1, n_layers): # gradually increase the number of filters 43 | nf_mult_prev = nf_mult 44 | nf_mult = min(2 ** n, 8) 45 | sequence += [ 46 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), 47 | norm_layer(ndf * nf_mult), 48 | nn.LeakyReLU(0.2, True) 49 | ] 50 | 51 | nf_mult_prev = nf_mult 52 | nf_mult = min(2 ** n_layers, 8) 53 | sequence += [ 54 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), 55 | norm_layer(ndf * nf_mult), 56 | nn.LeakyReLU(0.2, True) 57 | ] 58 | 59 | sequence += [ 60 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map 61 | self.main = nn.Sequential(*sequence) 62 | 63 | def forward(self, input): 64 | """Standard forward.""" 65 | return self.main(input) 66 | 67 | 68 | class NLayerDiscriminator3D(nn.Module): 69 | """Defines a 3D PatchGAN discriminator as in Pix2Pix but for 3D inputs.""" 70 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): 71 | """ 72 | Construct a 3D PatchGAN discriminator 73 | 74 | Parameters: 75 | input_nc (int) -- the number of channels in input volumes 76 | ndf (int) -- the number of filters in the last conv layer 77 | n_layers (int) -- the number of conv layers in the discriminator 78 | use_actnorm (bool) -- flag to use actnorm instead of batchnorm 79 | """ 80 | super(NLayerDiscriminator3D, self).__init__() 81 | # if not use_actnorm: 82 | # norm_layer = nn.BatchNorm3d 83 | # else: 84 | # raise NotImplementedError("Not implemented.") 85 | 86 | norm_layer = nn.InstanceNorm3d 87 | 88 | if type(norm_layer) == functools.partial: 89 | use_bias = norm_layer.func != nn.BatchNorm3d 90 | else: 91 | use_bias = norm_layer != nn.BatchNorm3d 92 | 93 | kw = 4 94 | padw = 1 95 | sequence = [nn.Conv3d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] 96 | nf_mult = 1 97 | nf_mult_prev = 1 98 | for n in range(1, n_layers): # gradually increase the number of filters 99 | nf_mult_prev = nf_mult 100 | nf_mult = min(2 ** n, 8) 101 | sequence += [ 102 | nn.Conv3d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=(kw, kw, kw), stride=(1,2,2), padding=padw, bias=use_bias), 103 | norm_layer(ndf * nf_mult), 104 | nn.LeakyReLU(0.2, True) 105 | ] 106 | 107 | nf_mult_prev = nf_mult 108 | nf_mult = min(2 ** n_layers, 8) 109 | sequence += [ 110 | nn.Conv3d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=(kw, kw, kw), stride=1, padding=padw, bias=use_bias), 111 | norm_layer(ndf * nf_mult), 112 | nn.LeakyReLU(0.2, True) 113 | ] 114 | 115 | sequence += [nn.Conv3d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map 116 | self.main = nn.Sequential(*sequence) 117 | 118 | def forward(self, input): 119 | """Standard forward.""" 120 | return self.main(input) -------------------------------------------------------------------------------- /video_vae/modeling_lpips.py: -------------------------------------------------------------------------------- 1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torchvision import models 6 | from collections import namedtuple 7 | 8 | 9 | class LPIPS(nn.Module): 10 | # Learned perceptual metric 11 | def __init__(self, use_dropout=True): 12 | super().__init__() 13 | self.scaling_layer = ScalingLayer() 14 | self.chns = [64, 128, 256, 512, 512] # vg16 features 15 | self.net = vgg16(pretrained=False, requires_grad=False) 16 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 17 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 18 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 19 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 20 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 21 | self.load_from_pretrained() 22 | for param in self.parameters(): 23 | param.requires_grad = False 24 | 25 | def load_from_pretrained(self): 26 | ckpt = "/home/jinyang/models/vae/video_vae_baseline/vgg_lpips.pth" # replace with your lpips 27 | self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=True) 28 | print("loaded pretrained LPIPS loss from {}".format(ckpt)) 29 | 30 | def forward(self, input, target): 31 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) 32 | outs0, outs1 = self.net(in0_input), self.net(in1_input) 33 | feats0, feats1, diffs = {}, {}, {} 34 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 35 | for kk in range(len(self.chns)): 36 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) 37 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 38 | 39 | res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))] 40 | val = res[0] 41 | for l in range(1, len(self.chns)): 42 | val += res[l] 43 | return val 44 | 45 | 46 | class ScalingLayer(nn.Module): 47 | def __init__(self): 48 | super(ScalingLayer, self).__init__() 49 | self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 50 | self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) 51 | 52 | def forward(self, inp): 53 | return (inp - self.shift) / self.scale 54 | 55 | 56 | class NetLinLayer(nn.Module): 57 | """ A single linear layer which does a 1x1 conv """ 58 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 59 | super(NetLinLayer, self).__init__() 60 | layers = [nn.Dropout(), ] if (use_dropout) else [] 61 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] 62 | self.model = nn.Sequential(*layers) 63 | 64 | 65 | class vgg16(torch.nn.Module): 66 | def __init__(self, requires_grad=False, pretrained=True): 67 | super(vgg16, self).__init__() 68 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features 69 | self.slice1 = torch.nn.Sequential() 70 | self.slice2 = torch.nn.Sequential() 71 | self.slice3 = torch.nn.Sequential() 72 | self.slice4 = torch.nn.Sequential() 73 | self.slice5 = torch.nn.Sequential() 74 | self.N_slices = 5 75 | for x in range(4): 76 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 77 | for x in range(4, 9): 78 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 79 | for x in range(9, 16): 80 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 81 | for x in range(16, 23): 82 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 83 | for x in range(23, 30): 84 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 85 | if not requires_grad: 86 | for param in self.parameters(): 87 | param.requires_grad = False 88 | 89 | def forward(self, X): 90 | h = self.slice1(X) 91 | h_relu1_2 = h 92 | h = self.slice2(h) 93 | h_relu2_2 = h 94 | h = self.slice3(h) 95 | h_relu3_3 = h 96 | h = self.slice4(h) 97 | h_relu4_3 = h 98 | h = self.slice5(h) 99 | h_relu5_3 = h 100 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 101 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 102 | return out 103 | 104 | 105 | def normalize_tensor(x,eps=1e-10): 106 | norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True)) 107 | return x/(norm_factor+eps) 108 | 109 | 110 | def spatial_average(x, keepdim=True): 111 | return x.mean([2,3],keepdim=keepdim) 112 | 113 | 114 | if __name__ == "__main__": 115 | model = LPIPS().eval() 116 | _ = torch.manual_seed(123) 117 | img1 = (torch.rand(10, 3, 100, 100) * 2) - 1 118 | img2 = (torch.rand(10, 3, 100, 100) * 2) - 1 119 | print(model(img1, img2).shape) 120 | # embed() -------------------------------------------------------------------------------- /diffusion_schedulers/scheduling_cosine_ddpm.py: -------------------------------------------------------------------------------- 1 | import math 2 | from dataclasses import dataclass 3 | from typing import List, Optional, Tuple, Union 4 | 5 | import torch 6 | 7 | from diffusers.configuration_utils import ConfigMixin, register_to_config 8 | from diffusers.utils import BaseOutput 9 | from diffusers.utils.torch_utils import randn_tensor 10 | from diffusers.schedulers.scheduling_utils import SchedulerMixin 11 | 12 | 13 | @dataclass 14 | class DDPMSchedulerOutput(BaseOutput): 15 | """ 16 | Output class for the scheduler's step function output. 17 | 18 | Args: 19 | prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): 20 | Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the 21 | denoising loop. 22 | """ 23 | 24 | prev_sample: torch.Tensor 25 | 26 | 27 | class DDPMCosineScheduler(SchedulerMixin, ConfigMixin): 28 | 29 | @register_to_config 30 | def __init__( 31 | self, 32 | scaler: float = 1.0, 33 | s: float = 0.008, 34 | ): 35 | self.scaler = scaler 36 | self.s = torch.tensor([s]) 37 | self._init_alpha_cumprod = torch.cos(self.s / (1 + self.s) * torch.pi * 0.5) ** 2 38 | 39 | # standard deviation of the initial noise distribution 40 | self.init_noise_sigma = 1.0 41 | 42 | def _alpha_cumprod(self, t, device): 43 | if self.scaler > 1: 44 | t = 1 - (1 - t) ** self.scaler 45 | elif self.scaler < 1: 46 | t = t**self.scaler 47 | alpha_cumprod = torch.cos( 48 | (t + self.s.to(device)) / (1 + self.s.to(device)) * torch.pi * 0.5 49 | ) ** 2 / self._init_alpha_cumprod.to(device) 50 | return alpha_cumprod.clamp(0.0001, 0.9999) 51 | 52 | def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor: 53 | """ 54 | Ensures interchangeability with schedulers that need to scale the denoising model input depending on the 55 | current timestep. 56 | 57 | Args: 58 | sample (`torch.Tensor`): input sample 59 | timestep (`int`, optional): current timestep 60 | 61 | Returns: 62 | `torch.Tensor`: scaled input sample 63 | """ 64 | return sample 65 | 66 | def set_timesteps( 67 | self, 68 | num_inference_steps: int = None, 69 | timesteps: Optional[List[int]] = None, 70 | device: Union[str, torch.device] = None, 71 | ): 72 | """ 73 | Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. 74 | 75 | Args: 76 | num_inference_steps (`Dict[float, int]`): 77 | the number of diffusion steps used when generating samples with a pre-trained model. If passed, then 78 | `timesteps` must be `None`. 79 | device (`str` or `torch.device`, optional): 80 | the device to which the timesteps are moved to. {2 / 3: 20, 0.0: 10} 81 | """ 82 | if timesteps is None: 83 | timesteps = torch.linspace(1.0, 0.0, num_inference_steps + 1, device=device) 84 | if not isinstance(timesteps, torch.Tensor): 85 | timesteps = torch.Tensor(timesteps).to(device) 86 | self.timesteps = timesteps 87 | 88 | def step( 89 | self, 90 | model_output: torch.Tensor, 91 | timestep: int, 92 | sample: torch.Tensor, 93 | generator=None, 94 | return_dict: bool = True, 95 | ) -> Union[DDPMSchedulerOutput, Tuple]: 96 | dtype = model_output.dtype 97 | device = model_output.device 98 | t = timestep 99 | 100 | prev_t = self.previous_timestep(t) 101 | 102 | alpha_cumprod = self._alpha_cumprod(t, device).view(t.size(0), *[1 for _ in sample.shape[1:]]) 103 | alpha_cumprod_prev = self._alpha_cumprod(prev_t, device).view(prev_t.size(0), *[1 for _ in sample.shape[1:]]) 104 | alpha = alpha_cumprod / alpha_cumprod_prev 105 | 106 | mu = (1.0 / alpha).sqrt() * (sample - (1 - alpha) * model_output / (1 - alpha_cumprod).sqrt()) 107 | 108 | std_noise = randn_tensor(mu.shape, generator=generator, device=model_output.device, dtype=model_output.dtype) 109 | std = ((1 - alpha) * (1.0 - alpha_cumprod_prev) / (1.0 - alpha_cumprod)).sqrt() * std_noise 110 | pred = mu + std * (prev_t != 0).float().view(prev_t.size(0), *[1 for _ in sample.shape[1:]]) 111 | 112 | if not return_dict: 113 | return (pred.to(dtype),) 114 | 115 | return DDPMSchedulerOutput(prev_sample=pred.to(dtype)) 116 | 117 | def add_noise( 118 | self, 119 | original_samples: torch.Tensor, 120 | noise: torch.Tensor, 121 | timesteps: torch.Tensor, 122 | ) -> torch.Tensor: 123 | device = original_samples.device 124 | dtype = original_samples.dtype 125 | alpha_cumprod = self._alpha_cumprod(timesteps, device=device).view( 126 | timesteps.size(0), *[1 for _ in original_samples.shape[1:]] 127 | ) 128 | noisy_samples = alpha_cumprod.sqrt() * original_samples + (1 - alpha_cumprod).sqrt() * noise 129 | return noisy_samples.to(dtype=dtype) 130 | 131 | def __len__(self): 132 | return self.config.num_train_timesteps 133 | 134 | def previous_timestep(self, timestep): 135 | index = (self.timesteps - timestep[0]).abs().argmin().item() 136 | prev_t = self.timesteps[index + 1][None].expand(timestep.shape[0]) 137 | return prev_t 138 | -------------------------------------------------------------------------------- /pyramid_dit/mmdit_modules/modeling_text_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | 5 | from transformers import ( 6 | CLIPTextModelWithProjection, 7 | CLIPTokenizer, 8 | T5EncoderModel, 9 | T5TokenizerFast, 10 | ) 11 | 12 | from typing import Any, Callable, Dict, List, Optional, Union 13 | 14 | 15 | class SD3TextEncoderWithMask(nn.Module): 16 | def __init__(self, model_path, torch_dtype): 17 | super().__init__() 18 | # CLIP-L 19 | self.tokenizer = CLIPTokenizer.from_pretrained(os.path.join(model_path, 'tokenizer')) 20 | self.tokenizer_max_length = self.tokenizer.model_max_length 21 | self.text_encoder = CLIPTextModelWithProjection.from_pretrained(os.path.join(model_path, 'text_encoder'), torch_dtype=torch_dtype) 22 | 23 | # CLIP-G 24 | self.tokenizer_2 = CLIPTokenizer.from_pretrained(os.path.join(model_path, 'tokenizer_2')) 25 | self.text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(os.path.join(model_path, 'text_encoder_2'), torch_dtype=torch_dtype) 26 | 27 | # T5 28 | self.tokenizer_3 = T5TokenizerFast.from_pretrained(os.path.join(model_path, 'tokenizer_3')) 29 | self.text_encoder_3 = T5EncoderModel.from_pretrained(os.path.join(model_path, 'text_encoder_3'), torch_dtype=torch_dtype) 30 | 31 | self._freeze() 32 | 33 | def _freeze(self): 34 | for param in self.parameters(): 35 | param.requires_grad = False 36 | 37 | def _get_t5_prompt_embeds( 38 | self, 39 | prompt: Union[str, List[str]] = None, 40 | num_images_per_prompt: int = 1, 41 | device: Optional[torch.device] = None, 42 | max_sequence_length: int = 128, 43 | ): 44 | prompt = [prompt] if isinstance(prompt, str) else prompt 45 | batch_size = len(prompt) 46 | 47 | text_inputs = self.tokenizer_3( 48 | prompt, 49 | padding="max_length", 50 | max_length=max_sequence_length, 51 | truncation=True, 52 | add_special_tokens=True, 53 | return_tensors="pt", 54 | ) 55 | text_input_ids = text_inputs.input_ids 56 | prompt_attention_mask = text_inputs.attention_mask 57 | prompt_attention_mask = prompt_attention_mask.to(device) 58 | prompt_embeds = self.text_encoder_3(text_input_ids.to(device), attention_mask=prompt_attention_mask)[0] 59 | dtype = self.text_encoder_3.dtype 60 | prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) 61 | 62 | _, seq_len, _ = prompt_embeds.shape 63 | 64 | # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method 65 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 66 | prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) 67 | prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) 68 | prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) 69 | 70 | return prompt_embeds, prompt_attention_mask 71 | 72 | def _get_clip_prompt_embeds( 73 | self, 74 | prompt: Union[str, List[str]], 75 | num_images_per_prompt: int = 1, 76 | device: Optional[torch.device] = None, 77 | clip_skip: Optional[int] = None, 78 | clip_model_index: int = 0, 79 | ): 80 | 81 | clip_tokenizers = [self.tokenizer, self.tokenizer_2] 82 | clip_text_encoders = [self.text_encoder, self.text_encoder_2] 83 | 84 | tokenizer = clip_tokenizers[clip_model_index] 85 | text_encoder = clip_text_encoders[clip_model_index] 86 | 87 | batch_size = len(prompt) 88 | 89 | text_inputs = tokenizer( 90 | prompt, 91 | padding="max_length", 92 | max_length=self.tokenizer_max_length, 93 | truncation=True, 94 | return_tensors="pt", 95 | ) 96 | 97 | text_input_ids = text_inputs.input_ids 98 | prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) 99 | pooled_prompt_embeds = prompt_embeds[0] 100 | pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1) 101 | pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) 102 | 103 | return pooled_prompt_embeds 104 | 105 | def encode_prompt(self, 106 | prompt, 107 | num_images_per_prompt=1, 108 | clip_skip: Optional[int] = None, 109 | device=None, 110 | ): 111 | prompt = [prompt] if isinstance(prompt, str) else prompt 112 | 113 | pooled_prompt_embed = self._get_clip_prompt_embeds( 114 | prompt=prompt, 115 | device=device, 116 | num_images_per_prompt=num_images_per_prompt, 117 | clip_skip=clip_skip, 118 | clip_model_index=0, 119 | ) 120 | pooled_prompt_2_embed = self._get_clip_prompt_embeds( 121 | prompt=prompt, 122 | device=device, 123 | num_images_per_prompt=num_images_per_prompt, 124 | clip_skip=clip_skip, 125 | clip_model_index=1, 126 | ) 127 | pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1) 128 | 129 | prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds( 130 | prompt=prompt, 131 | num_images_per_prompt=num_images_per_prompt, 132 | device=device, 133 | ) 134 | return prompt_embeds, prompt_attention_mask, pooled_prompt_embeds 135 | 136 | def forward(self, input_prompts, device): 137 | with torch.no_grad(): 138 | prompt_embeds, prompt_attention_mask, pooled_prompt_embeds = self.encode_prompt(input_prompts, 1, clip_skip=None, device=device) 139 | 140 | return prompt_embeds, prompt_attention_mask, pooled_prompt_embeds -------------------------------------------------------------------------------- /video_vae/context_parallel_ops.py: -------------------------------------------------------------------------------- 1 | # from cogvideoX 2 | import torch 3 | import torch.nn as nn 4 | import math 5 | 6 | from ..utils import ( 7 | get_context_parallel_group, 8 | get_context_parallel_rank, 9 | get_context_parallel_world_size, 10 | get_context_parallel_group_rank, 11 | ) 12 | 13 | 14 | def _conv_split(input_, dim=2, kernel_size=1): 15 | cp_world_size = get_context_parallel_world_size() 16 | 17 | # Bypass the function if context parallel is 1 18 | if cp_world_size == 1: 19 | return input_ 20 | 21 | # print('in _conv_split, cp_rank:', cp_rank, 'input_size:', input_.shape) 22 | 23 | cp_rank = get_context_parallel_rank() 24 | 25 | dim_size = (input_.size()[dim] - kernel_size) // cp_world_size 26 | 27 | if cp_rank == 0: 28 | output = input_.transpose(dim, 0)[: dim_size + kernel_size].transpose(dim, 0) 29 | else: 30 | # output = input_.transpose(dim, 0)[cp_rank * dim_size + 1:(cp_rank + 1) * dim_size + kernel_size].transpose(dim, 0) 31 | output = input_.transpose(dim, 0)[ 32 | cp_rank * dim_size + kernel_size : (cp_rank + 1) * dim_size + kernel_size 33 | ].transpose(dim, 0) 34 | output = output.contiguous() 35 | 36 | # print('out _conv_split, cp_rank:', cp_rank, 'input_size:', output.shape) 37 | 38 | return output 39 | 40 | 41 | def _conv_gather(input_, dim=2, kernel_size=1): 42 | cp_world_size = get_context_parallel_world_size() 43 | 44 | # Bypass the function if context parallel is 1 45 | if cp_world_size == 1: 46 | return input_ 47 | 48 | group = get_context_parallel_group() 49 | cp_rank = get_context_parallel_rank() 50 | 51 | # print('in _conv_gather, cp_rank:', cp_rank, 'input_size:', input_.shape) 52 | 53 | input_first_kernel_ = input_.transpose(0, dim)[:kernel_size].transpose(0, dim).contiguous() 54 | if cp_rank == 0: 55 | input_ = input_.transpose(0, dim)[kernel_size:].transpose(0, dim).contiguous() 56 | else: 57 | input_ = input_.transpose(0, dim)[max(kernel_size - 1, 0) :].transpose(0, dim).contiguous() 58 | 59 | tensor_list = [torch.empty_like(torch.cat([input_first_kernel_, input_], dim=dim))] + [ 60 | torch.empty_like(input_) for _ in range(cp_world_size - 1) 61 | ] 62 | if cp_rank == 0: 63 | input_ = torch.cat([input_first_kernel_, input_], dim=dim) 64 | 65 | tensor_list[cp_rank] = input_ 66 | torch.distributed.all_gather(tensor_list, input_, group=group) 67 | 68 | # Note: torch.cat already creates a contiguous tensor. 69 | output = torch.cat(tensor_list, dim=dim).contiguous() 70 | 71 | # print('out _conv_gather, cp_rank:', cp_rank, 'input_size:', output.shape) 72 | 73 | return output 74 | 75 | 76 | def _cp_pass_from_previous_rank(input_, dim, kernel_size): 77 | # Bypass the function if kernel size is 1 78 | if kernel_size == 1: 79 | return input_ 80 | 81 | group = get_context_parallel_group() 82 | cp_rank = get_context_parallel_rank() 83 | cp_group_rank = get_context_parallel_group_rank() 84 | cp_world_size = get_context_parallel_world_size() 85 | 86 | # print('in _pass_from_previous_rank, cp_rank:', cp_rank, 'input_size:', input_.shape) 87 | 88 | global_rank = torch.distributed.get_rank() 89 | global_world_size = torch.distributed.get_world_size() 90 | 91 | input_ = input_.transpose(0, dim) 92 | 93 | # pass from last rank 94 | send_rank = global_rank + 1 95 | recv_rank = global_rank - 1 96 | if send_rank % cp_world_size == 0: 97 | send_rank -= cp_world_size 98 | if recv_rank % cp_world_size == cp_world_size - 1: 99 | recv_rank += cp_world_size 100 | 101 | recv_buffer = torch.empty_like(input_[-kernel_size + 1 :]).contiguous() 102 | if cp_rank < cp_world_size - 1: 103 | req_send = torch.distributed.isend(input_[-kernel_size + 1 :].contiguous(), send_rank, group=group) 104 | if cp_rank > 0: 105 | req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group) 106 | 107 | if cp_rank == 0: 108 | input_ = torch.cat([torch.zeros_like(input_[:1])] * (kernel_size - 1) + [input_], dim=0) 109 | else: 110 | req_recv.wait() 111 | input_ = torch.cat([recv_buffer, input_], dim=0) 112 | 113 | input_ = input_.transpose(0, dim).contiguous() 114 | return input_ 115 | 116 | 117 | def _drop_from_previous_rank(input_, dim, kernel_size): 118 | input_ = input_.transpose(0, dim)[kernel_size - 1 :].transpose(0, dim) 119 | return input_ 120 | 121 | 122 | class _ConvolutionScatterToContextParallelRegion(torch.autograd.Function): 123 | @staticmethod 124 | def forward(ctx, input_, dim, kernel_size): 125 | ctx.dim = dim 126 | ctx.kernel_size = kernel_size 127 | return _conv_split(input_, dim, kernel_size) 128 | 129 | @staticmethod 130 | def backward(ctx, grad_output): 131 | return _conv_gather(grad_output, ctx.dim, ctx.kernel_size), None, None 132 | 133 | 134 | class _ConvolutionGatherFromContextParallelRegion(torch.autograd.Function): 135 | @staticmethod 136 | def forward(ctx, input_, dim, kernel_size): 137 | ctx.dim = dim 138 | ctx.kernel_size = kernel_size 139 | return _conv_gather(input_, dim, kernel_size) 140 | 141 | @staticmethod 142 | def backward(ctx, grad_output): 143 | return _conv_split(grad_output, ctx.dim, ctx.kernel_size), None, None 144 | 145 | 146 | class _CPConvolutionPassFromPreviousRank(torch.autograd.Function): 147 | @staticmethod 148 | def forward(ctx, input_, dim, kernel_size): 149 | ctx.dim = dim 150 | ctx.kernel_size = kernel_size 151 | return _cp_pass_from_previous_rank(input_, dim, kernel_size) 152 | 153 | @staticmethod 154 | def backward(ctx, grad_output): 155 | return _drop_from_previous_rank(grad_output, ctx.dim, ctx.kernel_size), None, None 156 | 157 | 158 | def conv_scatter_to_context_parallel_region(input_, dim, kernel_size): 159 | return _ConvolutionScatterToContextParallelRegion.apply(input_, dim, kernel_size) 160 | 161 | 162 | def conv_gather_from_context_parallel_region(input_, dim, kernel_size): 163 | return _ConvolutionGatherFromContextParallelRegion.apply(input_, dim, kernel_size) 164 | 165 | 166 | def cp_pass_from_previous_rank(input_, dim, kernel_size): 167 | return _CPConvolutionPassFromPreviousRank.apply(input_, dim, kernel_size) 168 | 169 | 170 | 171 | 172 | 173 | -------------------------------------------------------------------------------- /pyramid_dit/flux_modules/modeling_embedding.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List, Optional, Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | from diffusers.models.activations import get_activation, FP32SiLU 10 | 11 | def get_timestep_embedding( 12 | timesteps: torch.Tensor, 13 | embedding_dim: int, 14 | flip_sin_to_cos: bool = False, 15 | downscale_freq_shift: float = 1, 16 | scale: float = 1, 17 | max_period: int = 10000, 18 | ): 19 | """ 20 | This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. 21 | 22 | Args 23 | timesteps (torch.Tensor): 24 | a 1-D Tensor of N indices, one per batch element. These may be fractional. 25 | embedding_dim (int): 26 | the dimension of the output. 27 | flip_sin_to_cos (bool): 28 | Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False) 29 | downscale_freq_shift (float): 30 | Controls the delta between frequencies between dimensions 31 | scale (float): 32 | Scaling factor applied to the embeddings. 33 | max_period (int): 34 | Controls the maximum frequency of the embeddings 35 | Returns 36 | torch.Tensor: an [N x dim] Tensor of positional embeddings. 37 | """ 38 | assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" 39 | 40 | half_dim = embedding_dim // 2 41 | exponent = -math.log(max_period) * torch.arange( 42 | start=0, end=half_dim, dtype=torch.float32, device=timesteps.device 43 | ) 44 | exponent = exponent / (half_dim - downscale_freq_shift) 45 | 46 | emb = torch.exp(exponent) 47 | emb = timesteps[:, None].float() * emb[None, :] 48 | 49 | # scale embeddings 50 | emb = scale * emb 51 | 52 | # concat sine and cosine embeddings 53 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) 54 | 55 | # flip sine and cosine embeddings 56 | if flip_sin_to_cos: 57 | emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) 58 | 59 | # zero pad 60 | if embedding_dim % 2 == 1: 61 | emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) 62 | return emb 63 | 64 | 65 | class Timesteps(nn.Module): 66 | def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1): 67 | super().__init__() 68 | self.num_channels = num_channels 69 | self.flip_sin_to_cos = flip_sin_to_cos 70 | self.downscale_freq_shift = downscale_freq_shift 71 | self.scale = scale 72 | 73 | def forward(self, timesteps): 74 | t_emb = get_timestep_embedding( 75 | timesteps, 76 | self.num_channels, 77 | flip_sin_to_cos=self.flip_sin_to_cos, 78 | downscale_freq_shift=self.downscale_freq_shift, 79 | scale=self.scale, 80 | ) 81 | return t_emb 82 | 83 | 84 | class TimestepEmbedding(nn.Module): 85 | def __init__( 86 | self, 87 | in_channels: int, 88 | time_embed_dim: int, 89 | act_fn: str = "silu", 90 | out_dim: int = None, 91 | post_act_fn: Optional[str] = None, 92 | cond_proj_dim=None, 93 | sample_proj_bias=True, 94 | ): 95 | super().__init__() 96 | 97 | self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias) 98 | 99 | if cond_proj_dim is not None: 100 | self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) 101 | else: 102 | self.cond_proj = None 103 | 104 | self.act = get_activation(act_fn) 105 | 106 | if out_dim is not None: 107 | time_embed_dim_out = out_dim 108 | else: 109 | time_embed_dim_out = time_embed_dim 110 | self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias) 111 | 112 | if post_act_fn is None: 113 | self.post_act = None 114 | else: 115 | self.post_act = get_activation(post_act_fn) 116 | 117 | def forward(self, sample, condition=None): 118 | if condition is not None: 119 | sample = sample + self.cond_proj(condition) 120 | sample = self.linear_1(sample) 121 | 122 | if self.act is not None: 123 | sample = self.act(sample) 124 | 125 | sample = self.linear_2(sample) 126 | 127 | if self.post_act is not None: 128 | sample = self.post_act(sample) 129 | return sample 130 | 131 | 132 | class PixArtAlphaTextProjection(nn.Module): 133 | """ 134 | Projects caption embeddings. Also handles dropout for classifier-free guidance. 135 | 136 | Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py 137 | """ 138 | 139 | def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh"): 140 | super().__init__() 141 | if out_features is None: 142 | out_features = hidden_size 143 | self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True) 144 | if act_fn == "gelu_tanh": 145 | self.act_1 = nn.GELU(approximate="tanh") 146 | elif act_fn == "silu": 147 | self.act_1 = nn.SiLU() 148 | elif act_fn == "silu_fp32": 149 | self.act_1 = FP32SiLU() 150 | else: 151 | raise ValueError(f"Unknown activation function: {act_fn}") 152 | self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True) 153 | 154 | def forward(self, caption): 155 | hidden_states = self.linear_1(caption) 156 | hidden_states = self.act_1(hidden_states) 157 | hidden_states = self.linear_2(hidden_states) 158 | return hidden_states 159 | 160 | 161 | class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module): 162 | def __init__(self, embedding_dim, pooled_projection_dim): 163 | super().__init__() 164 | 165 | self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) 166 | self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) 167 | self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) 168 | self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu") 169 | 170 | def forward(self, timestep, guidance, pooled_projection): 171 | timesteps_proj = self.time_proj(timestep) 172 | timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D) 173 | 174 | guidance_proj = self.time_proj(guidance) 175 | guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype)) # (N, D) 176 | 177 | time_guidance_emb = timesteps_emb + guidance_emb 178 | 179 | pooled_projections = self.text_embedder(pooled_projection) 180 | conditioning = time_guidance_emb + pooled_projections 181 | 182 | return conditioning 183 | 184 | 185 | class CombinedTimestepTextProjEmbeddings(nn.Module): 186 | def __init__(self, embedding_dim, pooled_projection_dim): 187 | super().__init__() 188 | 189 | self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) 190 | self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) 191 | self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu") 192 | 193 | def forward(self, timestep, pooled_projection): 194 | timesteps_proj = self.time_proj(timestep) 195 | timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D) 196 | 197 | pooled_projections = self.text_embedder(pooled_projection) 198 | 199 | conditioning = timesteps_emb + pooled_projections 200 | 201 | return conditioning -------------------------------------------------------------------------------- /video_vae/modeling_loss.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | from einops import rearrange 6 | from .modeling_lpips import LPIPS 7 | from .modeling_discriminator import NLayerDiscriminator, NLayerDiscriminator3D, weights_init 8 | #from IPython import embed 9 | 10 | 11 | class AdaptiveLossWeight: 12 | def __init__(self, timestep_range=[0, 1], buckets=300, weight_range=[1e-7, 1e7]): 13 | self.bucket_ranges = torch.linspace(timestep_range[0], timestep_range[1], buckets-1) 14 | self.bucket_losses = torch.ones(buckets) 15 | self.weight_range = weight_range 16 | 17 | def weight(self, timestep): 18 | indices = torch.searchsorted(self.bucket_ranges.to(timestep.device), timestep) 19 | return (1/self.bucket_losses.to(timestep.device)[indices]).clamp(*self.weight_range) 20 | 21 | def update_buckets(self, timestep, loss, beta=0.99): 22 | indices = torch.searchsorted(self.bucket_ranges.to(timestep.device), timestep).cpu() 23 | self.bucket_losses[indices] = self.bucket_losses[indices]*beta + loss.detach().cpu() * (1-beta) 24 | 25 | 26 | def hinge_d_loss(logits_real, logits_fake): 27 | loss_real = torch.mean(F.relu(1.0 - logits_real)) 28 | loss_fake = torch.mean(F.relu(1.0 + logits_fake)) 29 | d_loss = 0.5 * (loss_real + loss_fake) 30 | return d_loss 31 | 32 | 33 | def vanilla_d_loss(logits_real, logits_fake): 34 | d_loss = 0.5 * ( 35 | torch.mean(torch.nn.functional.softplus(-logits_real)) 36 | + torch.mean(torch.nn.functional.softplus(logits_fake)) 37 | ) 38 | return d_loss 39 | 40 | 41 | def adopt_weight(weight, global_step, threshold=0, value=0.0): 42 | if global_step < threshold: 43 | weight = value 44 | return weight 45 | 46 | 47 | class LPIPSWithDiscriminator(nn.Module): 48 | def __init__( 49 | self, 50 | disc_start, 51 | logvar_init=0.0, 52 | kl_weight=1.0, 53 | pixelloss_weight=1.0, 54 | perceptual_weight=1.0, 55 | # --- Discriminator Loss --- 56 | disc_num_layers=4, 57 | disc_in_channels=3, 58 | disc_factor=1.0, 59 | disc_weight=0.5, 60 | disc_loss="hinge", 61 | add_discriminator=True, 62 | using_3d_discriminator=False, 63 | ): 64 | 65 | super().__init__() 66 | assert disc_loss in ["hinge", "vanilla"] 67 | self.kl_weight = kl_weight 68 | self.pixel_weight = pixelloss_weight 69 | self.perceptual_loss = LPIPS().eval() 70 | self.perceptual_weight = perceptual_weight 71 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) 72 | 73 | if add_discriminator: 74 | disc_cls = NLayerDiscriminator3D if using_3d_discriminator else NLayerDiscriminator 75 | self.discriminator = disc_cls( 76 | input_nc=disc_in_channels, n_layers=disc_num_layers, 77 | ).apply(weights_init) 78 | else: 79 | self.discriminator = None 80 | 81 | self.discriminator_iter_start = disc_start 82 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss 83 | self.disc_factor = disc_factor 84 | self.discriminator_weight = disc_weight 85 | self.using_3d_discriminator = using_3d_discriminator 86 | 87 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 88 | if last_layer is not None: 89 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 90 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 91 | else: 92 | nll_grads = torch.autograd.grad( 93 | nll_loss, self.last_layer[0], retain_graph=True 94 | )[0] 95 | g_grads = torch.autograd.grad( 96 | g_loss, self.last_layer[0], retain_graph=True 97 | )[0] 98 | 99 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 100 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 101 | d_weight = d_weight * self.discriminator_weight 102 | return d_weight 103 | 104 | def forward( 105 | self, 106 | inputs, 107 | reconstructions, 108 | posteriors, 109 | optimizer_idx, 110 | global_step, 111 | split="train", 112 | last_layer=None, 113 | ): 114 | t = reconstructions.shape[2] 115 | inputs = rearrange(inputs, "b c t h w -> (b t) c h w").contiguous() 116 | reconstructions = rearrange(reconstructions, "b c t h w -> (b t) c h w").contiguous() 117 | 118 | if optimizer_idx == 0: 119 | # rec_loss = torch.mean(torch.abs(inputs - reconstructions), dim=(1,2,3), keepdim=True) 120 | rec_loss = torch.mean(F.mse_loss(inputs, reconstructions, reduction='none'), dim=(1,2,3), keepdim=True) 121 | 122 | if self.perceptual_weight > 0: 123 | p_loss = self.perceptual_loss(inputs, reconstructions) 124 | nll_loss = self.pixel_weight * rec_loss + self.perceptual_weight * p_loss 125 | 126 | nll_loss = nll_loss / torch.exp(self.logvar) + self.logvar 127 | weighted_nll_loss = nll_loss 128 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] 129 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 130 | 131 | kl_loss = posteriors.kl() 132 | kl_loss = torch.mean(kl_loss) 133 | 134 | disc_factor = adopt_weight( 135 | self.disc_factor, global_step, threshold=self.discriminator_iter_start 136 | ) 137 | 138 | if disc_factor > 0.0: 139 | if self.using_3d_discriminator: 140 | reconstructions = rearrange(reconstructions, '(b t) c h w -> b c t h w', t=t) 141 | 142 | logits_fake = self.discriminator(reconstructions.contiguous()) 143 | g_loss = -torch.mean(logits_fake) 144 | try: 145 | d_weight = self.calculate_adaptive_weight( 146 | nll_loss, g_loss, last_layer=last_layer 147 | ) 148 | except RuntimeError: 149 | assert not self.training 150 | d_weight = torch.tensor(0.0) 151 | else: 152 | d_weight = torch.tensor(0.0) 153 | g_loss = torch.tensor(0.0) 154 | 155 | 156 | loss = ( 157 | weighted_nll_loss 158 | + self.kl_weight * kl_loss 159 | + d_weight * disc_factor * g_loss 160 | ) 161 | log = { 162 | "{}/total_loss".format(split): loss.clone().detach().mean(), 163 | "{}/logvar".format(split): self.logvar.detach(), 164 | "{}/kl_loss".format(split): kl_loss.detach().mean(), 165 | "{}/nll_loss".format(split): nll_loss.detach().mean(), 166 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 167 | "{}/perception_loss".format(split): p_loss.detach().mean(), 168 | "{}/d_weight".format(split): d_weight.detach(), 169 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 170 | "{}/g_loss".format(split): g_loss.detach().mean(), 171 | } 172 | return loss, log 173 | 174 | if optimizer_idx == 1: 175 | if self.using_3d_discriminator: 176 | inputs = rearrange(inputs, '(b t) c h w -> b c t h w', t=t) 177 | reconstructions = rearrange(reconstructions, '(b t) c h w -> b c t h w', t=t) 178 | 179 | logits_real = self.discriminator(inputs.contiguous().detach()) 180 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 181 | 182 | disc_factor = adopt_weight( 183 | self.disc_factor, global_step, threshold=self.discriminator_iter_start 184 | ) 185 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 186 | 187 | log = { 188 | "{}/disc_loss".format(split): d_loss.clone().detach().mean(), 189 | "{}/logits_real".format(split): logits_real.detach().mean(), 190 | "{}/logits_fake".format(split): logits_fake.detach().mean(), 191 | } 192 | return d_loss, log -------------------------------------------------------------------------------- /pyramid_dit/mmdit_modules/modeling_normalization.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | from typing import Dict, Optional, Tuple 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from einops import rearrange 8 | from diffusers.utils import is_torch_version 9 | 10 | 11 | if is_torch_version(">=", "2.1.0"): 12 | LayerNorm = nn.LayerNorm 13 | else: 14 | # Has optional bias parameter compared to torch layer norm 15 | # TODO: replace with torch layernorm once min required torch version >= 2.1 16 | class LayerNorm(nn.Module): 17 | def __init__(self, dim, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True): 18 | super().__init__() 19 | 20 | self.eps = eps 21 | 22 | if isinstance(dim, numbers.Integral): 23 | dim = (dim,) 24 | 25 | self.dim = torch.Size(dim) 26 | 27 | if elementwise_affine: 28 | self.weight = nn.Parameter(torch.ones(dim)) 29 | self.bias = nn.Parameter(torch.zeros(dim)) if bias else None 30 | else: 31 | self.weight = None 32 | self.bias = None 33 | 34 | def forward(self, input): 35 | return F.layer_norm(input, self.dim, self.weight, self.bias, self.eps) 36 | 37 | 38 | class RMSNorm(nn.Module): 39 | def __init__(self, dim, eps: float, elementwise_affine: bool = True): 40 | super().__init__() 41 | 42 | self.eps = eps 43 | 44 | if isinstance(dim, numbers.Integral): 45 | dim = (dim,) 46 | 47 | self.dim = torch.Size(dim) 48 | 49 | if elementwise_affine: 50 | self.weight = nn.Parameter(torch.ones(dim)) 51 | else: 52 | self.weight = None 53 | 54 | def forward(self, hidden_states): 55 | input_dtype = hidden_states.dtype 56 | variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) 57 | hidden_states = hidden_states * torch.rsqrt(variance + self.eps) 58 | 59 | if self.weight is not None: 60 | # Handle the case where self.weight is torch.float8_e4m3fn 61 | if self.weight.dtype == torch.float8_e4m3fn: 62 | weight = self.weight.to(torch.float32) 63 | else: 64 | weight = self.weight 65 | 66 | # Convert hidden_states to the dtype of weight if necessary 67 | if hidden_states.dtype != weight.dtype: 68 | hidden_states = hidden_states.to(weight.dtype) 69 | hidden_states = hidden_states * weight 70 | 71 | hidden_states = hidden_states.to(input_dtype) 72 | 73 | return hidden_states 74 | 75 | 76 | class AdaLayerNormContinuous(nn.Module): 77 | def __init__( 78 | self, 79 | embedding_dim: int, 80 | conditioning_embedding_dim: int, 81 | # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters 82 | # because the output is immediately scaled and shifted by the projected conditioning embeddings. 83 | # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters. 84 | # However, this is how it was implemented in the original code, and it's rather likely you should 85 | # set `elementwise_affine` to False. 86 | elementwise_affine=True, 87 | eps=1e-5, 88 | bias=True, 89 | norm_type="layer_norm", 90 | ): 91 | super().__init__() 92 | self.silu = nn.SiLU() 93 | self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias) 94 | if norm_type == "layer_norm": 95 | self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias) 96 | elif norm_type == "rms_norm": 97 | self.norm = RMSNorm(embedding_dim, eps, elementwise_affine) 98 | else: 99 | raise ValueError(f"unknown norm_type {norm_type}") 100 | 101 | def forward_with_pad(self, x: torch.Tensor, conditioning_embedding: torch.Tensor, hidden_length=None) -> torch.Tensor: 102 | assert hidden_length is not None 103 | 104 | emb = self.linear(self.silu(conditioning_embedding).to(x.dtype)) 105 | batch_emb = torch.zeros_like(x).repeat(1, 1, 2) 106 | 107 | i_sum = 0 108 | num_stages = len(hidden_length) 109 | for i_p, length in enumerate(hidden_length): 110 | batch_emb[:, i_sum:i_sum+length] = emb[i_p::num_stages][:,None] 111 | i_sum += length 112 | 113 | batch_scale, batch_shift = torch.chunk(batch_emb, 2, dim=2) 114 | x = self.norm(x) * (1 + batch_scale) + batch_shift 115 | return x 116 | 117 | def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor, hidden_length=None) -> torch.Tensor: 118 | # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT) 119 | if hidden_length is not None: 120 | return self.forward_with_pad(x, conditioning_embedding, hidden_length) 121 | emb = self.linear(self.silu(conditioning_embedding).to(x.dtype)) 122 | scale, shift = torch.chunk(emb, 2, dim=1) 123 | x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] 124 | return x 125 | 126 | 127 | class AdaLayerNormZero(nn.Module): 128 | r""" 129 | Norm layer adaptive layer norm zero (adaLN-Zero). 130 | 131 | Parameters: 132 | embedding_dim (`int`): The size of each embedding vector. 133 | num_embeddings (`int`): The size of the embeddings dictionary. 134 | """ 135 | 136 | def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None): 137 | super().__init__() 138 | if num_embeddings is not None: 139 | self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim) 140 | else: 141 | self.emb = None 142 | 143 | self.silu = nn.SiLU() 144 | self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True) 145 | self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) 146 | 147 | def forward_with_pad( 148 | self, 149 | x: torch.Tensor, 150 | timestep: Optional[torch.Tensor] = None, 151 | class_labels: Optional[torch.LongTensor] = None, 152 | hidden_dtype: Optional[torch.dtype] = None, 153 | emb: Optional[torch.Tensor] = None, 154 | hidden_length: Optional[torch.Tensor] = None, 155 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 156 | # hidden_length: [[20, 30], [30, 40], [50, 60]] 157 | # x: [bs, seq_len, dim] 158 | if self.emb is not None: 159 | emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype) 160 | 161 | emb = self.linear(self.silu(emb)) 162 | batch_emb = torch.zeros_like(x).repeat(1, 1, 6) 163 | 164 | i_sum = 0 165 | num_stages = len(hidden_length) 166 | for i_p, length in enumerate(hidden_length): 167 | batch_emb[:, i_sum:i_sum+length] = emb[i_p::num_stages][:,None] 168 | i_sum += length 169 | 170 | batch_shift_msa, batch_scale_msa, batch_gate_msa, batch_shift_mlp, batch_scale_mlp, batch_gate_mlp = batch_emb.chunk(6, dim=2) 171 | x = self.norm(x) * (1 + batch_scale_msa) + batch_shift_msa 172 | return x, batch_gate_msa, batch_shift_mlp, batch_scale_mlp, batch_gate_mlp 173 | 174 | def forward( 175 | self, 176 | x: torch.Tensor, 177 | timestep: Optional[torch.Tensor] = None, 178 | class_labels: Optional[torch.LongTensor] = None, 179 | hidden_dtype: Optional[torch.dtype] = None, 180 | emb: Optional[torch.Tensor] = None, 181 | hidden_length: Optional[torch.Tensor] = None, 182 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 183 | if hidden_length is not None: 184 | return self.forward_with_pad(x, timestep, class_labels, hidden_dtype, emb, hidden_length) 185 | if self.emb is not None: 186 | emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype) 187 | emb = self.linear(self.silu(emb)) 188 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1) 189 | x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] 190 | return x, gate_msa, shift_mlp, scale_mlp, gate_mlp -------------------------------------------------------------------------------- /examples/pyramid_flow_miniflux_text2vid_example_01.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 53, 3 | "last_link_id": 81, 4 | "nodes": [ 5 | { 6 | "id": 39, 7 | "type": "Note", 8 | "pos": { 9 | "0": 30, 10 | "1": 650 11 | }, 12 | "size": { 13 | "0": 318.25567626953125, 14 | "1": 66.4825210571289 15 | }, 16 | "flags": {}, 17 | "order": 0, 18 | "mode": 0, 19 | "inputs": [], 20 | "outputs": [], 21 | "properties": {}, 22 | "widgets_values": [ 23 | "fp8 text encoder results are different from fp16!" 24 | ], 25 | "color": "#432", 26 | "bgcolor": "#653" 27 | }, 28 | { 29 | "id": 37, 30 | "type": "DualCLIPLoader", 31 | "pos": { 32 | "0": -40, 33 | "1": 480 34 | }, 35 | "size": { 36 | "0": 407.1675720214844, 37 | "1": 106 38 | }, 39 | "flags": {}, 40 | "order": 1, 41 | "mode": 0, 42 | "inputs": [], 43 | "outputs": [ 44 | { 45 | "name": "CLIP", 46 | "type": "CLIP", 47 | "links": [ 48 | 80 49 | ], 50 | "slot_index": 0 51 | } 52 | ], 53 | "properties": { 54 | "Node name for S&R": "DualCLIPLoader" 55 | }, 56 | "widgets_values": [ 57 | "clip_l.safetensors", 58 | "t5\\t5xxl_fp16.safetensors", 59 | "flux" 60 | ] 61 | }, 62 | { 63 | "id": 50, 64 | "type": "PyramidFlowSampler", 65 | "pos": { 66 | "0": 1046, 67 | "1": 144 68 | }, 69 | "size": { 70 | "0": 315, 71 | "1": 314 72 | }, 73 | "flags": {}, 74 | "order": 5, 75 | "mode": 0, 76 | "inputs": [ 77 | { 78 | "name": "model", 79 | "type": "PYRAMIDFLOWMODEL", 80 | "link": 74 81 | }, 82 | { 83 | "name": "prompt_embeds", 84 | "type": "PYRAMIDFLOWPROMPT", 85 | "link": 81 86 | }, 87 | { 88 | "name": "input_latent", 89 | "type": "LATENT", 90 | "link": null, 91 | "shape": 7 92 | } 93 | ], 94 | "outputs": [ 95 | { 96 | "name": "samples", 97 | "type": "LATENT", 98 | "links": [ 99 | 76 100 | ] 101 | } 102 | ], 103 | "properties": { 104 | "Node name for S&R": "PyramidFlowSampler" 105 | }, 106 | "widgets_values": [ 107 | 640, 108 | 384, 109 | "20, 20, 20", 110 | "10, 10, 10", 111 | 16, 112 | 7, 113 | 5, 114 | 44664248661402, 115 | "fixed", 116 | false 117 | ] 118 | }, 119 | { 120 | "id": 43, 121 | "type": "PyramidFlowVAELoader", 122 | "pos": { 123 | "0": 250, 124 | "1": 282 125 | }, 126 | "size": { 127 | "0": 411.12652587890625, 128 | "1": 82 129 | }, 130 | "flags": {}, 131 | "order": 2, 132 | "mode": 0, 133 | "inputs": [ 134 | { 135 | "name": "compile_args", 136 | "type": "PYRAMIDFLOW_COMPILEARGS", 137 | "link": null, 138 | "shape": 7 139 | } 140 | ], 141 | "outputs": [ 142 | { 143 | "name": "pyramidflow_vae", 144 | "type": "PYRAMIDFLOWVAE", 145 | "links": [ 146 | 71 147 | ], 148 | "slot_index": 0 149 | } 150 | ], 151 | "properties": { 152 | "Node name for S&R": "PyramidFlowVAELoader" 153 | }, 154 | "widgets_values": [ 155 | "pyramidflow\\pyramid_flow_vae_bf16.safetensors", 156 | "bf16" 157 | ] 158 | }, 159 | { 160 | "id": 53, 161 | "type": "PyramidFlowTextEncode", 162 | "pos": { 163 | "0": 444, 164 | "1": 476 165 | }, 166 | "size": { 167 | "0": 437.19818115234375, 168 | "1": 269.9795837402344 169 | }, 170 | "flags": {}, 171 | "order": 4, 172 | "mode": 0, 173 | "inputs": [ 174 | { 175 | "name": "clip", 176 | "type": "CLIP", 177 | "link": 80 178 | } 179 | ], 180 | "outputs": [ 181 | { 182 | "name": "prompt_embeds", 183 | "type": "PYRAMIDFLOWPROMPT", 184 | "links": [ 185 | 81 186 | ], 187 | "slot_index": 0 188 | } 189 | ], 190 | "properties": { 191 | "Node name for S&R": "PyramidFlowTextEncode" 192 | }, 193 | "widgets_values": [ 194 | "Beautiful, snowy Tokyo city is bustling. The camera moves through the bustling city street, following several people enjoying the beautiful snowy weather and shopping at nearby stalls. Gorgeous sakura petals are flying through the wind along with snowflakes, hyper quality, Ultra HD, 8K", 195 | "cartoon style, worst quality, low quality, blurry, absolute black, absolute white, low res, extra limbs, extra digits, misplaced objects, mutated anatomy, monochrome, horror", 196 | true 197 | ] 198 | }, 199 | { 200 | "id": 51, 201 | "type": "VHS_VideoCombine", 202 | "pos": { 203 | "0": 1420, 204 | "1": 113 205 | }, 206 | "size": [ 207 | 1018.2306518554688, 208 | 310 209 | ], 210 | "flags": {}, 211 | "order": 7, 212 | "mode": 0, 213 | "inputs": [ 214 | { 215 | "name": "images", 216 | "type": "IMAGE", 217 | "link": 77 218 | }, 219 | { 220 | "name": "audio", 221 | "type": "AUDIO", 222 | "link": null, 223 | "shape": 7 224 | }, 225 | { 226 | "name": "meta_batch", 227 | "type": "VHS_BatchManager", 228 | "link": null, 229 | "shape": 7 230 | }, 231 | { 232 | "name": "vae", 233 | "type": "VAE", 234 | "link": null, 235 | "shape": 7 236 | } 237 | ], 238 | "outputs": [ 239 | { 240 | "name": "Filenames", 241 | "type": "VHS_FILENAMES", 242 | "links": null 243 | } 244 | ], 245 | "properties": { 246 | "Node name for S&R": "VHS_VideoCombine" 247 | }, 248 | "widgets_values": { 249 | "frame_rate": 24, 250 | "loop_count": 0, 251 | "filename_prefix": "PyramidFlow", 252 | "format": "video/h264-mp4", 253 | "pix_fmt": "yuv420p", 254 | "crf": 19, 255 | "save_metadata": true, 256 | "pingpong": false, 257 | "save_output": true, 258 | "videopreview": { 259 | "hidden": false, 260 | "paused": false, 261 | "params": { 262 | "filename": "PyramidFlow_00129.mp4", 263 | "subfolder": "", 264 | "type": "output", 265 | "format": "video/h264-mp4", 266 | "frame_rate": 24 267 | }, 268 | "muted": false 269 | } 270 | } 271 | }, 272 | { 273 | "id": 40, 274 | "type": "PyramidFlowTransformerLoader", 275 | "pos": { 276 | "0": 225, 277 | "1": 107 278 | }, 279 | "size": { 280 | "0": 444.05462646484375, 281 | "1": 106 282 | }, 283 | "flags": {}, 284 | "order": 3, 285 | "mode": 0, 286 | "inputs": [ 287 | { 288 | "name": "compile_args", 289 | "type": "PYRAMIDFLOW_COMPILEARGS", 290 | "link": null, 291 | "shape": 7 292 | } 293 | ], 294 | "outputs": [ 295 | { 296 | "name": "pyramidflow_model", 297 | "type": "PYRAMIDFLOWMODEL", 298 | "links": [ 299 | 74 300 | ], 301 | "slot_index": 0 302 | } 303 | ], 304 | "properties": { 305 | "Node name for S&R": "PyramidFlowTransformerLoader" 306 | }, 307 | "widgets_values": [ 308 | "pyramidflow\\pyramid_flow_miniflux_bf16_v2.safetensors", 309 | "bf16", 310 | false 311 | ] 312 | }, 313 | { 314 | "id": 48, 315 | "type": "PyramidFlowVAEDecode", 316 | "pos": { 317 | "0": 1051, 318 | "1": 534 319 | }, 320 | "size": { 321 | "0": 315, 322 | "1": 150 323 | }, 324 | "flags": {}, 325 | "order": 6, 326 | "mode": 0, 327 | "inputs": [ 328 | { 329 | "name": "vae", 330 | "type": "PYRAMIDFLOWVAE", 331 | "link": 71 332 | }, 333 | { 334 | "name": "samples", 335 | "type": "LATENT", 336 | "link": 76 337 | } 338 | ], 339 | "outputs": [ 340 | { 341 | "name": "images", 342 | "type": "IMAGE", 343 | "links": [ 344 | 77 345 | ], 346 | "slot_index": 0 347 | } 348 | ], 349 | "properties": { 350 | "Node name for S&R": "PyramidFlowVAEDecode" 351 | }, 352 | "widgets_values": [ 353 | 256, 354 | 0.25, 355 | true, 356 | true 357 | ] 358 | } 359 | ], 360 | "links": [ 361 | [ 362 | 71, 363 | 43, 364 | 0, 365 | 48, 366 | 0, 367 | "PYRAMIDFLOWVAE" 368 | ], 369 | [ 370 | 74, 371 | 40, 372 | 0, 373 | 50, 374 | 0, 375 | "PYRAMIDFLOWMODEL" 376 | ], 377 | [ 378 | 76, 379 | 50, 380 | 0, 381 | 48, 382 | 1, 383 | "LATENT" 384 | ], 385 | [ 386 | 77, 387 | 48, 388 | 0, 389 | 51, 390 | 0, 391 | "IMAGE" 392 | ], 393 | [ 394 | 80, 395 | 37, 396 | 0, 397 | 53, 398 | 0, 399 | "CLIP" 400 | ], 401 | [ 402 | 81, 403 | 53, 404 | 0, 405 | 50, 406 | 1, 407 | "PYRAMIDFLOWPROMPT" 408 | ] 409 | ], 410 | "groups": [], 411 | "config": {}, 412 | "extra": { 413 | "ds": { 414 | "scale": 0.6934334949442648, 415 | "offset": [ 416 | 667.6026332109438, 417 | 191.25659609525596 418 | ] 419 | } 420 | }, 421 | "version": 0.4 422 | } -------------------------------------------------------------------------------- /pyramid_dit/flux_modules/modeling_normalization.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | from typing import Dict, Optional, Tuple 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from einops import rearrange 8 | from diffusers.utils import is_torch_version 9 | 10 | 11 | if is_torch_version(">=", "2.1.0"): 12 | LayerNorm = nn.LayerNorm 13 | else: 14 | # Has optional bias parameter compared to torch layer norm 15 | # TODO: replace with torch layernorm once min required torch version >= 2.1 16 | class LayerNorm(nn.Module): 17 | def __init__(self, dim, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True): 18 | super().__init__() 19 | 20 | self.eps = eps 21 | 22 | if isinstance(dim, numbers.Integral): 23 | dim = (dim,) 24 | 25 | self.dim = torch.Size(dim) 26 | 27 | if elementwise_affine: 28 | self.weight = nn.Parameter(torch.ones(dim)) 29 | self.bias = nn.Parameter(torch.zeros(dim)) if bias else None 30 | else: 31 | self.weight = None 32 | self.bias = None 33 | 34 | def forward(self, input): 35 | return F.layer_norm(input, self.dim, self.weight, self.bias, self.eps) 36 | 37 | 38 | class FP32LayerNorm(nn.LayerNorm): 39 | def forward(self, inputs: torch.Tensor) -> torch.Tensor: 40 | origin_dtype = inputs.dtype 41 | return F.layer_norm( 42 | inputs.float(), 43 | self.normalized_shape, 44 | self.weight.float() if self.weight is not None else None, 45 | self.bias.float() if self.bias is not None else None, 46 | self.eps, 47 | ).to(origin_dtype) 48 | 49 | 50 | class RMSNorm(nn.Module): 51 | def __init__(self, dim, eps: float, elementwise_affine: bool = True): 52 | super().__init__() 53 | 54 | self.eps = eps 55 | 56 | if isinstance(dim, numbers.Integral): 57 | dim = (dim,) 58 | 59 | self.dim = torch.Size(dim) 60 | 61 | if elementwise_affine: 62 | self.weight = nn.Parameter(torch.ones(dim)) 63 | else: 64 | self.weight = None 65 | 66 | def forward(self, hidden_states): 67 | input_dtype = hidden_states.dtype 68 | variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) 69 | hidden_states = hidden_states * torch.rsqrt(variance + self.eps) 70 | 71 | if self.weight is not None: 72 | # Handle the case where self.weight is torch.float8_e4m3fn 73 | if self.weight.dtype == torch.float8_e4m3fn: 74 | weight = self.weight.to(torch.float32) 75 | else: 76 | weight = self.weight 77 | 78 | # Convert hidden_states to the dtype of weight if necessary 79 | if hidden_states.dtype != weight.dtype: 80 | hidden_states = hidden_states.to(weight.dtype) 81 | hidden_states = hidden_states * weight 82 | 83 | hidden_states = hidden_states.to(input_dtype) 84 | 85 | return hidden_states 86 | 87 | 88 | class AdaLayerNormContinuous(nn.Module): 89 | def __init__( 90 | self, 91 | embedding_dim: int, 92 | conditioning_embedding_dim: int, 93 | # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters 94 | # because the output is immediately scaled and shifted by the projected conditioning embeddings. 95 | # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters. 96 | # However, this is how it was implemented in the original code, and it's rather likely you should 97 | # set `elementwise_affine` to False. 98 | elementwise_affine=True, 99 | eps=1e-5, 100 | bias=True, 101 | norm_type="layer_norm", 102 | ): 103 | super().__init__() 104 | self.silu = nn.SiLU() 105 | self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias) 106 | if norm_type == "layer_norm": 107 | self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias) 108 | elif norm_type == "rms_norm": 109 | self.norm = RMSNorm(embedding_dim, eps, elementwise_affine) 110 | else: 111 | raise ValueError(f"unknown norm_type {norm_type}") 112 | 113 | def forward_with_pad(self, x: torch.Tensor, conditioning_embedding: torch.Tensor, hidden_length=None) -> torch.Tensor: 114 | assert hidden_length is not None 115 | 116 | emb = self.linear(self.silu(conditioning_embedding).to(x.dtype)) 117 | batch_emb = torch.zeros_like(x).repeat(1, 1, 2) 118 | 119 | i_sum = 0 120 | num_stages = len(hidden_length) 121 | for i_p, length in enumerate(hidden_length): 122 | batch_emb[:, i_sum:i_sum+length] = emb[i_p::num_stages][:,None] 123 | i_sum += length 124 | 125 | batch_scale, batch_shift = torch.chunk(batch_emb, 2, dim=2) 126 | x = self.norm(x) * (1 + batch_scale) + batch_shift 127 | return x 128 | 129 | def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor, hidden_length=None) -> torch.Tensor: 130 | # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT) 131 | if hidden_length is not None: 132 | return self.forward_with_pad(x, conditioning_embedding, hidden_length) 133 | emb = self.linear(self.silu(conditioning_embedding).to(x.dtype)) 134 | scale, shift = torch.chunk(emb, 2, dim=1) 135 | x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] 136 | return x 137 | 138 | 139 | class AdaLayerNormZero(nn.Module): 140 | r""" 141 | Norm layer adaptive layer norm zero (adaLN-Zero). 142 | 143 | Parameters: 144 | embedding_dim (`int`): The size of each embedding vector. 145 | num_embeddings (`int`): The size of the embeddings dictionary. 146 | """ 147 | 148 | def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None): 149 | super().__init__() 150 | self.emb = None 151 | 152 | self.silu = nn.SiLU() 153 | self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True) 154 | self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) 155 | 156 | def forward_with_pad( 157 | self, 158 | x: torch.Tensor, 159 | timestep: Optional[torch.Tensor] = None, 160 | class_labels: Optional[torch.LongTensor] = None, 161 | hidden_dtype: Optional[torch.dtype] = None, 162 | emb: Optional[torch.Tensor] = None, 163 | hidden_length: Optional[torch.Tensor] = None, 164 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 165 | # hidden_length: [[20, 30], [30, 40], [50, 60]] 166 | # x: [bs, seq_len, dim] 167 | if self.emb is not None: 168 | emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype) 169 | 170 | emb = self.linear(self.silu(emb)) 171 | batch_emb = torch.zeros_like(x).repeat(1, 1, 6) 172 | 173 | i_sum = 0 174 | num_stages = len(hidden_length) 175 | for i_p, length in enumerate(hidden_length): 176 | batch_emb[:, i_sum:i_sum+length] = emb[i_p::num_stages][:,None] 177 | i_sum += length 178 | 179 | batch_shift_msa, batch_scale_msa, batch_gate_msa, batch_shift_mlp, batch_scale_mlp, batch_gate_mlp = batch_emb.chunk(6, dim=2) 180 | x = self.norm(x) * (1 + batch_scale_msa) + batch_shift_msa 181 | return x, batch_gate_msa, batch_shift_mlp, batch_scale_mlp, batch_gate_mlp 182 | 183 | def forward( 184 | self, 185 | x: torch.Tensor, 186 | timestep: Optional[torch.Tensor] = None, 187 | class_labels: Optional[torch.LongTensor] = None, 188 | hidden_dtype: Optional[torch.dtype] = None, 189 | emb: Optional[torch.Tensor] = None, 190 | hidden_length: Optional[torch.Tensor] = None, 191 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 192 | if hidden_length is not None: 193 | return self.forward_with_pad(x, timestep, class_labels, hidden_dtype, emb, hidden_length) 194 | if self.emb is not None: 195 | emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype) 196 | emb = self.linear(self.silu(emb)) 197 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1) 198 | x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] 199 | return x, gate_msa, shift_mlp, scale_mlp, gate_mlp 200 | 201 | 202 | class AdaLayerNormZeroSingle(nn.Module): 203 | r""" 204 | Norm layer adaptive layer norm zero (adaLN-Zero). 205 | 206 | Parameters: 207 | embedding_dim (`int`): The size of each embedding vector. 208 | num_embeddings (`int`): The size of the embeddings dictionary. 209 | """ 210 | 211 | def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True): 212 | super().__init__() 213 | 214 | self.silu = nn.SiLU() 215 | self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias) 216 | if norm_type == "layer_norm": 217 | self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) 218 | else: 219 | raise ValueError( 220 | f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'." 221 | ) 222 | 223 | def forward_with_pad( 224 | self, 225 | x: torch.Tensor, 226 | emb: Optional[torch.Tensor] = None, 227 | hidden_length: Optional[torch.Tensor] = None, 228 | ): 229 | emb = self.linear(self.silu(emb)) 230 | batch_emb = torch.zeros_like(x).repeat(1, 1, 3) 231 | 232 | i_sum = 0 233 | num_stages = len(hidden_length) 234 | for i_p, length in enumerate(hidden_length): 235 | batch_emb[:, i_sum:i_sum+length] = emb[i_p::num_stages][:,None] 236 | i_sum += length 237 | 238 | batch_shift_msa, batch_scale_msa, batch_gate_msa = batch_emb.chunk(3, dim=2) 239 | 240 | x = self.norm(x) * (1 + batch_scale_msa) + batch_shift_msa 241 | 242 | return x, batch_gate_msa 243 | 244 | def forward( 245 | self, 246 | x: torch.Tensor, 247 | emb: Optional[torch.Tensor] = None, 248 | hidden_length: Optional[torch.Tensor] = None, 249 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 250 | if hidden_length is not None: 251 | return self.forward_with_pad(x, emb, hidden_length) 252 | emb = self.linear(self.silu(emb)) 253 | shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1) 254 | x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] 255 | return x, gate_msa 256 | -------------------------------------------------------------------------------- /diffusion_schedulers/scheduling_flow_matching.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional, Tuple, Union, List 3 | import math 4 | import numpy as np 5 | import torch 6 | 7 | from diffusers.configuration_utils import ConfigMixin, register_to_config 8 | from diffusers.utils import BaseOutput 9 | from diffusers.schedulers.scheduling_utils import SchedulerMixin 10 | #from IPython import embed 11 | 12 | 13 | @dataclass 14 | class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput): 15 | """ 16 | Output class for the scheduler's `step` function output. 17 | 18 | Args: 19 | prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): 20 | Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the 21 | denoising loop. 22 | """ 23 | 24 | prev_sample: torch.FloatTensor 25 | 26 | 27 | class PyramidFlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): 28 | """ 29 | Euler scheduler. 30 | 31 | This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic 32 | methods the library implements for all schedulers such as loading and saving. 33 | 34 | Args: 35 | num_train_timesteps (`int`, defaults to 1000): 36 | The number of diffusion steps to train the model. 37 | timestep_spacing (`str`, defaults to `"linspace"`): 38 | The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and 39 | Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. 40 | shift (`float`, defaults to 1.0): 41 | The shift value for the timestep schedule. 42 | """ 43 | 44 | _compatibles = [] 45 | order = 1 46 | 47 | @register_to_config 48 | def __init__( 49 | self, 50 | num_train_timesteps: int = 1000, 51 | shift: float = 1.0, # Following Stable diffusion 3, 52 | stages: int = 3, 53 | stage_range: List = [0, 1/3, 2/3, 1], 54 | gamma: float = 1/3, 55 | ): 56 | 57 | self.timestep_ratios = {} # The timestep ratio for each stage 58 | self.timesteps_per_stage = {} # The detailed timesteps per stage 59 | self.sigmas_per_stage = {} 60 | self.start_sigmas = {} 61 | self.end_sigmas = {} 62 | self.ori_start_sigmas = {} 63 | 64 | # self.init_sigmas() 65 | self.init_sigmas_for_each_stage() 66 | self.sigma_min = self.sigmas[-1].item() 67 | self.sigma_max = self.sigmas[0].item() 68 | self.gamma = gamma 69 | 70 | def init_sigmas(self): 71 | """ 72 | initialize the global timesteps and sigmas 73 | """ 74 | num_train_timesteps = self.config.num_train_timesteps 75 | shift = self.config.shift 76 | 77 | timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy() 78 | timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) 79 | 80 | sigmas = timesteps / num_train_timesteps 81 | sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) 82 | 83 | self.timesteps = sigmas * num_train_timesteps 84 | 85 | self._step_index = None 86 | self._begin_index = None 87 | 88 | self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication 89 | 90 | def init_sigmas_for_each_stage(self): 91 | """ 92 | Init the timesteps for each stage 93 | """ 94 | self.init_sigmas() 95 | 96 | stage_distance = [] 97 | stages = self.config.stages 98 | training_steps = self.config.num_train_timesteps 99 | stage_range = self.config.stage_range 100 | 101 | # Init the start and end point of each stage 102 | for i_s in range(stages): 103 | # To decide the start and ends point 104 | start_indice = int(stage_range[i_s] * training_steps) 105 | start_indice = max(start_indice, 0) 106 | end_indice = int(stage_range[i_s+1] * training_steps) 107 | end_indice = min(end_indice, training_steps) 108 | start_sigma = self.sigmas[start_indice].item() 109 | end_sigma = self.sigmas[end_indice].item() if end_indice < training_steps else 0.0 110 | self.ori_start_sigmas[i_s] = start_sigma 111 | 112 | if i_s != 0: 113 | ori_sigma = 1 - start_sigma 114 | gamma = self.config.gamma 115 | corrected_sigma = (1 / (math.sqrt(1 + (1 / gamma)) * (1 - ori_sigma) + ori_sigma)) * ori_sigma 116 | # corrected_sigma = 1 / (2 - ori_sigma) * ori_sigma 117 | start_sigma = 1 - corrected_sigma 118 | 119 | stage_distance.append(start_sigma - end_sigma) 120 | self.start_sigmas[i_s] = start_sigma 121 | self.end_sigmas[i_s] = end_sigma 122 | 123 | # Determine the ratio of each stage according to flow length 124 | tot_distance = sum(stage_distance) 125 | for i_s in range(stages): 126 | if i_s == 0: 127 | start_ratio = 0.0 128 | else: 129 | start_ratio = sum(stage_distance[:i_s]) / tot_distance 130 | if i_s == stages - 1: 131 | end_ratio = 1.0 132 | else: 133 | end_ratio = sum(stage_distance[:i_s+1]) / tot_distance 134 | 135 | self.timestep_ratios[i_s] = (start_ratio, end_ratio) 136 | 137 | # Determine the timesteps and sigmas for each stage 138 | for i_s in range(stages): 139 | timestep_ratio = self.timestep_ratios[i_s] 140 | timestep_max = self.timesteps[int(timestep_ratio[0] * training_steps)] 141 | timestep_min = self.timesteps[min(int(timestep_ratio[1] * training_steps), training_steps - 1)] 142 | timesteps = np.linspace( 143 | timestep_max, timestep_min, training_steps + 1, 144 | ) 145 | self.timesteps_per_stage[i_s] = timesteps[:-1] if isinstance(timesteps, torch.Tensor) else torch.from_numpy(timesteps[:-1]) 146 | 147 | stage_sigmas = np.linspace( 148 | 1, 0, training_steps + 1, 149 | ) 150 | self.sigmas_per_stage[i_s] = torch.from_numpy(stage_sigmas[:-1]) 151 | 152 | @property 153 | def step_index(self): 154 | """ 155 | The index counter for current timestep. It will increase 1 after each scheduler step. 156 | """ 157 | return self._step_index 158 | 159 | @property 160 | def begin_index(self): 161 | """ 162 | The index for the first timestep. It should be set from pipeline with `set_begin_index` method. 163 | """ 164 | return self._begin_index 165 | 166 | # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index 167 | def set_begin_index(self, begin_index: int = 0): 168 | """ 169 | Sets the begin index for the scheduler. This function should be run from pipeline before the inference. 170 | 171 | Args: 172 | begin_index (`int`): 173 | The begin index for the scheduler. 174 | """ 175 | self._begin_index = begin_index 176 | 177 | def _sigma_to_t(self, sigma): 178 | return sigma * self.config.num_train_timesteps 179 | 180 | def set_timesteps(self, num_inference_steps: int, stage_index: int, device: Union[str, torch.device] = None): 181 | """ 182 | Setting the timesteps and sigmas for each stage 183 | """ 184 | self.num_inference_steps = num_inference_steps 185 | training_steps = self.config.num_train_timesteps 186 | self.init_sigmas() 187 | 188 | stage_timesteps = self.timesteps_per_stage[stage_index] 189 | timestep_max = stage_timesteps[0].item() 190 | timestep_min = stage_timesteps[-1].item() 191 | 192 | timesteps = np.linspace( 193 | timestep_max, timestep_min, num_inference_steps, 194 | ) 195 | self.timesteps = torch.from_numpy(timesteps).to(device=device) 196 | 197 | stage_sigmas = self.sigmas_per_stage[stage_index] 198 | sigma_max = stage_sigmas[0].item() 199 | sigma_min = stage_sigmas[-1].item() 200 | 201 | ratios = np.linspace( 202 | sigma_max, sigma_min, num_inference_steps 203 | ) 204 | sigmas = torch.from_numpy(ratios).to(device=device) 205 | self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) 206 | 207 | self._step_index = None 208 | 209 | def index_for_timestep(self, timestep, schedule_timesteps=None): 210 | if schedule_timesteps is None: 211 | schedule_timesteps = self.timesteps 212 | 213 | indices = (schedule_timesteps == timestep).nonzero() 214 | 215 | # The sigma index that is taken for the **very** first `step` 216 | # is always the second index (or the last index if there is only 1) 217 | # This way we can ensure we don't accidentally skip a sigma in 218 | # case we start in the middle of the denoising schedule (e.g. for image-to-image) 219 | pos = 1 if len(indices) > 1 else 0 220 | 221 | return indices[pos].item() 222 | 223 | def _init_step_index(self, timestep): 224 | if self.begin_index is None: 225 | if isinstance(timestep, torch.Tensor): 226 | timestep = timestep.to(self.timesteps.device) 227 | self._step_index = self.index_for_timestep(timestep) 228 | else: 229 | self._step_index = self._begin_index 230 | 231 | def step( 232 | self, 233 | model_output: torch.FloatTensor, 234 | timestep: Union[float, torch.FloatTensor], 235 | sample: torch.FloatTensor, 236 | generator: Optional[torch.Generator] = None, 237 | return_dict: bool = True, 238 | ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]: 239 | """ 240 | Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion 241 | process from the learned model outputs (most often the predicted noise). 242 | 243 | Args: 244 | model_output (`torch.FloatTensor`): 245 | The direct output from learned diffusion model. 246 | timestep (`float`): 247 | The current discrete timestep in the diffusion chain. 248 | sample (`torch.FloatTensor`): 249 | A current instance of a sample created by the diffusion process. 250 | generator (`torch.Generator`, *optional*): 251 | A random number generator. 252 | return_dict (`bool`): 253 | Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or 254 | tuple. 255 | 256 | Returns: 257 | [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`: 258 | If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is 259 | returned, otherwise a tuple is returned where the first element is the sample tensor. 260 | """ 261 | 262 | if ( 263 | isinstance(timestep, int) 264 | or isinstance(timestep, torch.IntTensor) 265 | or isinstance(timestep, torch.LongTensor) 266 | ): 267 | raise ValueError( 268 | ( 269 | "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" 270 | " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" 271 | " one of the `scheduler.timesteps` as a timestep." 272 | ), 273 | ) 274 | 275 | if self.step_index is None: 276 | self._step_index = 0 277 | 278 | # Upcast to avoid precision issues when computing prev_sample 279 | sample = sample.to(torch.float32) 280 | 281 | sigma = self.sigmas[self.step_index] 282 | sigma_next = self.sigmas[self.step_index + 1] 283 | 284 | prev_sample = sample + (sigma_next - sigma) * model_output 285 | 286 | # Cast sample back to model compatible dtype 287 | prev_sample = prev_sample.to(model_output.dtype) 288 | 289 | # upon completion increase step index by one 290 | self._step_index += 1 291 | 292 | if not return_dict: 293 | return (prev_sample,) 294 | 295 | return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample) 296 | 297 | def __len__(self): 298 | return self.config.num_train_timesteps 299 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import PIL.Image 4 | import numpy as np 5 | from torch import nn 6 | import torch.distributed as dist 7 | import timm.models.hub as timm_hub 8 | 9 | """Modified from https://github.com/CompVis/taming-transformers.git""" 10 | 11 | import hashlib 12 | import requests 13 | from tqdm import tqdm 14 | try: 15 | import piq 16 | except: 17 | pass 18 | 19 | _CONTEXT_PARALLEL_GROUP = None 20 | _CONTEXT_PARALLEL_SIZE = None 21 | 22 | 23 | def is_dist_avail_and_initialized(): 24 | if not dist.is_available(): 25 | return False 26 | if not dist.is_initialized(): 27 | return False 28 | return True 29 | 30 | 31 | def get_world_size(): 32 | if not is_dist_avail_and_initialized(): 33 | return 1 34 | return dist.get_world_size() 35 | 36 | 37 | def get_rank(): 38 | if not is_dist_avail_and_initialized(): 39 | return 0 40 | return dist.get_rank() 41 | 42 | 43 | def is_main_process(): 44 | return get_rank() == 0 45 | 46 | 47 | def is_context_parallel_initialized(): 48 | if _CONTEXT_PARALLEL_GROUP is None: 49 | return False 50 | else: 51 | return True 52 | 53 | 54 | def set_context_parallel_group(size, group): 55 | global _CONTEXT_PARALLEL_GROUP 56 | global _CONTEXT_PARALLEL_SIZE 57 | _CONTEXT_PARALLEL_GROUP = group 58 | _CONTEXT_PARALLEL_SIZE = size 59 | 60 | 61 | def initialize_context_parallel(context_parallel_size): 62 | global _CONTEXT_PARALLEL_GROUP 63 | global _CONTEXT_PARALLEL_SIZE 64 | 65 | assert _CONTEXT_PARALLEL_GROUP is None, "context parallel group is already initialized" 66 | _CONTEXT_PARALLEL_SIZE = context_parallel_size 67 | 68 | rank = torch.distributed.get_rank() 69 | world_size = torch.distributed.get_world_size() 70 | 71 | for i in range(0, world_size, context_parallel_size): 72 | ranks = range(i, i + context_parallel_size) 73 | group = torch.distributed.new_group(ranks) 74 | if rank in ranks: 75 | _CONTEXT_PARALLEL_GROUP = group 76 | break 77 | 78 | 79 | def get_context_parallel_group(): 80 | assert _CONTEXT_PARALLEL_GROUP is not None, "context parallel group is not initialized" 81 | 82 | return _CONTEXT_PARALLEL_GROUP 83 | 84 | 85 | def get_context_parallel_world_size(): 86 | assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized" 87 | 88 | return _CONTEXT_PARALLEL_SIZE 89 | 90 | 91 | def get_context_parallel_rank(): 92 | assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized" 93 | 94 | rank = get_rank() 95 | cp_rank = rank % _CONTEXT_PARALLEL_SIZE 96 | return cp_rank 97 | 98 | 99 | def get_context_parallel_group_rank(): 100 | assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized" 101 | 102 | rank = get_rank() 103 | cp_group_rank = rank // _CONTEXT_PARALLEL_SIZE 104 | 105 | return cp_group_rank 106 | 107 | 108 | def download_cached_file(url, check_hash=True, progress=False): 109 | """ 110 | Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again. 111 | If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded. 112 | """ 113 | 114 | def get_cached_file_path(): 115 | # a hack to sync the file path across processes 116 | parts = torch.hub.urlparse(url) 117 | filename = os.path.basename(parts.path) 118 | cached_file = os.path.join(timm_hub.get_cache_dir(), filename) 119 | 120 | return cached_file 121 | 122 | if is_main_process(): 123 | timm_hub.download_cached_file(url, check_hash, progress) 124 | 125 | if is_dist_avail_and_initialized(): 126 | dist.barrier() 127 | 128 | return get_cached_file_path() 129 | 130 | 131 | def convert_weights_to_fp16(model: nn.Module): 132 | """Convert applicable model parameters to fp16""" 133 | 134 | def _convert_weights_to_fp16(l): 135 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear)): 136 | l.weight.data = l.weight.data.to(torch.float16) 137 | if l.bias is not None: 138 | l.bias.data = l.bias.data.to(torch.float16) 139 | 140 | model.apply(_convert_weights_to_fp16) 141 | 142 | 143 | def convert_weights_to_bf16(model: nn.Module): 144 | """Convert applicable model parameters to fp16""" 145 | 146 | def _convert_weights_to_bf16(l): 147 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear)): 148 | l.weight.data = l.weight.data.to(torch.bfloat16) 149 | if l.bias is not None: 150 | l.bias.data = l.bias.data.to(torch.bfloat16) 151 | 152 | model.apply(_convert_weights_to_bf16) 153 | 154 | 155 | def save_result(result, result_dir, filename, remove_duplicate="", save_format='json'): 156 | import json 157 | import jsonlines 158 | print("Dump result") 159 | 160 | # Make the temp dir for saving results 161 | if not os.path.exists(result_dir): 162 | if is_main_process(): 163 | os.makedirs(result_dir) 164 | if is_dist_avail_and_initialized(): 165 | torch.distributed.barrier() 166 | 167 | result_file = os.path.join( 168 | result_dir, "%s_rank%d.json" % (filename, get_rank()) 169 | ) 170 | 171 | final_result_file = os.path.join(result_dir, f"{filename}.{save_format}") 172 | 173 | json.dump(result, open(result_file, "w")) 174 | 175 | if is_dist_avail_and_initialized(): 176 | torch.distributed.barrier() 177 | 178 | if is_main_process(): 179 | # print("rank %d starts merging results." % get_rank()) 180 | # combine results from all processes 181 | result = [] 182 | 183 | for rank in range(get_world_size()): 184 | result_file = os.path.join(result_dir, "%s_rank%d.json" % (filename, rank)) 185 | res = json.load(open(result_file, "r")) 186 | result += res 187 | 188 | # print("Remove duplicate") 189 | if remove_duplicate: 190 | result_new = [] 191 | id_set = set() 192 | for res in result: 193 | if res[remove_duplicate] not in id_set: 194 | id_set.add(res[remove_duplicate]) 195 | result_new.append(res) 196 | result = result_new 197 | 198 | if save_format == 'json': 199 | json.dump(result, open(final_result_file, "w")) 200 | else: 201 | assert save_format == 'jsonl', "Only support json adn jsonl format" 202 | with jsonlines.open(final_result_file, "w") as writer: 203 | writer.write_all(result) 204 | 205 | # print("result file saved to %s" % final_result_file) 206 | 207 | return final_result_file 208 | 209 | 210 | # resizing utils 211 | # TODO: clean up later 212 | def _resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True): 213 | h, w = input.shape[-2:] 214 | factors = (h / size[0], w / size[1]) 215 | 216 | # First, we have to determine sigma 217 | # Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171 218 | sigmas = ( 219 | max((factors[0] - 1.0) / 2.0, 0.001), 220 | max((factors[1] - 1.0) / 2.0, 0.001), 221 | ) 222 | 223 | # Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma 224 | # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206 225 | # But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now 226 | ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3)) 227 | 228 | # Make sure it is odd 229 | if (ks[0] % 2) == 0: 230 | ks = ks[0] + 1, ks[1] 231 | 232 | if (ks[1] % 2) == 0: 233 | ks = ks[0], ks[1] + 1 234 | 235 | input = _gaussian_blur2d(input, ks, sigmas) 236 | 237 | output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners) 238 | return output 239 | 240 | 241 | def _compute_padding(kernel_size): 242 | """Compute padding tuple.""" 243 | # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom) 244 | # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad 245 | if len(kernel_size) < 2: 246 | raise AssertionError(kernel_size) 247 | computed = [k - 1 for k in kernel_size] 248 | 249 | # for even kernels we need to do asymmetric padding :( 250 | out_padding = 2 * len(kernel_size) * [0] 251 | 252 | for i in range(len(kernel_size)): 253 | computed_tmp = computed[-(i + 1)] 254 | 255 | pad_front = computed_tmp // 2 256 | pad_rear = computed_tmp - pad_front 257 | 258 | out_padding[2 * i + 0] = pad_front 259 | out_padding[2 * i + 1] = pad_rear 260 | 261 | return out_padding 262 | 263 | 264 | def _filter2d(input, kernel): 265 | # prepare kernel 266 | b, c, h, w = input.shape 267 | tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype) 268 | 269 | tmp_kernel = tmp_kernel.expand(-1, c, -1, -1) 270 | 271 | height, width = tmp_kernel.shape[-2:] 272 | 273 | padding_shape: list[int] = _compute_padding([height, width]) 274 | input = torch.nn.functional.pad(input, padding_shape, mode="reflect") 275 | 276 | # kernel and input tensor reshape to align element-wise or batch-wise params 277 | tmp_kernel = tmp_kernel.reshape(-1, 1, height, width) 278 | input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1)) 279 | 280 | # convolve the tensor with the kernel. 281 | output = torch.nn.functional.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1) 282 | 283 | out = output.view(b, c, h, w) 284 | return out 285 | 286 | 287 | def _gaussian(window_size: int, sigma): 288 | if isinstance(sigma, float): 289 | sigma = torch.tensor([[sigma]]) 290 | 291 | batch_size = sigma.shape[0] 292 | 293 | x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1) 294 | 295 | if window_size % 2 == 0: 296 | x = x + 0.5 297 | 298 | gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0))) 299 | 300 | return gauss / gauss.sum(-1, keepdim=True) 301 | 302 | 303 | def _gaussian_blur2d(input, kernel_size, sigma): 304 | if isinstance(sigma, tuple): 305 | sigma = torch.tensor([sigma], dtype=input.dtype) 306 | else: 307 | sigma = sigma.to(dtype=input.dtype) 308 | 309 | ky, kx = int(kernel_size[0]), int(kernel_size[1]) 310 | bs = sigma.shape[0] 311 | kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1)) 312 | kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1)) 313 | out_x = _filter2d(input, kernel_x[..., None, :]) 314 | out = _filter2d(out_x, kernel_y[..., None]) 315 | 316 | return out 317 | 318 | 319 | URL_MAP = { 320 | "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" 321 | } 322 | 323 | CKPT_MAP = { 324 | "vgg_lpips": "vgg.pth" 325 | } 326 | 327 | MD5_MAP = { 328 | "vgg_lpips": "d507d7349b931f0638a25a48a722f98a" 329 | } 330 | 331 | 332 | def download(url, local_path, chunk_size=1024): 333 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 334 | with requests.get(url, stream=True) as r: 335 | total_size = int(r.headers.get("content-length", 0)) 336 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 337 | with open(local_path, "wb") as f: 338 | for data in r.iter_content(chunk_size=chunk_size): 339 | if data: 340 | f.write(data) 341 | pbar.update(chunk_size) 342 | 343 | 344 | def md5_hash(path): 345 | with open(path, "rb") as f: 346 | content = f.read() 347 | return hashlib.md5(content).hexdigest() 348 | 349 | 350 | def get_ckpt_path(name, root, check=False): 351 | assert name in URL_MAP 352 | path = os.path.join(root, CKPT_MAP[name]) 353 | print(md5_hash(path)) 354 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): 355 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) 356 | download(URL_MAP[name], path) 357 | md5 = md5_hash(path) 358 | assert md5 == MD5_MAP[name], md5 359 | return path 360 | 361 | 362 | class KeyNotFoundError(Exception): 363 | def __init__(self, cause, keys=None, visited=None): 364 | self.cause = cause 365 | self.keys = keys 366 | self.visited = visited 367 | messages = list() 368 | if keys is not None: 369 | messages.append("Key not found: {}".format(keys)) 370 | if visited is not None: 371 | messages.append("Visited: {}".format(visited)) 372 | messages.append("Cause:\n{}".format(cause)) 373 | message = "\n".join(messages) 374 | super().__init__(message) 375 | 376 | 377 | def retrieve( 378 | list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False 379 | ): 380 | """Given a nested list or dict return the desired value at key expanding 381 | callable nodes if necessary and :attr:`expand` is ``True``. The expansion 382 | is done in-place. 383 | 384 | Parameters 385 | ---------- 386 | list_or_dict : list or dict 387 | Possibly nested list or dictionary. 388 | key : str 389 | key/to/value, path like string describing all keys necessary to 390 | consider to get to the desired value. List indices can also be 391 | passed here. 392 | splitval : str 393 | String that defines the delimiter between keys of the 394 | different depth levels in `key`. 395 | default : obj 396 | Value returned if :attr:`key` is not found. 397 | expand : bool 398 | Whether to expand callable nodes on the path or not. 399 | 400 | Returns 401 | ------- 402 | The desired value or if :attr:`default` is not ``None`` and the 403 | :attr:`key` is not found returns ``default``. 404 | 405 | Raises 406 | ------ 407 | Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is 408 | ``None``. 409 | """ 410 | 411 | keys = key.split(splitval) 412 | 413 | success = True 414 | try: 415 | visited = [] 416 | parent = None 417 | last_key = None 418 | for key in keys: 419 | if callable(list_or_dict): 420 | if not expand: 421 | raise KeyNotFoundError( 422 | ValueError( 423 | "Trying to get past callable node with expand=False." 424 | ), 425 | keys=keys, 426 | visited=visited, 427 | ) 428 | list_or_dict = list_or_dict() 429 | parent[last_key] = list_or_dict 430 | 431 | last_key = key 432 | parent = list_or_dict 433 | 434 | try: 435 | if isinstance(list_or_dict, dict): 436 | list_or_dict = list_or_dict[key] 437 | else: 438 | list_or_dict = list_or_dict[int(key)] 439 | except (KeyError, IndexError, ValueError) as e: 440 | raise KeyNotFoundError(e, keys=keys, visited=visited) 441 | 442 | visited += [key] 443 | # final expansion of retrieved value 444 | if expand and callable(list_or_dict): 445 | list_or_dict = list_or_dict() 446 | parent[last_key] = list_or_dict 447 | except KeyNotFoundError as e: 448 | if default is None: 449 | raise e 450 | else: 451 | list_or_dict = default 452 | success = False 453 | 454 | if not pass_success: 455 | return list_or_dict 456 | else: 457 | return list_or_dict, success -------------------------------------------------------------------------------- /examples/pyramid_flow_miniflux_img2vid_example_01.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 59, 3 | "last_link_id": 95, 4 | "nodes": [ 5 | { 6 | "id": 39, 7 | "type": "Note", 8 | "pos": { 9 | "0": 30, 10 | "1": 650 11 | }, 12 | "size": { 13 | "0": 318.25567626953125, 14 | "1": 66.4825210571289 15 | }, 16 | "flags": {}, 17 | "order": 0, 18 | "mode": 0, 19 | "inputs": [], 20 | "outputs": [], 21 | "properties": {}, 22 | "widgets_values": [ 23 | "fp8 text encoder results are different from fp16!" 24 | ], 25 | "color": "#432", 26 | "bgcolor": "#653" 27 | }, 28 | { 29 | "id": 37, 30 | "type": "DualCLIPLoader", 31 | "pos": { 32 | "0": -40, 33 | "1": 480 34 | }, 35 | "size": { 36 | "0": 407.1675720214844, 37 | "1": 106 38 | }, 39 | "flags": {}, 40 | "order": 1, 41 | "mode": 0, 42 | "inputs": [], 43 | "outputs": [ 44 | { 45 | "name": "CLIP", 46 | "type": "CLIP", 47 | "links": [ 48 | 80 49 | ], 50 | "slot_index": 0 51 | } 52 | ], 53 | "properties": { 54 | "Node name for S&R": "DualCLIPLoader" 55 | }, 56 | "widgets_values": [ 57 | "clip_l.safetensors", 58 | "t5\\t5xxl_fp16.safetensors", 59 | "flux" 60 | ] 61 | }, 62 | { 63 | "id": 57, 64 | "type": "ImageScale", 65 | "pos": { 66 | "0": 635, 67 | "1": 833 68 | }, 69 | "size": { 70 | "0": 315, 71 | "1": 130 72 | }, 73 | "flags": {}, 74 | "order": 8, 75 | "mode": 0, 76 | "inputs": [ 77 | { 78 | "name": "image", 79 | "type": "IMAGE", 80 | "link": 88 81 | }, 82 | { 83 | "name": "width", 84 | "type": "INT", 85 | "link": 91, 86 | "widget": { 87 | "name": "width" 88 | } 89 | }, 90 | { 91 | "name": "height", 92 | "type": "INT", 93 | "link": 93, 94 | "widget": { 95 | "name": "height" 96 | } 97 | } 98 | ], 99 | "outputs": [ 100 | { 101 | "name": "IMAGE", 102 | "type": "IMAGE", 103 | "links": [ 104 | 89 105 | ], 106 | "slot_index": 0 107 | } 108 | ], 109 | "properties": { 110 | "Node name for S&R": "ImageScale" 111 | }, 112 | "widgets_values": [ 113 | "lanczos", 114 | 640, 115 | 384, 116 | "center" 117 | ] 118 | }, 119 | { 120 | "id": 58, 121 | "type": "PrimitiveNode", 122 | "pos": { 123 | "0": 347, 124 | "1": 817 125 | }, 126 | "size": { 127 | "0": 256.92181396484375, 128 | "1": 82 129 | }, 130 | "flags": {}, 131 | "order": 2, 132 | "mode": 0, 133 | "inputs": [], 134 | "outputs": [ 135 | { 136 | "name": "INT", 137 | "type": "INT", 138 | "links": [ 139 | 91, 140 | 94 141 | ], 142 | "slot_index": 0, 143 | "widget": { 144 | "name": "width" 145 | } 146 | } 147 | ], 148 | "title": "width", 149 | "properties": { 150 | "Run widget replace on values": false 151 | }, 152 | "widgets_values": [ 153 | 640, 154 | "fixed" 155 | ] 156 | }, 157 | { 158 | "id": 59, 159 | "type": "PrimitiveNode", 160 | "pos": { 161 | "0": 350, 162 | "1": 945 163 | }, 164 | "size": { 165 | "0": 251.2918701171875, 166 | "1": 82 167 | }, 168 | "flags": {}, 169 | "order": 3, 170 | "mode": 0, 171 | "inputs": [], 172 | "outputs": [ 173 | { 174 | "name": "INT", 175 | "type": "INT", 176 | "links": [ 177 | 93, 178 | 95 179 | ], 180 | "slot_index": 0, 181 | "widget": { 182 | "name": "height" 183 | } 184 | } 185 | ], 186 | "title": "height", 187 | "properties": { 188 | "Run widget replace on values": false 189 | }, 190 | "widgets_values": [ 191 | 384, 192 | "fixed" 193 | ] 194 | }, 195 | { 196 | "id": 54, 197 | "type": "PyramidFlowVAEEncode", 198 | "pos": { 199 | "0": 993, 200 | "1": 786 201 | }, 202 | "size": { 203 | "0": 315, 204 | "1": 102 205 | }, 206 | "flags": {}, 207 | "order": 9, 208 | "mode": 0, 209 | "inputs": [ 210 | { 211 | "name": "vae", 212 | "type": "PYRAMIDFLOWVAE", 213 | "link": 82 214 | }, 215 | { 216 | "name": "image", 217 | "type": "IMAGE", 218 | "link": 89 219 | } 220 | ], 221 | "outputs": [ 222 | { 223 | "name": "samples", 224 | "type": "LATENT", 225 | "links": [ 226 | 83 227 | ], 228 | "slot_index": 0 229 | } 230 | ], 231 | "properties": { 232 | "Node name for S&R": "PyramidFlowVAEEncode" 233 | }, 234 | "widgets_values": [ 235 | false, 236 | 0.25 237 | ] 238 | }, 239 | { 240 | "id": 55, 241 | "type": "LoadImage", 242 | "pos": { 243 | "0": -32, 244 | "1": 842 245 | }, 246 | "size": { 247 | "0": 315, 248 | "1": 314 249 | }, 250 | "flags": {}, 251 | "order": 4, 252 | "mode": 0, 253 | "inputs": [], 254 | "outputs": [ 255 | { 256 | "name": "IMAGE", 257 | "type": "IMAGE", 258 | "links": [ 259 | 88 260 | ] 261 | }, 262 | { 263 | "name": "MASK", 264 | "type": "MASK", 265 | "links": null 266 | } 267 | ], 268 | "properties": { 269 | "Node name for S&R": "LoadImage" 270 | }, 271 | "widgets_values": [ 272 | "videoframe_812.png", 273 | "image" 274 | ] 275 | }, 276 | { 277 | "id": 53, 278 | "type": "PyramidFlowTextEncode", 279 | "pos": { 280 | "0": 444, 281 | "1": 476 282 | }, 283 | "size": { 284 | "0": 437.19818115234375, 285 | "1": 269.9795837402344 286 | }, 287 | "flags": {}, 288 | "order": 7, 289 | "mode": 0, 290 | "inputs": [ 291 | { 292 | "name": "clip", 293 | "type": "CLIP", 294 | "link": 80 295 | } 296 | ], 297 | "outputs": [ 298 | { 299 | "name": "prompt_embeds", 300 | "type": "PYRAMIDFLOWPROMPT", 301 | "links": [ 302 | 81 303 | ], 304 | "slot_index": 0 305 | } 306 | ], 307 | "properties": { 308 | "Node name for S&R": "PyramidFlowTextEncode" 309 | }, 310 | "widgets_values": [ 311 | "FPV flying over seaside cliffs while the sun is setting, hyper quality, Ultra HD, 8K", 312 | "cartoon style, worst quality, low quality, blurry, absolute black, absolute white, low res, extra limbs, extra digits, misplaced objects, mutated anatomy, monochrome, horror", 313 | true 314 | ] 315 | }, 316 | { 317 | "id": 50, 318 | "type": "PyramidFlowSampler", 319 | "pos": { 320 | "0": 1046, 321 | "1": 144 322 | }, 323 | "size": { 324 | "0": 315, 325 | "1": 314 326 | }, 327 | "flags": {}, 328 | "order": 10, 329 | "mode": 0, 330 | "inputs": [ 331 | { 332 | "name": "model", 333 | "type": "PYRAMIDFLOWMODEL", 334 | "link": 74 335 | }, 336 | { 337 | "name": "prompt_embeds", 338 | "type": "PYRAMIDFLOWPROMPT", 339 | "link": 81 340 | }, 341 | { 342 | "name": "input_latent", 343 | "type": "LATENT", 344 | "link": 83, 345 | "shape": 7 346 | }, 347 | { 348 | "name": "width", 349 | "type": "INT", 350 | "link": 94, 351 | "widget": { 352 | "name": "width" 353 | } 354 | }, 355 | { 356 | "name": "height", 357 | "type": "INT", 358 | "link": 95, 359 | "widget": { 360 | "name": "height" 361 | } 362 | } 363 | ], 364 | "outputs": [ 365 | { 366 | "name": "samples", 367 | "type": "LATENT", 368 | "links": [ 369 | 76 370 | ] 371 | } 372 | ], 373 | "properties": { 374 | "Node name for S&R": "PyramidFlowSampler" 375 | }, 376 | "widgets_values": [ 377 | 640, 378 | 384, 379 | "20, 20, 20", 380 | "10, 10, 10", 381 | 16, 382 | 7, 383 | 4, 384 | 44664248661402, 385 | "fixed", 386 | false 387 | ] 388 | }, 389 | { 390 | "id": 51, 391 | "type": "VHS_VideoCombine", 392 | "pos": { 393 | "0": 1420, 394 | "1": 113 395 | }, 396 | "size": [ 397 | 1018.2306518554688, 398 | 310 399 | ], 400 | "flags": {}, 401 | "order": 12, 402 | "mode": 0, 403 | "inputs": [ 404 | { 405 | "name": "images", 406 | "type": "IMAGE", 407 | "link": 77 408 | }, 409 | { 410 | "name": "audio", 411 | "type": "AUDIO", 412 | "link": null, 413 | "shape": 7 414 | }, 415 | { 416 | "name": "meta_batch", 417 | "type": "VHS_BatchManager", 418 | "link": null, 419 | "shape": 7 420 | }, 421 | { 422 | "name": "vae", 423 | "type": "VAE", 424 | "link": null, 425 | "shape": 7 426 | } 427 | ], 428 | "outputs": [ 429 | { 430 | "name": "Filenames", 431 | "type": "VHS_FILENAMES", 432 | "links": null 433 | } 434 | ], 435 | "properties": { 436 | "Node name for S&R": "VHS_VideoCombine" 437 | }, 438 | "widgets_values": { 439 | "frame_rate": 24, 440 | "loop_count": 0, 441 | "filename_prefix": "PyramidFlow", 442 | "format": "video/h264-mp4", 443 | "pix_fmt": "yuv420p", 444 | "crf": 19, 445 | "save_metadata": true, 446 | "pingpong": false, 447 | "save_output": true, 448 | "videopreview": { 449 | "hidden": false, 450 | "paused": false, 451 | "params": { 452 | "filename": "PyramidFlow_00131.mp4", 453 | "subfolder": "", 454 | "type": "output", 455 | "format": "video/h264-mp4", 456 | "frame_rate": 24 457 | }, 458 | "muted": false 459 | } 460 | } 461 | }, 462 | { 463 | "id": 43, 464 | "type": "PyramidFlowVAELoader", 465 | "pos": { 466 | "0": 250, 467 | "1": 282 468 | }, 469 | "size": { 470 | "0": 411.12652587890625, 471 | "1": 82 472 | }, 473 | "flags": {}, 474 | "order": 5, 475 | "mode": 0, 476 | "inputs": [ 477 | { 478 | "name": "compile_args", 479 | "type": "PYRAMIDFLOW_COMPILEARGS", 480 | "link": null, 481 | "shape": 7 482 | } 483 | ], 484 | "outputs": [ 485 | { 486 | "name": "pyramidflow_vae", 487 | "type": "PYRAMIDFLOWVAE", 488 | "links": [ 489 | 71, 490 | 82 491 | ], 492 | "slot_index": 0 493 | } 494 | ], 495 | "properties": { 496 | "Node name for S&R": "PyramidFlowVAELoader" 497 | }, 498 | "widgets_values": [ 499 | "pyramidflow\\pyramid_flow_vae_bf16.safetensors", 500 | "bf16" 501 | ] 502 | }, 503 | { 504 | "id": 48, 505 | "type": "PyramidFlowVAEDecode", 506 | "pos": { 507 | "0": 1051, 508 | "1": 534 509 | }, 510 | "size": { 511 | "0": 315, 512 | "1": 150 513 | }, 514 | "flags": {}, 515 | "order": 11, 516 | "mode": 0, 517 | "inputs": [ 518 | { 519 | "name": "vae", 520 | "type": "PYRAMIDFLOWVAE", 521 | "link": 71 522 | }, 523 | { 524 | "name": "samples", 525 | "type": "LATENT", 526 | "link": 76 527 | } 528 | ], 529 | "outputs": [ 530 | { 531 | "name": "images", 532 | "type": "IMAGE", 533 | "links": [ 534 | 77 535 | ], 536 | "slot_index": 0 537 | } 538 | ], 539 | "properties": { 540 | "Node name for S&R": "PyramidFlowVAEDecode" 541 | }, 542 | "widgets_values": [ 543 | 256, 544 | 0.25, 545 | true, 546 | true 547 | ] 548 | }, 549 | { 550 | "id": 40, 551 | "type": "PyramidFlowTransformerLoader", 552 | "pos": { 553 | "0": 230, 554 | "1": 114 555 | }, 556 | "size": { 557 | "0": 444.05462646484375, 558 | "1": 106 559 | }, 560 | "flags": {}, 561 | "order": 6, 562 | "mode": 0, 563 | "inputs": [ 564 | { 565 | "name": "compile_args", 566 | "type": "PYRAMIDFLOW_COMPILEARGS", 567 | "link": null, 568 | "shape": 7 569 | } 570 | ], 571 | "outputs": [ 572 | { 573 | "name": "pyramidflow_model", 574 | "type": "PYRAMIDFLOWMODEL", 575 | "links": [ 576 | 74 577 | ], 578 | "slot_index": 0 579 | } 580 | ], 581 | "properties": { 582 | "Node name for S&R": "PyramidFlowTransformerLoader" 583 | }, 584 | "widgets_values": [ 585 | "pyramidflow\\pyramid_flow_miniflux_bf16_v2.safetensors", 586 | "bf16", 587 | false 588 | ] 589 | } 590 | ], 591 | "links": [ 592 | [ 593 | 71, 594 | 43, 595 | 0, 596 | 48, 597 | 0, 598 | "PYRAMIDFLOWVAE" 599 | ], 600 | [ 601 | 74, 602 | 40, 603 | 0, 604 | 50, 605 | 0, 606 | "PYRAMIDFLOWMODEL" 607 | ], 608 | [ 609 | 76, 610 | 50, 611 | 0, 612 | 48, 613 | 1, 614 | "LATENT" 615 | ], 616 | [ 617 | 77, 618 | 48, 619 | 0, 620 | 51, 621 | 0, 622 | "IMAGE" 623 | ], 624 | [ 625 | 80, 626 | 37, 627 | 0, 628 | 53, 629 | 0, 630 | "CLIP" 631 | ], 632 | [ 633 | 81, 634 | 53, 635 | 0, 636 | 50, 637 | 1, 638 | "PYRAMIDFLOWPROMPT" 639 | ], 640 | [ 641 | 82, 642 | 43, 643 | 0, 644 | 54, 645 | 0, 646 | "PYRAMIDFLOWVAE" 647 | ], 648 | [ 649 | 83, 650 | 54, 651 | 0, 652 | 50, 653 | 2, 654 | "LATENT" 655 | ], 656 | [ 657 | 88, 658 | 55, 659 | 0, 660 | 57, 661 | 0, 662 | "IMAGE" 663 | ], 664 | [ 665 | 89, 666 | 57, 667 | 0, 668 | 54, 669 | 1, 670 | "IMAGE" 671 | ], 672 | [ 673 | 91, 674 | 58, 675 | 0, 676 | 57, 677 | 1, 678 | "INT" 679 | ], 680 | [ 681 | 93, 682 | 59, 683 | 0, 684 | 57, 685 | 2, 686 | "INT" 687 | ], 688 | [ 689 | 94, 690 | 58, 691 | 0, 692 | 50, 693 | 3, 694 | "INT" 695 | ], 696 | [ 697 | 95, 698 | 59, 699 | 0, 700 | 50, 701 | 4, 702 | "INT" 703 | ] 704 | ], 705 | "groups": [], 706 | "config": {}, 707 | "extra": { 708 | "ds": { 709 | "scale": 0.6934334949442648, 710 | "offset": [ 711 | 667.6026332109438, 712 | 191.25659609525596 713 | ] 714 | } 715 | }, 716 | "version": 0.4 717 | } -------------------------------------------------------------------------------- /examples/pyramid_flow_miniflux_768_img2vid_example_01.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 59, 3 | "last_link_id": 95, 4 | "nodes": [ 5 | { 6 | "id": 39, 7 | "type": "Note", 8 | "pos": { 9 | "0": 30, 10 | "1": 650 11 | }, 12 | "size": { 13 | "0": 318.25567626953125, 14 | "1": 66.4825210571289 15 | }, 16 | "flags": {}, 17 | "order": 0, 18 | "mode": 0, 19 | "inputs": [], 20 | "outputs": [], 21 | "properties": {}, 22 | "widgets_values": [ 23 | "fp8 text encoder results are different from fp16!" 24 | ], 25 | "color": "#432", 26 | "bgcolor": "#653" 27 | }, 28 | { 29 | "id": 37, 30 | "type": "DualCLIPLoader", 31 | "pos": { 32 | "0": -40, 33 | "1": 480 34 | }, 35 | "size": { 36 | "0": 407.1675720214844, 37 | "1": 106 38 | }, 39 | "flags": {}, 40 | "order": 1, 41 | "mode": 0, 42 | "inputs": [], 43 | "outputs": [ 44 | { 45 | "name": "CLIP", 46 | "type": "CLIP", 47 | "links": [ 48 | 80 49 | ], 50 | "slot_index": 0 51 | } 52 | ], 53 | "properties": { 54 | "Node name for S&R": "DualCLIPLoader" 55 | }, 56 | "widgets_values": [ 57 | "clip_l.safetensors", 58 | "t5\\t5xxl_fp16.safetensors", 59 | "flux" 60 | ] 61 | }, 62 | { 63 | "id": 57, 64 | "type": "ImageScale", 65 | "pos": { 66 | "0": 635, 67 | "1": 833 68 | }, 69 | "size": { 70 | "0": 315, 71 | "1": 130 72 | }, 73 | "flags": {}, 74 | "order": 8, 75 | "mode": 0, 76 | "inputs": [ 77 | { 78 | "name": "image", 79 | "type": "IMAGE", 80 | "link": 88 81 | }, 82 | { 83 | "name": "width", 84 | "type": "INT", 85 | "link": 91, 86 | "widget": { 87 | "name": "width" 88 | } 89 | }, 90 | { 91 | "name": "height", 92 | "type": "INT", 93 | "link": 93, 94 | "widget": { 95 | "name": "height" 96 | } 97 | } 98 | ], 99 | "outputs": [ 100 | { 101 | "name": "IMAGE", 102 | "type": "IMAGE", 103 | "links": [ 104 | 89 105 | ], 106 | "slot_index": 0 107 | } 108 | ], 109 | "properties": { 110 | "Node name for S&R": "ImageScale" 111 | }, 112 | "widgets_values": [ 113 | "lanczos", 114 | 1280, 115 | 768, 116 | "center" 117 | ] 118 | }, 119 | { 120 | "id": 54, 121 | "type": "PyramidFlowVAEEncode", 122 | "pos": { 123 | "0": 993, 124 | "1": 786 125 | }, 126 | "size": { 127 | "0": 315, 128 | "1": 102 129 | }, 130 | "flags": {}, 131 | "order": 9, 132 | "mode": 0, 133 | "inputs": [ 134 | { 135 | "name": "vae", 136 | "type": "PYRAMIDFLOWVAE", 137 | "link": 82 138 | }, 139 | { 140 | "name": "image", 141 | "type": "IMAGE", 142 | "link": 89 143 | } 144 | ], 145 | "outputs": [ 146 | { 147 | "name": "samples", 148 | "type": "LATENT", 149 | "links": [ 150 | 83 151 | ], 152 | "slot_index": 0 153 | } 154 | ], 155 | "properties": { 156 | "Node name for S&R": "PyramidFlowVAEEncode" 157 | }, 158 | "widgets_values": [ 159 | false, 160 | 0.25 161 | ] 162 | }, 163 | { 164 | "id": 55, 165 | "type": "LoadImage", 166 | "pos": { 167 | "0": -32, 168 | "1": 842 169 | }, 170 | "size": { 171 | "0": 315, 172 | "1": 314 173 | }, 174 | "flags": {}, 175 | "order": 2, 176 | "mode": 0, 177 | "inputs": [], 178 | "outputs": [ 179 | { 180 | "name": "IMAGE", 181 | "type": "IMAGE", 182 | "links": [ 183 | 88 184 | ] 185 | }, 186 | { 187 | "name": "MASK", 188 | "type": "MASK", 189 | "links": null 190 | } 191 | ], 192 | "properties": { 193 | "Node name for S&R": "LoadImage" 194 | }, 195 | "widgets_values": [ 196 | "videoframe_812.png", 197 | "image" 198 | ] 199 | }, 200 | { 201 | "id": 53, 202 | "type": "PyramidFlowTextEncode", 203 | "pos": { 204 | "0": 444, 205 | "1": 476 206 | }, 207 | "size": { 208 | "0": 437.19818115234375, 209 | "1": 269.9795837402344 210 | }, 211 | "flags": {}, 212 | "order": 7, 213 | "mode": 0, 214 | "inputs": [ 215 | { 216 | "name": "clip", 217 | "type": "CLIP", 218 | "link": 80 219 | } 220 | ], 221 | "outputs": [ 222 | { 223 | "name": "prompt_embeds", 224 | "type": "PYRAMIDFLOWPROMPT", 225 | "links": [ 226 | 81 227 | ], 228 | "slot_index": 0 229 | } 230 | ], 231 | "properties": { 232 | "Node name for S&R": "PyramidFlowTextEncode" 233 | }, 234 | "widgets_values": [ 235 | "FPV flying over seaside cliffs while the sun is setting, hyper quality, Ultra HD, 8K", 236 | "cartoon style, worst quality, low quality, blurry, absolute black, absolute white, low res, extra limbs, extra digits, misplaced objects, mutated anatomy, monochrome, horror", 237 | true 238 | ] 239 | }, 240 | { 241 | "id": 51, 242 | "type": "VHS_VideoCombine", 243 | "pos": { 244 | "0": 1420, 245 | "1": 113 246 | }, 247 | "size": [ 248 | 1018.2306518554688, 249 | 310 250 | ], 251 | "flags": {}, 252 | "order": 12, 253 | "mode": 0, 254 | "inputs": [ 255 | { 256 | "name": "images", 257 | "type": "IMAGE", 258 | "link": 77 259 | }, 260 | { 261 | "name": "audio", 262 | "type": "AUDIO", 263 | "link": null, 264 | "shape": 7 265 | }, 266 | { 267 | "name": "meta_batch", 268 | "type": "VHS_BatchManager", 269 | "link": null, 270 | "shape": 7 271 | }, 272 | { 273 | "name": "vae", 274 | "type": "VAE", 275 | "link": null, 276 | "shape": 7 277 | } 278 | ], 279 | "outputs": [ 280 | { 281 | "name": "Filenames", 282 | "type": "VHS_FILENAMES", 283 | "links": null 284 | } 285 | ], 286 | "properties": { 287 | "Node name for S&R": "VHS_VideoCombine" 288 | }, 289 | "widgets_values": { 290 | "frame_rate": 24, 291 | "loop_count": 0, 292 | "filename_prefix": "PyramidFlow", 293 | "format": "video/h264-mp4", 294 | "pix_fmt": "yuv420p", 295 | "crf": 19, 296 | "save_metadata": true, 297 | "pingpong": false, 298 | "save_output": true, 299 | "videopreview": { 300 | "hidden": false, 301 | "paused": false, 302 | "params": { 303 | "filename": "PyramidFlow_00131.mp4", 304 | "subfolder": "", 305 | "type": "output", 306 | "format": "video/h264-mp4", 307 | "frame_rate": 24 308 | }, 309 | "muted": false 310 | } 311 | } 312 | }, 313 | { 314 | "id": 43, 315 | "type": "PyramidFlowVAELoader", 316 | "pos": { 317 | "0": 250, 318 | "1": 282 319 | }, 320 | "size": { 321 | "0": 411.12652587890625, 322 | "1": 82 323 | }, 324 | "flags": {}, 325 | "order": 3, 326 | "mode": 0, 327 | "inputs": [ 328 | { 329 | "name": "compile_args", 330 | "type": "PYRAMIDFLOW_COMPILEARGS", 331 | "link": null, 332 | "shape": 7 333 | } 334 | ], 335 | "outputs": [ 336 | { 337 | "name": "pyramidflow_vae", 338 | "type": "PYRAMIDFLOWVAE", 339 | "links": [ 340 | 71, 341 | 82 342 | ], 343 | "slot_index": 0 344 | } 345 | ], 346 | "properties": { 347 | "Node name for S&R": "PyramidFlowVAELoader" 348 | }, 349 | "widgets_values": [ 350 | "pyramidflow\\pyramid_flow_vae_bf16.safetensors", 351 | "bf16" 352 | ] 353 | }, 354 | { 355 | "id": 58, 356 | "type": "PrimitiveNode", 357 | "pos": { 358 | "0": 347, 359 | "1": 817 360 | }, 361 | "size": { 362 | "0": 256.92181396484375, 363 | "1": 82 364 | }, 365 | "flags": {}, 366 | "order": 4, 367 | "mode": 0, 368 | "inputs": [], 369 | "outputs": [ 370 | { 371 | "name": "INT", 372 | "type": "INT", 373 | "links": [ 374 | 91, 375 | 94 376 | ], 377 | "slot_index": 0, 378 | "widget": { 379 | "name": "width" 380 | } 381 | } 382 | ], 383 | "title": "width", 384 | "properties": { 385 | "Run widget replace on values": false 386 | }, 387 | "widgets_values": [ 388 | 1280, 389 | "fixed" 390 | ] 391 | }, 392 | { 393 | "id": 59, 394 | "type": "PrimitiveNode", 395 | "pos": { 396 | "0": 350, 397 | "1": 945 398 | }, 399 | "size": { 400 | "0": 251.2918701171875, 401 | "1": 82 402 | }, 403 | "flags": {}, 404 | "order": 5, 405 | "mode": 0, 406 | "inputs": [], 407 | "outputs": [ 408 | { 409 | "name": "INT", 410 | "type": "INT", 411 | "links": [ 412 | 93, 413 | 95 414 | ], 415 | "slot_index": 0, 416 | "widget": { 417 | "name": "height" 418 | } 419 | } 420 | ], 421 | "title": "height", 422 | "properties": { 423 | "Run widget replace on values": false 424 | }, 425 | "widgets_values": [ 426 | 768, 427 | "fixed" 428 | ] 429 | }, 430 | { 431 | "id": 48, 432 | "type": "PyramidFlowVAEDecode", 433 | "pos": { 434 | "0": 1032, 435 | "1": 571 436 | }, 437 | "size": { 438 | "0": 315, 439 | "1": 150 440 | }, 441 | "flags": {}, 442 | "order": 11, 443 | "mode": 0, 444 | "inputs": [ 445 | { 446 | "name": "vae", 447 | "type": "PYRAMIDFLOWVAE", 448 | "link": 71 449 | }, 450 | { 451 | "name": "samples", 452 | "type": "LATENT", 453 | "link": 76 454 | } 455 | ], 456 | "outputs": [ 457 | { 458 | "name": "images", 459 | "type": "IMAGE", 460 | "links": [ 461 | 77 462 | ], 463 | "slot_index": 0 464 | } 465 | ], 466 | "properties": { 467 | "Node name for S&R": "PyramidFlowVAEDecode" 468 | }, 469 | "widgets_values": [ 470 | 256, 471 | 0.25, 472 | true, 473 | true 474 | ] 475 | }, 476 | { 477 | "id": 40, 478 | "type": "PyramidFlowTransformerLoader", 479 | "pos": { 480 | "0": 226, 481 | "1": 92 482 | }, 483 | "size": { 484 | "0": 444.05462646484375, 485 | "1": 106 486 | }, 487 | "flags": {}, 488 | "order": 6, 489 | "mode": 0, 490 | "inputs": [ 491 | { 492 | "name": "compile_args", 493 | "type": "PYRAMIDFLOW_COMPILEARGS", 494 | "link": null, 495 | "shape": 7 496 | } 497 | ], 498 | "outputs": [ 499 | { 500 | "name": "pyramidflow_model", 501 | "type": "PYRAMIDFLOWMODEL", 502 | "links": [ 503 | 74 504 | ], 505 | "slot_index": 0 506 | } 507 | ], 508 | "properties": { 509 | "Node name for S&R": "PyramidFlowTransformerLoader" 510 | }, 511 | "widgets_values": [ 512 | "pyramidflow\\pyramid_flow_miniflux_768_bf16.safetensors", 513 | "bf16", 514 | false 515 | ] 516 | }, 517 | { 518 | "id": 50, 519 | "type": "PyramidFlowSampler", 520 | "pos": { 521 | "0": 1028, 522 | "1": -15 523 | }, 524 | "size": { 525 | "0": 315, 526 | "1": 518 527 | }, 528 | "flags": {}, 529 | "order": 10, 530 | "mode": 0, 531 | "inputs": [ 532 | { 533 | "name": "model", 534 | "type": "PYRAMIDFLOWMODEL", 535 | "link": 74 536 | }, 537 | { 538 | "name": "prompt_embeds", 539 | "type": "PYRAMIDFLOWPROMPT", 540 | "link": 81 541 | }, 542 | { 543 | "name": "input_latent", 544 | "type": "LATENT", 545 | "link": 83, 546 | "shape": 7 547 | }, 548 | { 549 | "name": "width", 550 | "type": "INT", 551 | "link": 94, 552 | "widget": { 553 | "name": "width" 554 | } 555 | }, 556 | { 557 | "name": "height", 558 | "type": "INT", 559 | "link": 95, 560 | "widget": { 561 | "name": "height" 562 | } 563 | } 564 | ], 565 | "outputs": [ 566 | { 567 | "name": "samples", 568 | "type": "LATENT", 569 | "links": [ 570 | 76 571 | ] 572 | } 573 | ], 574 | "properties": { 575 | "Node name for S&R": "PyramidFlowSampler" 576 | }, 577 | "widgets_values": [ 578 | 1280, 579 | 768, 580 | "20, 20, 20", 581 | "10, 10, 10", 582 | 16, 583 | 7, 584 | 5, 585 | 44664248661402, 586 | "fixed", 587 | false 588 | ] 589 | } 590 | ], 591 | "links": [ 592 | [ 593 | 71, 594 | 43, 595 | 0, 596 | 48, 597 | 0, 598 | "PYRAMIDFLOWVAE" 599 | ], 600 | [ 601 | 74, 602 | 40, 603 | 0, 604 | 50, 605 | 0, 606 | "PYRAMIDFLOWMODEL" 607 | ], 608 | [ 609 | 76, 610 | 50, 611 | 0, 612 | 48, 613 | 1, 614 | "LATENT" 615 | ], 616 | [ 617 | 77, 618 | 48, 619 | 0, 620 | 51, 621 | 0, 622 | "IMAGE" 623 | ], 624 | [ 625 | 80, 626 | 37, 627 | 0, 628 | 53, 629 | 0, 630 | "CLIP" 631 | ], 632 | [ 633 | 81, 634 | 53, 635 | 0, 636 | 50, 637 | 1, 638 | "PYRAMIDFLOWPROMPT" 639 | ], 640 | [ 641 | 82, 642 | 43, 643 | 0, 644 | 54, 645 | 0, 646 | "PYRAMIDFLOWVAE" 647 | ], 648 | [ 649 | 83, 650 | 54, 651 | 0, 652 | 50, 653 | 2, 654 | "LATENT" 655 | ], 656 | [ 657 | 88, 658 | 55, 659 | 0, 660 | 57, 661 | 0, 662 | "IMAGE" 663 | ], 664 | [ 665 | 89, 666 | 57, 667 | 0, 668 | 54, 669 | 1, 670 | "IMAGE" 671 | ], 672 | [ 673 | 91, 674 | 58, 675 | 0, 676 | 57, 677 | 1, 678 | "INT" 679 | ], 680 | [ 681 | 93, 682 | 59, 683 | 0, 684 | 57, 685 | 2, 686 | "INT" 687 | ], 688 | [ 689 | 94, 690 | 58, 691 | 0, 692 | 50, 693 | 3, 694 | "INT" 695 | ], 696 | [ 697 | 95, 698 | 59, 699 | 0, 700 | 50, 701 | 4, 702 | "INT" 703 | ] 704 | ], 705 | "groups": [], 706 | "config": {}, 707 | "extra": { 708 | "ds": { 709 | "scale": 0.7627768444386913, 710 | "offset": [ 711 | 86.13055347478304, 712 | 125.09485874927132 713 | ] 714 | } 715 | }, 716 | "version": 0.4 717 | } -------------------------------------------------------------------------------- /pyramid_dit/mmdit_modules/modeling_embedding.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | import math 7 | 8 | from diffusers.models.activations import get_activation 9 | from einops import rearrange 10 | 11 | 12 | def get_1d_sincos_pos_embed( 13 | embed_dim, num_frames, cls_token=False, extra_tokens=0, 14 | ): 15 | t = np.arange(num_frames, dtype=np.float32) 16 | pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, t) # (T, D) 17 | if cls_token and extra_tokens > 0: 18 | pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) 19 | return pos_embed 20 | 21 | 22 | def get_2d_sincos_pos_embed( 23 | embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16 24 | ): 25 | """ 26 | grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or 27 | [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 28 | """ 29 | if isinstance(grid_size, int): 30 | grid_size = (grid_size, grid_size) 31 | 32 | grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale 33 | grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale 34 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 35 | grid = np.stack(grid, axis=0) 36 | 37 | grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) 38 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 39 | if cls_token and extra_tokens > 0: 40 | pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) 41 | return pos_embed 42 | 43 | 44 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 45 | if embed_dim % 2 != 0: 46 | raise ValueError("embed_dim must be divisible by 2") 47 | 48 | # use half of dimensions to encode grid_h 49 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 50 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 51 | 52 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 53 | return emb 54 | 55 | 56 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 57 | """ 58 | embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) 59 | """ 60 | if embed_dim % 2 != 0: 61 | raise ValueError("embed_dim must be divisible by 2") 62 | 63 | omega = np.arange(embed_dim // 2, dtype=np.float64) 64 | omega /= embed_dim / 2.0 65 | omega = 1.0 / 10000**omega # (D/2,) 66 | 67 | pos = pos.reshape(-1) # (M,) 68 | out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product 69 | 70 | emb_sin = np.sin(out) # (M, D/2) 71 | emb_cos = np.cos(out) # (M, D/2) 72 | 73 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 74 | return emb 75 | 76 | 77 | def get_timestep_embedding( 78 | timesteps: torch.Tensor, 79 | embedding_dim: int, 80 | flip_sin_to_cos: bool = False, 81 | downscale_freq_shift: float = 1, 82 | scale: float = 1, 83 | max_period: int = 10000, 84 | ): 85 | """ 86 | This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. 87 | :param timesteps: a 1-D Tensor of N indices, one per batch element. These may be fractional. 88 | :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the 89 | embeddings. :return: an [N x dim] Tensor of positional embeddings. 90 | """ 91 | assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" 92 | 93 | half_dim = embedding_dim // 2 94 | exponent = -math.log(max_period) * torch.arange( 95 | start=0, end=half_dim, dtype=torch.float32, device=timesteps.device 96 | ) 97 | exponent = exponent / (half_dim - downscale_freq_shift) 98 | 99 | emb = torch.exp(exponent) 100 | emb = timesteps[:, None].float() * emb[None, :] 101 | 102 | # scale embeddings 103 | emb = scale * emb 104 | 105 | # concat sine and cosine embeddings 106 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) 107 | 108 | # flip sine and cosine embeddings 109 | if flip_sin_to_cos: 110 | emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) 111 | 112 | # zero pad 113 | if embedding_dim % 2 == 1: 114 | emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) 115 | return emb 116 | 117 | 118 | class Timesteps(nn.Module): 119 | def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float): 120 | super().__init__() 121 | self.num_channels = num_channels 122 | self.flip_sin_to_cos = flip_sin_to_cos 123 | self.downscale_freq_shift = downscale_freq_shift 124 | 125 | def forward(self, timesteps): 126 | t_emb = get_timestep_embedding( 127 | timesteps, 128 | self.num_channels, 129 | flip_sin_to_cos=self.flip_sin_to_cos, 130 | downscale_freq_shift=self.downscale_freq_shift, 131 | ) 132 | return t_emb 133 | 134 | 135 | class TimestepEmbedding(nn.Module): 136 | def __init__( 137 | self, 138 | in_channels: int, 139 | time_embed_dim: int, 140 | act_fn: str = "silu", 141 | out_dim: int = None, 142 | post_act_fn: Optional[str] = None, 143 | sample_proj_bias=True, 144 | ): 145 | super().__init__() 146 | self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias) 147 | self.act = get_activation(act_fn) 148 | self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim, sample_proj_bias) 149 | 150 | def forward(self, sample): 151 | sample = self.linear_1(sample) 152 | sample = self.act(sample) 153 | sample = self.linear_2(sample) 154 | return sample 155 | 156 | 157 | class TextProjection(nn.Module): 158 | def __init__(self, in_features, hidden_size, act_fn="silu"): 159 | super().__init__() 160 | self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True) 161 | self.act_1 = get_activation(act_fn) 162 | self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True) 163 | 164 | def forward(self, caption): 165 | hidden_states = self.linear_1(caption) 166 | hidden_states = self.act_1(hidden_states) 167 | hidden_states = self.linear_2(hidden_states) 168 | return hidden_states 169 | 170 | 171 | class CombinedTimestepConditionEmbeddings(nn.Module): 172 | def __init__(self, embedding_dim, pooled_projection_dim): 173 | super().__init__() 174 | 175 | self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) 176 | self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) 177 | self.text_embedder = TextProjection(pooled_projection_dim, embedding_dim, act_fn="silu") 178 | 179 | def forward(self, timestep, pooled_projection): 180 | timesteps_proj = self.time_proj(timestep) 181 | timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D) 182 | pooled_projections = self.text_embedder(pooled_projection) 183 | conditioning = timesteps_emb + pooled_projections 184 | return conditioning 185 | 186 | 187 | class CombinedTimestepEmbeddings(nn.Module): 188 | def __init__(self, embedding_dim): 189 | super().__init__() 190 | self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) 191 | self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) 192 | 193 | def forward(self, timestep): 194 | timesteps_proj = self.time_proj(timestep) 195 | timesteps_emb = self.timestep_embedder(timesteps_proj) # (N, D) 196 | return timesteps_emb 197 | 198 | 199 | class PatchEmbed3D(nn.Module): 200 | """Support the 3D Tensor input""" 201 | 202 | def __init__( 203 | self, 204 | height=128, 205 | width=128, 206 | patch_size=2, 207 | in_channels=16, 208 | embed_dim=1536, 209 | layer_norm=False, 210 | bias=True, 211 | interpolation_scale=1, 212 | pos_embed_type="sincos", 213 | temp_pos_embed_type='rope', 214 | pos_embed_max_size=192, # For SD3 cropping 215 | max_num_frames=64, 216 | add_temp_pos_embed=False, 217 | interp_condition_pos=False, 218 | ): 219 | super().__init__() 220 | 221 | num_patches = (height // patch_size) * (width // patch_size) 222 | self.layer_norm = layer_norm 223 | self.pos_embed_max_size = pos_embed_max_size 224 | 225 | self.proj = nn.Conv2d( 226 | in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias 227 | ) 228 | if layer_norm: 229 | self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) 230 | else: 231 | self.norm = None 232 | 233 | self.patch_size = patch_size 234 | self.height, self.width = height // patch_size, width // patch_size 235 | self.base_size = height // patch_size 236 | self.interpolation_scale = interpolation_scale 237 | self.add_temp_pos_embed = add_temp_pos_embed 238 | 239 | # Calculate positional embeddings based on max size or default 240 | if pos_embed_max_size: 241 | grid_size = pos_embed_max_size 242 | else: 243 | grid_size = int(num_patches**0.5) 244 | 245 | if pos_embed_type is None: 246 | self.pos_embed = None 247 | 248 | elif pos_embed_type == "sincos": 249 | pos_embed = get_2d_sincos_pos_embed( 250 | embed_dim, grid_size, base_size=self.base_size, interpolation_scale=self.interpolation_scale 251 | ) 252 | persistent = True if pos_embed_max_size else False 253 | self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=persistent) 254 | 255 | if add_temp_pos_embed and temp_pos_embed_type == 'sincos': 256 | time_pos_embed = get_1d_sincos_pos_embed(embed_dim, max_num_frames) 257 | self.register_buffer("temp_pos_embed", torch.from_numpy(time_pos_embed).float().unsqueeze(0), persistent=True) 258 | 259 | elif pos_embed_type == "rope": 260 | print("Using the rotary position embedding") 261 | 262 | else: 263 | raise ValueError(f"Unsupported pos_embed_type: {pos_embed_type}") 264 | 265 | self.pos_embed_type = pos_embed_type 266 | self.temp_pos_embed_type = temp_pos_embed_type 267 | self.interp_condition_pos = interp_condition_pos 268 | 269 | def cropped_pos_embed(self, height, width, ori_height, ori_width): 270 | """Crops positional embeddings for SD3 compatibility.""" 271 | if self.pos_embed_max_size is None: 272 | raise ValueError("`pos_embed_max_size` must be set for cropping.") 273 | 274 | height = height // self.patch_size 275 | width = width // self.patch_size 276 | ori_height = ori_height // self.patch_size 277 | ori_width = ori_width // self.patch_size 278 | 279 | assert ori_height >= height, "The ori_height needs >= height" 280 | assert ori_width >= width, "The ori_width needs >= width" 281 | 282 | if height > self.pos_embed_max_size: 283 | raise ValueError( 284 | f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}." 285 | ) 286 | if width > self.pos_embed_max_size: 287 | raise ValueError( 288 | f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}." 289 | ) 290 | 291 | if self.interp_condition_pos: 292 | top = (self.pos_embed_max_size - ori_height) // 2 293 | left = (self.pos_embed_max_size - ori_width) // 2 294 | spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1) 295 | spatial_pos_embed = spatial_pos_embed[:, top : top + ori_height, left : left + ori_width, :] # [b h w c] 296 | if ori_height != height or ori_width != width: 297 | spatial_pos_embed = spatial_pos_embed.permute(0, 3, 1, 2) 298 | spatial_pos_embed = torch.nn.functional.interpolate(spatial_pos_embed, size=(height, width), mode='bilinear') 299 | spatial_pos_embed = spatial_pos_embed.permute(0, 2, 3, 1) 300 | else: 301 | top = (self.pos_embed_max_size - height) // 2 302 | left = (self.pos_embed_max_size - width) // 2 303 | spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1) 304 | spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :] 305 | 306 | spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1]) 307 | 308 | return spatial_pos_embed 309 | 310 | def forward_func(self, latent, time_index=0, ori_height=None, ori_width=None): 311 | if self.pos_embed_max_size is not None: 312 | height, width = latent.shape[-2:] 313 | else: 314 | height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size 315 | 316 | bs = latent.shape[0] 317 | temp = latent.shape[2] 318 | 319 | latent = rearrange(latent, 'b c t h w -> (b t) c h w') 320 | latent = self.proj(latent) 321 | latent = latent.flatten(2).transpose(1, 2) # (BT)CHW -> (BT)NC 322 | 323 | if self.layer_norm: 324 | latent = self.norm(latent) 325 | 326 | if self.pos_embed_type == 'sincos': 327 | # Spatial position embedding, Interpolate or crop positional embeddings as needed 328 | if self.pos_embed_max_size: 329 | pos_embed = self.cropped_pos_embed(height, width, ori_height, ori_width) 330 | else: 331 | raise NotImplementedError("Not implemented sincos pos embed without sd3 max pos crop") 332 | if self.height != height or self.width != width: 333 | pos_embed = get_2d_sincos_pos_embed( 334 | embed_dim=self.pos_embed.shape[-1], 335 | grid_size=(height, width), 336 | base_size=self.base_size, 337 | interpolation_scale=self.interpolation_scale, 338 | ) 339 | pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).to(latent.device) 340 | else: 341 | pos_embed = self.pos_embed 342 | 343 | if self.add_temp_pos_embed and self.temp_pos_embed_type == 'sincos': 344 | latent_dtype = latent.dtype 345 | latent = latent + pos_embed 346 | latent = rearrange(latent, '(b t) n c -> (b n) t c', t=temp) 347 | latent = latent + self.temp_pos_embed[:, time_index:time_index + temp, :] 348 | latent = latent.to(latent_dtype) 349 | latent = rearrange(latent, '(b n) t c -> b t n c', b=bs) 350 | else: 351 | latent = (latent + pos_embed).to(latent.dtype) 352 | latent = rearrange(latent, '(b t) n c -> b t n c', b=bs, t=temp) 353 | 354 | else: 355 | assert self.pos_embed_type == "rope", "Only supporting the sincos and rope embedding" 356 | latent = rearrange(latent, '(b t) n c -> b t n c', b=bs, t=temp) 357 | 358 | return latent 359 | 360 | def forward(self, latent): 361 | """ 362 | Arguments: 363 | past_condition_latents (Torch.FloatTensor): The past latent during the generation 364 | flatten_input (bool): True indicate flatten the latent into 1D sequence 365 | """ 366 | 367 | if isinstance(latent, list): 368 | output_list = [] 369 | 370 | for latent_ in latent: 371 | if not isinstance(latent_, list): 372 | latent_ = [latent_] 373 | 374 | output_latent = [] 375 | time_index = 0 376 | ori_height, ori_width = latent_[-1].shape[-2:] 377 | for each_latent in latent_: 378 | hidden_state = self.forward_func(each_latent, time_index=time_index, ori_height=ori_height, ori_width=ori_width) 379 | time_index += each_latent.shape[2] 380 | hidden_state = rearrange(hidden_state, "b t n c -> b (t n) c") 381 | output_latent.append(hidden_state) 382 | 383 | output_latent = torch.cat(output_latent, dim=1) 384 | output_list.append(output_latent) 385 | 386 | return output_list 387 | else: 388 | hidden_states = self.forward_func(latent) 389 | hidden_states = rearrange(hidden_states, "b t n c -> b (t n) c") 390 | return hidden_states -------------------------------------------------------------------------------- /video_vae/modeling_enc_dec.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from dataclasses import dataclass 15 | from typing import Optional, Tuple 16 | 17 | import numpy as np 18 | import torch 19 | import torch.nn as nn 20 | 21 | from diffusers.utils import BaseOutput, is_torch_version 22 | from diffusers.utils.torch_utils import randn_tensor 23 | from .modeling_block import ( 24 | UNetMidBlock2D, 25 | CausalUNetMidBlock2D, 26 | get_down_block, 27 | get_up_block, 28 | get_input_layer, 29 | get_output_layer, 30 | ) 31 | from .modeling_causal_conv import CausalConv3d, CausalGroupNorm 32 | 33 | 34 | @dataclass 35 | class DecoderOutput(BaseOutput): 36 | r""" 37 | Output of decoding method. 38 | 39 | Args: 40 | sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): 41 | The decoded output sample from the last layer of the model. 42 | """ 43 | 44 | sample: torch.FloatTensor 45 | 46 | 47 | class CausalVaeEncoder(nn.Module): 48 | r""" 49 | The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation. 50 | 51 | Args: 52 | in_channels (`int`, *optional*, defaults to 3): 53 | The number of input channels. 54 | out_channels (`int`, *optional*, defaults to 3): 55 | The number of output channels. 56 | down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`): 57 | The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available 58 | options. 59 | block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): 60 | The number of output channels for each block. 61 | layers_per_block (`int`, *optional*, defaults to 2): 62 | The number of layers per block. 63 | norm_num_groups (`int`, *optional*, defaults to 32): 64 | The number of groups for normalization. 65 | act_fn (`str`, *optional*, defaults to `"silu"`): 66 | The activation function to use. See `~diffusers.models.activations.get_activation` for available options. 67 | double_z (`bool`, *optional*, defaults to `True`): 68 | Whether to double the number of output channels for the last block. 69 | """ 70 | 71 | def __init__( 72 | self, 73 | in_channels: int = 3, 74 | out_channels: int = 3, 75 | down_block_types: Tuple[str, ...] = ("DownEncoderBlockCausal3D",), 76 | spatial_down_sample: Tuple[bool, ...] = (True,), 77 | temporal_down_sample: Tuple[bool, ...] = (False,), 78 | block_out_channels: Tuple[int, ...] = (64,), 79 | layers_per_block: Tuple[int, ...] = (2,), 80 | norm_num_groups: int = 32, 81 | act_fn: str = "silu", 82 | double_z: bool = True, 83 | block_dropout: Tuple[int, ...] = (0.0,), 84 | mid_block_add_attention=True, 85 | ): 86 | super().__init__() 87 | self.layers_per_block = layers_per_block 88 | 89 | self.conv_in = CausalConv3d( 90 | in_channels, 91 | block_out_channels[0], 92 | kernel_size=3, 93 | stride=1, 94 | ) 95 | 96 | self.mid_block = None 97 | self.down_blocks = nn.ModuleList([]) 98 | 99 | # down 100 | output_channel = block_out_channels[0] 101 | for i, down_block_type in enumerate(down_block_types): 102 | input_channel = output_channel 103 | output_channel = block_out_channels[i] 104 | 105 | down_block = get_down_block( 106 | down_block_type, 107 | num_layers=self.layers_per_block[i], 108 | in_channels=input_channel, 109 | out_channels=output_channel, 110 | add_spatial_downsample=spatial_down_sample[i], 111 | add_temporal_downsample=temporal_down_sample[i], 112 | resnet_eps=1e-6, 113 | downsample_padding=0, 114 | resnet_act_fn=act_fn, 115 | resnet_groups=norm_num_groups, 116 | attention_head_dim=output_channel, 117 | temb_channels=None, 118 | dropout=block_dropout[i], 119 | ) 120 | self.down_blocks.append(down_block) 121 | 122 | # mid 123 | self.mid_block = CausalUNetMidBlock2D( 124 | in_channels=block_out_channels[-1], 125 | resnet_eps=1e-6, 126 | resnet_act_fn=act_fn, 127 | output_scale_factor=1, 128 | resnet_time_scale_shift="default", 129 | attention_head_dim=block_out_channels[-1], 130 | resnet_groups=norm_num_groups, 131 | temb_channels=None, 132 | add_attention=mid_block_add_attention, 133 | dropout=block_dropout[-1], 134 | ) 135 | 136 | # out 137 | 138 | self.conv_norm_out = CausalGroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6) 139 | self.conv_act = nn.SiLU() 140 | 141 | conv_out_channels = 2 * out_channels if double_z else out_channels 142 | self.conv_out = CausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3, stride=1) 143 | 144 | self.gradient_checkpointing = False 145 | 146 | def forward(self, sample: torch.FloatTensor, is_init_image=True, temporal_chunk=False) -> torch.FloatTensor: 147 | r"""The forward method of the `Encoder` class.""" 148 | 149 | sample = self.conv_in(sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk) 150 | 151 | if self.training and self.gradient_checkpointing: 152 | 153 | def create_custom_forward(module): 154 | def custom_forward(*inputs): 155 | return module(*inputs) 156 | 157 | return custom_forward 158 | 159 | # down 160 | if is_torch_version(">=", "1.11.0"): 161 | for down_block in self.down_blocks: 162 | sample = torch.utils.checkpoint.checkpoint( 163 | create_custom_forward(down_block), sample, is_init_image, 164 | temporal_chunk, use_reentrant=False 165 | ) 166 | # middle 167 | sample = torch.utils.checkpoint.checkpoint( 168 | create_custom_forward(self.mid_block), sample, is_init_image, 169 | temporal_chunk, use_reentrant=False 170 | ) 171 | else: 172 | for down_block in self.down_blocks: 173 | sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample, is_init_image, temporal_chunk) 174 | # middle 175 | sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample, is_init_image, temporal_chunk) 176 | 177 | else: 178 | # down 179 | for down_block in self.down_blocks: 180 | sample = down_block(sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk) 181 | 182 | # middle 183 | sample = self.mid_block(sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk) 184 | 185 | # post-process 186 | sample = self.conv_norm_out(sample) 187 | sample = self.conv_act(sample) 188 | sample = self.conv_out(sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk) 189 | 190 | return sample 191 | 192 | 193 | class CausalVaeDecoder(nn.Module): 194 | r""" 195 | The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample. 196 | 197 | Args: 198 | in_channels (`int`, *optional*, defaults to 3): 199 | The number of input channels. 200 | out_channels (`int`, *optional*, defaults to 3): 201 | The number of output channels. 202 | up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`): 203 | The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options. 204 | block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): 205 | The number of output channels for each block. 206 | layers_per_block (`int`, *optional*, defaults to 2): 207 | The number of layers per block. 208 | norm_num_groups (`int`, *optional*, defaults to 32): 209 | The number of groups for normalization. 210 | act_fn (`str`, *optional*, defaults to `"silu"`): 211 | The activation function to use. See `~diffusers.models.activations.get_activation` for available options. 212 | norm_type (`str`, *optional*, defaults to `"group"`): 213 | The normalization type to use. Can be either `"group"` or `"spatial"`. 214 | """ 215 | 216 | def __init__( 217 | self, 218 | in_channels: int = 3, 219 | out_channels: int = 3, 220 | up_block_types: Tuple[str, ...] = ("UpDecoderBlockCausal3D",), 221 | spatial_up_sample: Tuple[bool, ...] = (True,), 222 | temporal_up_sample: Tuple[bool, ...] = (False,), 223 | block_out_channels: Tuple[int, ...] = (64,), 224 | layers_per_block: Tuple[int, ...] = (2,), 225 | norm_num_groups: int = 32, 226 | act_fn: str = "silu", 227 | mid_block_add_attention=True, 228 | interpolate: bool = True, 229 | block_dropout: Tuple[int, ...] = (0.0,), 230 | ): 231 | super().__init__() 232 | self.layers_per_block = layers_per_block 233 | 234 | self.conv_in = CausalConv3d( 235 | in_channels, 236 | block_out_channels[-1], 237 | kernel_size=3, 238 | stride=1, 239 | ) 240 | 241 | self.mid_block = None 242 | self.up_blocks = nn.ModuleList([]) 243 | 244 | # mid 245 | self.mid_block = CausalUNetMidBlock2D( 246 | in_channels=block_out_channels[-1], 247 | resnet_eps=1e-6, 248 | resnet_act_fn=act_fn, 249 | output_scale_factor=1, 250 | resnet_time_scale_shift="default", 251 | attention_head_dim=block_out_channels[-1], 252 | resnet_groups=norm_num_groups, 253 | temb_channels=None, 254 | add_attention=mid_block_add_attention, 255 | dropout=block_dropout[-1], 256 | ) 257 | 258 | # up 259 | reversed_block_out_channels = list(reversed(block_out_channels)) 260 | output_channel = reversed_block_out_channels[0] 261 | for i, up_block_type in enumerate(up_block_types): 262 | prev_output_channel = output_channel 263 | output_channel = reversed_block_out_channels[i] 264 | 265 | is_final_block = i == len(block_out_channels) - 1 266 | 267 | up_block = get_up_block( 268 | up_block_type, 269 | num_layers=self.layers_per_block[i], 270 | in_channels=prev_output_channel, 271 | out_channels=output_channel, 272 | prev_output_channel=None, 273 | add_spatial_upsample=spatial_up_sample[i], 274 | add_temporal_upsample=temporal_up_sample[i], 275 | resnet_eps=1e-6, 276 | resnet_act_fn=act_fn, 277 | resnet_groups=norm_num_groups, 278 | attention_head_dim=output_channel, 279 | temb_channels=None, 280 | resnet_time_scale_shift='default', 281 | interpolate=interpolate, 282 | dropout=block_dropout[i], 283 | ) 284 | self.up_blocks.append(up_block) 285 | prev_output_channel = output_channel 286 | 287 | # out 288 | self.conv_norm_out = CausalGroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) 289 | self.conv_act = nn.SiLU() 290 | self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3, stride=1) 291 | 292 | self.gradient_checkpointing = False 293 | 294 | def forward( 295 | self, 296 | sample: torch.FloatTensor, 297 | is_init_image=True, 298 | temporal_chunk=False, 299 | ) -> torch.FloatTensor: 300 | r"""The forward method of the `Decoder` class.""" 301 | 302 | sample = self.conv_in(sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk) 303 | 304 | upscale_dtype = next(iter(self.up_blocks.parameters())).dtype 305 | if self.training and self.gradient_checkpointing: 306 | 307 | def create_custom_forward(module): 308 | def custom_forward(*inputs): 309 | return module(*inputs) 310 | 311 | return custom_forward 312 | 313 | if is_torch_version(">=", "1.11.0"): 314 | # middle 315 | sample = torch.utils.checkpoint.checkpoint( 316 | create_custom_forward(self.mid_block), 317 | sample, 318 | is_init_image=is_init_image, 319 | temporal_chunk=temporal_chunk, 320 | use_reentrant=False, 321 | ) 322 | sample = sample.to(upscale_dtype) 323 | 324 | # up 325 | for up_block in self.up_blocks: 326 | sample = torch.utils.checkpoint.checkpoint( 327 | create_custom_forward(up_block), 328 | sample, 329 | is_init_image=is_init_image, 330 | temporal_chunk=temporal_chunk, 331 | use_reentrant=False, 332 | ) 333 | else: 334 | # middle 335 | sample = torch.utils.checkpoint.checkpoint( 336 | create_custom_forward(self.mid_block), sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk, 337 | ) 338 | sample = sample.to(upscale_dtype) 339 | 340 | # up 341 | for up_block in self.up_blocks: 342 | sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, 343 | is_init_image=is_init_image, temporal_chunk=temporal_chunk,) 344 | else: 345 | # middle 346 | sample = self.mid_block(sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk) 347 | sample = sample.to(upscale_dtype) 348 | 349 | # up 350 | for up_block in self.up_blocks: 351 | sample = up_block(sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk,) 352 | 353 | # post-process 354 | sample = self.conv_norm_out(sample) 355 | sample = self.conv_act(sample) 356 | sample = self.conv_out(sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk) 357 | 358 | return sample 359 | 360 | 361 | class DiagonalGaussianDistribution(object): 362 | def __init__(self, parameters: torch.Tensor, deterministic: bool = False): 363 | self.parameters = parameters 364 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 365 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 366 | self.deterministic = deterministic 367 | self.std = torch.exp(0.5 * self.logvar) 368 | self.var = torch.exp(self.logvar) 369 | if self.deterministic: 370 | self.var = self.std = torch.zeros_like( 371 | self.mean, device=self.parameters.device, dtype=self.parameters.dtype 372 | ) 373 | 374 | def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor: 375 | # make sure sample is on the same device as the parameters and has same dtype 376 | sample = randn_tensor( 377 | self.mean.shape, 378 | generator=generator, 379 | device=self.parameters.device, 380 | dtype=self.parameters.dtype, 381 | ) 382 | x = self.mean + self.std * sample 383 | return x 384 | 385 | def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor: 386 | if self.deterministic: 387 | return torch.Tensor([0.0]) 388 | else: 389 | if other is None: 390 | return 0.5 * torch.sum( 391 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, 392 | dim=[2, 3, 4], 393 | ) 394 | else: 395 | return 0.5 * torch.sum( 396 | torch.pow(self.mean - other.mean, 2) / other.var 397 | + self.var / other.var 398 | - 1.0 399 | - self.logvar 400 | + other.logvar, 401 | dim=[2, 3, 4], 402 | ) 403 | 404 | def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor: 405 | if self.deterministic: 406 | return torch.Tensor([0.0]) 407 | logtwopi = np.log(2.0 * np.pi) 408 | return 0.5 * torch.sum( 409 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 410 | dim=dims, 411 | ) 412 | 413 | def mode(self) -> torch.Tensor: 414 | return self.mean -------------------------------------------------------------------------------- /pyramid_dit/flux_modules/modeling_pyramid_flux.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from einops import rearrange 7 | 8 | from diffusers.configuration_utils import ConfigMixin, register_to_config 9 | from diffusers.models.modeling_utils import ModelMixin 10 | 11 | from .modeling_normalization import AdaLayerNormContinuous 12 | from .modeling_embedding import CombinedTimestepTextProjEmbeddings 13 | from .modeling_flux_block import FluxTransformerBlock, FluxSingleTransformerBlock 14 | 15 | def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: 16 | assert dim % 2 == 0, "The dimension must be even." 17 | 18 | scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim 19 | omega = 1.0 / (theta**scale) 20 | 21 | batch_size, seq_length = pos.shape 22 | out = torch.einsum("...n,d->...nd", pos, omega) 23 | cos_out = torch.cos(out) 24 | sin_out = torch.sin(out) 25 | 26 | stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1) 27 | out = stacked_out.view(batch_size, -1, dim // 2, 2, 2) 28 | return out.float() 29 | 30 | 31 | class EmbedND(nn.Module): 32 | def __init__(self, dim: int, theta: int, axes_dim: List[int]): 33 | super().__init__() 34 | self.dim = dim 35 | self.theta = theta 36 | self.axes_dim = axes_dim 37 | 38 | def forward(self, ids: torch.Tensor) -> torch.Tensor: 39 | n_axes = ids.shape[-1] 40 | emb = torch.cat( 41 | [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], 42 | dim=-3, 43 | ) 44 | return emb.unsqueeze(2) 45 | 46 | 47 | class PyramidFluxTransformer(ModelMixin, ConfigMixin): 48 | """ 49 | The Transformer model introduced in Flux. 50 | 51 | Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ 52 | 53 | Parameters: 54 | patch_size (`int`): Patch size to turn the input data into small patches. 55 | in_channels (`int`, *optional*, defaults to 16): The number of channels in the input. 56 | num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use. 57 | num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use. 58 | attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. 59 | num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention. 60 | joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. 61 | pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`. 62 | """ 63 | 64 | _supports_gradient_checkpointing = True 65 | 66 | @register_to_config 67 | def __init__( 68 | self, 69 | patch_size: int = 1, 70 | in_channels: int = 64, 71 | num_layers: int = 19, 72 | num_single_layers: int = 38, 73 | attention_head_dim: int = 64, 74 | num_attention_heads: int = 24, 75 | joint_attention_dim: int = 4096, 76 | pooled_projection_dim: int = 768, 77 | axes_dims_rope: List[int] = [16, 24, 24], 78 | use_flash_attn: bool = False, 79 | use_temporal_causal: bool = True, 80 | interp_condition_pos: bool = True, 81 | use_gradient_checkpointing: bool = False, 82 | gradient_checkpointing_ratio: float = 0.6, 83 | ): 84 | super().__init__() 85 | self.out_channels = in_channels 86 | self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim 87 | 88 | self.pos_embed = EmbedND(dim=self.inner_dim, theta=10000, axes_dim=axes_dims_rope) 89 | self.time_text_embed = CombinedTimestepTextProjEmbeddings( 90 | embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim 91 | ) 92 | 93 | self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim) 94 | self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim) 95 | 96 | self.transformer_blocks = nn.ModuleList( 97 | [ 98 | FluxTransformerBlock( 99 | dim=self.inner_dim, 100 | num_attention_heads=self.config.num_attention_heads, 101 | attention_head_dim=self.config.attention_head_dim, 102 | use_flash_attn=use_flash_attn, 103 | ) 104 | for i in range(self.config.num_layers) 105 | ] 106 | ) 107 | 108 | self.single_transformer_blocks = nn.ModuleList( 109 | [ 110 | FluxSingleTransformerBlock( 111 | dim=self.inner_dim, 112 | num_attention_heads=self.config.num_attention_heads, 113 | attention_head_dim=self.config.attention_head_dim, 114 | use_flash_attn=use_flash_attn, 115 | ) 116 | for i in range(self.config.num_single_layers) 117 | ] 118 | ) 119 | 120 | self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) 121 | self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) 122 | 123 | self.gradient_checkpointing = use_gradient_checkpointing 124 | self.gradient_checkpointing_ratio = gradient_checkpointing_ratio 125 | 126 | self.use_temporal_causal = use_temporal_causal 127 | if self.use_temporal_causal: 128 | print("Using temporal causal attention") 129 | 130 | self.use_flash_attn = use_flash_attn 131 | if self.use_flash_attn: 132 | print("Using Flash attention") 133 | 134 | self.patch_size = 2 # hard-code for now 135 | 136 | # init weights 137 | self.initialize_weights() 138 | 139 | def initialize_weights(self): 140 | # Initialize transformer layers: 141 | def _basic_init(module): 142 | if isinstance(module, (nn.Linear, nn.Conv2d, nn.Conv3d)): 143 | torch.nn.init.xavier_uniform_(module.weight) 144 | if module.bias is not None: 145 | nn.init.constant_(module.bias, 0) 146 | self.apply(_basic_init) 147 | 148 | # Initialize all the conditioning to normal init 149 | nn.init.normal_(self.time_text_embed.timestep_embedder.linear_1.weight, std=0.02) 150 | nn.init.normal_(self.time_text_embed.timestep_embedder.linear_2.weight, std=0.02) 151 | nn.init.normal_(self.time_text_embed.text_embedder.linear_1.weight, std=0.02) 152 | nn.init.normal_(self.time_text_embed.text_embedder.linear_2.weight, std=0.02) 153 | nn.init.normal_(self.context_embedder.weight, std=0.02) 154 | 155 | # Zero-out adaLN modulation layers in DiT blocks: 156 | for block in self.transformer_blocks: 157 | nn.init.constant_(block.norm1.linear.weight, 0) 158 | nn.init.constant_(block.norm1.linear.bias, 0) 159 | nn.init.constant_(block.norm1_context.linear.weight, 0) 160 | nn.init.constant_(block.norm1_context.linear.bias, 0) 161 | 162 | for block in self.single_transformer_blocks: 163 | nn.init.constant_(block.norm.linear.weight, 0) 164 | nn.init.constant_(block.norm.linear.bias, 0) 165 | 166 | # Zero-out output layers: 167 | nn.init.constant_(self.norm_out.linear.weight, 0) 168 | nn.init.constant_(self.norm_out.linear.bias, 0) 169 | nn.init.constant_(self.proj_out.weight, 0) 170 | nn.init.constant_(self.proj_out.bias, 0) 171 | 172 | @torch.no_grad() 173 | def _prepare_image_ids(self, batch_size, temp, height, width, train_height, train_width, device, start_time_stamp=0): 174 | latent_image_ids = torch.zeros(temp, height, width, 3) 175 | 176 | # Temporal Rope 177 | latent_image_ids[..., 0] = latent_image_ids[..., 0] + torch.arange(start_time_stamp, start_time_stamp + temp)[:, None, None] 178 | 179 | # height Rope 180 | if height != train_height: 181 | height_pos = F.interpolate(torch.arange(train_height)[None, None, :].float(), height, mode='linear').squeeze(0, 1) 182 | else: 183 | height_pos = torch.arange(train_height).float() 184 | 185 | latent_image_ids[..., 1] = latent_image_ids[..., 1] + height_pos[None, :, None] 186 | 187 | # width rope 188 | if width != train_width: 189 | width_pos = F.interpolate(torch.arange(train_width)[None, None, :].float(), width, mode='linear').squeeze(0, 1) 190 | else: 191 | width_pos = torch.arange(train_width).float() 192 | 193 | latent_image_ids[..., 2] = latent_image_ids[..., 2] + width_pos[None, None, :] 194 | 195 | latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1, 1) 196 | latent_image_ids = rearrange(latent_image_ids, 'b t h w c -> b (t h w) c') 197 | 198 | return latent_image_ids.to(device=device) 199 | 200 | @torch.no_grad() 201 | def _prepare_pyramid_image_ids(self, sample, batch_size, device): 202 | image_ids_list = [] 203 | 204 | for i_b, sample_ in enumerate(sample): 205 | if not isinstance(sample_, list): 206 | sample_ = [sample_] 207 | 208 | cur_image_ids = [] 209 | start_time_stamp = 0 210 | 211 | train_height = sample_[-1].shape[-2] // self.patch_size 212 | train_width = sample_[-1].shape[-1] // self.patch_size 213 | 214 | for clip_ in sample_: 215 | _, _, temp, height, width = clip_.shape 216 | height = height // self.patch_size 217 | width = width // self.patch_size 218 | cur_image_ids.append(self._prepare_image_ids(batch_size, temp, height, width, train_height, train_width, device, start_time_stamp=start_time_stamp)) 219 | start_time_stamp += temp 220 | 221 | cur_image_ids = torch.cat(cur_image_ids, dim=1) 222 | image_ids_list.append(cur_image_ids) 223 | 224 | return image_ids_list 225 | 226 | def merge_input(self, sample, encoder_hidden_length, encoder_attention_mask): 227 | """ 228 | Merge the input video with different resolutions into one sequence 229 | Sample: From low resolution to high resolution 230 | """ 231 | if isinstance(sample[0], list): 232 | device = sample[0][-1].device 233 | pad_batch_size = sample[0][-1].shape[0] 234 | else: 235 | device = sample[0].device 236 | pad_batch_size = sample[0].shape[0] 237 | 238 | num_stages = len(sample) 239 | height_list = [];width_list = [];temp_list = [] 240 | trainable_token_list = [] 241 | 242 | for i_b, sample_ in enumerate(sample): 243 | if isinstance(sample_, list): 244 | sample_ = sample_[-1] 245 | _, _, temp, height, width = sample_.shape 246 | height = height // self.patch_size 247 | width = width // self.patch_size 248 | temp_list.append(temp) 249 | height_list.append(height) 250 | width_list.append(width) 251 | trainable_token_list.append(height * width * temp) 252 | 253 | # prepare the RoPE IDs, 254 | image_ids_list = self._prepare_pyramid_image_ids(sample, pad_batch_size, device) 255 | text_ids = torch.zeros(pad_batch_size, encoder_attention_mask.shape[1], 3).to(device=device) 256 | input_ids_list = [torch.cat([text_ids, image_ids], dim=1) for image_ids in image_ids_list] 257 | image_rotary_emb = [self.pos_embed(input_ids) for input_ids in input_ids_list] # [bs, seq_len, 1, head_dim // 2, 2, 2] 258 | 259 | hidden_states, hidden_length = [], [] 260 | 261 | for sample_ in sample: 262 | video_tokens = [] 263 | 264 | for each_latent in sample_: 265 | each_latent = rearrange(each_latent, 'b c t h w -> b t h w c') 266 | each_latent = rearrange(each_latent, 'b t (h p1) (w p2) c -> b (t h w) (p1 p2 c)', p1=self.patch_size, p2=self.patch_size) 267 | video_tokens.append(each_latent) 268 | 269 | video_tokens = torch.cat(video_tokens, dim=1) 270 | video_tokens = self.x_embedder(video_tokens) 271 | hidden_states.append(video_tokens) 272 | hidden_length.append(video_tokens.shape[1]) 273 | 274 | # prepare the attention mask 275 | if self.use_flash_attn: 276 | attention_mask = None 277 | indices_list = [] 278 | for i_p, length in enumerate(hidden_length): 279 | pad_attention_mask = torch.ones((pad_batch_size, length), dtype=encoder_attention_mask.dtype).to(device) 280 | pad_attention_mask = torch.cat([encoder_attention_mask[i_p::num_stages], pad_attention_mask], dim=1) 281 | 282 | seqlens_in_batch = pad_attention_mask.sum(dim=-1, dtype=torch.int32) 283 | indices = torch.nonzero(pad_attention_mask.flatten(), as_tuple=False).flatten() 284 | 285 | indices_list.append( 286 | { 287 | 'indices': indices, 288 | 'seqlens_in_batch': seqlens_in_batch, 289 | } 290 | ) 291 | encoder_attention_mask = indices_list 292 | else: 293 | assert encoder_attention_mask.shape[1] == encoder_hidden_length 294 | real_batch_size = encoder_attention_mask.shape[0] 295 | 296 | # prepare text ids 297 | text_ids = torch.arange(1, real_batch_size + 1, dtype=encoder_attention_mask.dtype).unsqueeze(1).repeat(1, encoder_hidden_length) 298 | text_ids = text_ids.to(device) 299 | text_ids[encoder_attention_mask == 0] = 0 300 | 301 | # prepare image ids 302 | image_ids = torch.arange(1, real_batch_size + 1, dtype=encoder_attention_mask.dtype).unsqueeze(1).repeat(1, max(hidden_length)) 303 | image_ids = image_ids.to(device) 304 | image_ids_list = [] 305 | for i_p, length in enumerate(hidden_length): 306 | image_ids_list.append(image_ids[i_p::num_stages][:, :length]) 307 | 308 | attention_mask = [] 309 | for i_p in range(len(hidden_length)): 310 | image_ids = image_ids_list[i_p] 311 | token_ids = torch.cat([text_ids[i_p::num_stages], image_ids], dim=1) 312 | stage_attention_mask = rearrange(token_ids, 'b i -> b 1 i 1') == rearrange(token_ids, 'b j -> b 1 1 j') # [bs, 1, q_len, k_len] 313 | if self.use_temporal_causal: 314 | input_order_ids = input_ids_list[i_p][:,:,0] 315 | temporal_causal_mask = rearrange(input_order_ids, 'b i -> b 1 i 1') >= rearrange(input_order_ids, 'b j -> b 1 1 j') 316 | stage_attention_mask = stage_attention_mask & temporal_causal_mask 317 | attention_mask.append(stage_attention_mask) 318 | 319 | return hidden_states, hidden_length, temp_list, height_list, width_list, trainable_token_list, encoder_attention_mask, attention_mask, image_rotary_emb 320 | 321 | def split_output(self, batch_hidden_states, hidden_length, temps, heights, widths, trainable_token_list): 322 | # To split the hidden states 323 | batch_size = batch_hidden_states.shape[0] 324 | output_hidden_list = [] 325 | batch_hidden_states = torch.split(batch_hidden_states, hidden_length, dim=1) 326 | 327 | for i_p, length in enumerate(hidden_length): 328 | width, height, temp = widths[i_p], heights[i_p], temps[i_p] 329 | trainable_token_num = trainable_token_list[i_p] 330 | hidden_states = batch_hidden_states[i_p] 331 | 332 | # only the trainable token are taking part in loss computation 333 | hidden_states = hidden_states[:, -trainable_token_num:] 334 | 335 | # unpatchify 336 | hidden_states = hidden_states.reshape( 337 | shape=(batch_size, temp, height, width, self.patch_size, self.patch_size, self.out_channels // 4) 338 | ) 339 | hidden_states = rearrange(hidden_states, "b t h w p1 p2 c -> b t (h p1) (w p2) c") 340 | hidden_states = rearrange(hidden_states, "b t h w c -> b c t h w") 341 | output_hidden_list.append(hidden_states) 342 | 343 | return output_hidden_list 344 | 345 | def forward( 346 | self, 347 | sample: torch.FloatTensor, # [num_stages] 348 | encoder_hidden_states: torch.Tensor = None, 349 | encoder_attention_mask: torch.FloatTensor = None, 350 | pooled_projections: torch.Tensor = None, 351 | timestep_ratio: torch.LongTensor = None, 352 | ): 353 | temb = self.time_text_embed(timestep_ratio, pooled_projections) 354 | encoder_hidden_states = self.context_embedder(encoder_hidden_states) 355 | encoder_hidden_length = encoder_hidden_states.shape[1] 356 | 357 | # Get the input sequence 358 | hidden_states, hidden_length, temps, heights, widths, trainable_token_list, encoder_attention_mask, attention_mask, \ 359 | image_rotary_emb = self.merge_input(sample, encoder_hidden_length, encoder_attention_mask) 360 | 361 | hidden_states = torch.cat(hidden_states, dim=1) 362 | 363 | for index_block, block in enumerate(self.transformer_blocks): 364 | encoder_hidden_states, hidden_states = block( 365 | hidden_states=hidden_states, 366 | encoder_hidden_states=encoder_hidden_states, 367 | encoder_attention_mask=encoder_attention_mask, 368 | temb=temb, 369 | attention_mask=attention_mask, 370 | hidden_length=hidden_length, 371 | image_rotary_emb=image_rotary_emb, 372 | ) 373 | 374 | # remerge for single attention block 375 | num_stages = len(hidden_length) 376 | batch_hidden_states = list(torch.split(hidden_states, hidden_length, dim=1)) 377 | concat_hidden_length = [] 378 | 379 | for i_p in range(len(hidden_length)): 380 | batch_hidden_states[i_p] = torch.cat([encoder_hidden_states[i_p::num_stages], batch_hidden_states[i_p]], dim=1) 381 | concat_hidden_length.append(batch_hidden_states[i_p].shape[1]) 382 | 383 | hidden_states = torch.cat(batch_hidden_states, dim=1) 384 | 385 | for index_block, block in enumerate(self.single_transformer_blocks): 386 | hidden_states = block( 387 | hidden_states=hidden_states, 388 | temb=temb, 389 | encoder_attention_mask=encoder_attention_mask, 390 | attention_mask=attention_mask, 391 | hidden_length=concat_hidden_length, 392 | image_rotary_emb=image_rotary_emb, 393 | ) 394 | 395 | batch_hidden_states = list(torch.split(hidden_states, concat_hidden_length, dim=1)) 396 | 397 | for i_p in range(len(concat_hidden_length)): 398 | batch_hidden_states[i_p] = batch_hidden_states[i_p][:, encoder_hidden_length :, ...] 399 | 400 | hidden_states = torch.cat(batch_hidden_states, dim=1) 401 | hidden_states = self.norm_out(hidden_states, temb, hidden_length=hidden_length) 402 | hidden_states = self.proj_out(hidden_states) 403 | 404 | output = self.split_output(hidden_states, hidden_length, temps, heights, widths, trainable_token_list) 405 | 406 | return output -------------------------------------------------------------------------------- /pyramid_dit/mmdit_modules/modeling_pyramid_mmdit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from einops import rearrange 6 | from diffusers.utils.torch_utils import randn_tensor 7 | from diffusers.models.modeling_utils import ModelMixin 8 | from diffusers.configuration_utils import ConfigMixin, register_to_config 9 | from diffusers.utils import is_torch_version 10 | from typing import Any, Callable, Dict, List, Optional, Union 11 | 12 | from .modeling_embedding import PatchEmbed3D, CombinedTimestepConditionEmbeddings 13 | from .modeling_normalization import AdaLayerNormContinuous 14 | from .modeling_mmdit_block import JointTransformerBlock 15 | 16 | 17 | def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: 18 | assert dim % 2 == 0, "The dimension must be even." 19 | 20 | scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim 21 | omega = 1.0 / (theta**scale) 22 | 23 | batch_size, seq_length = pos.shape 24 | out = torch.einsum("...n,d->...nd", pos, omega) 25 | cos_out = torch.cos(out) 26 | sin_out = torch.sin(out) 27 | 28 | stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1) 29 | out = stacked_out.view(batch_size, -1, dim // 2, 2, 2) 30 | return out.float() 31 | 32 | 33 | class EmbedNDRoPE(nn.Module): 34 | def __init__(self, dim: int, theta: int, axes_dim: List[int]): 35 | super().__init__() 36 | self.dim = dim 37 | self.theta = theta 38 | self.axes_dim = axes_dim 39 | 40 | def forward(self, ids: torch.Tensor) -> torch.Tensor: 41 | n_axes = ids.shape[-1] 42 | emb = torch.cat( 43 | [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], 44 | dim=-3, 45 | ) 46 | return emb.unsqueeze(2) 47 | 48 | 49 | class PyramidDiffusionMMDiT(ModelMixin, ConfigMixin): 50 | _supports_gradient_checkpointing = True 51 | 52 | @register_to_config 53 | def __init__( 54 | self, 55 | sample_size: int = 128, 56 | patch_size: int = 2, 57 | in_channels: int = 16, 58 | num_layers: int = 24, 59 | attention_head_dim: int = 64, 60 | num_attention_heads: int = 24, 61 | caption_projection_dim: int = 1152, 62 | pooled_projection_dim: int = 2048, 63 | pos_embed_max_size: int = 192, 64 | max_num_frames: int = 200, 65 | qk_norm: str = 'rms_norm', 66 | pos_embed_type: str = 'rope', 67 | temp_pos_embed_type: str = 'sincos', 68 | joint_attention_dim: int = 4096, 69 | use_gradient_checkpointing: bool = False, 70 | use_flash_attn: bool = True, 71 | use_temporal_causal: bool = False, 72 | use_t5_mask: bool = False, 73 | add_temp_pos_embed: bool = False, 74 | interp_condition_pos: bool = False, 75 | ): 76 | super().__init__() 77 | 78 | self.out_channels = in_channels 79 | self.inner_dim = num_attention_heads * attention_head_dim 80 | assert temp_pos_embed_type in ['rope', 'sincos'] 81 | 82 | # The input latent embeder, using the name pos_embed to remain the same with SD# 83 | self.pos_embed = PatchEmbed3D( 84 | height=sample_size, 85 | width=sample_size, 86 | patch_size=patch_size, 87 | in_channels=in_channels, 88 | embed_dim=self.inner_dim, 89 | pos_embed_max_size=pos_embed_max_size, # hard-code for now. 90 | max_num_frames=max_num_frames, 91 | pos_embed_type=pos_embed_type, 92 | temp_pos_embed_type=temp_pos_embed_type, 93 | add_temp_pos_embed=add_temp_pos_embed, 94 | interp_condition_pos=interp_condition_pos, 95 | ) 96 | 97 | # The RoPE EMbedding 98 | if pos_embed_type == 'rope': 99 | self.rope_embed = EmbedNDRoPE(self.inner_dim, 10000, axes_dim=[16, 24, 24]) 100 | else: 101 | self.rope_embed = None 102 | 103 | if temp_pos_embed_type == 'rope': 104 | self.temp_rope_embed = EmbedNDRoPE(self.inner_dim, 10000, axes_dim=[attention_head_dim]) 105 | else: 106 | self.temp_rope_embed = None 107 | 108 | self.time_text_embed = CombinedTimestepConditionEmbeddings( 109 | embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim, 110 | ) 111 | self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.config.caption_projection_dim) 112 | 113 | self.transformer_blocks = nn.ModuleList( 114 | [ 115 | JointTransformerBlock( 116 | dim=self.inner_dim, 117 | num_attention_heads=num_attention_heads, 118 | attention_head_dim=self.inner_dim, 119 | qk_norm=qk_norm, 120 | context_pre_only=i == num_layers - 1, 121 | use_flash_attn=use_flash_attn, 122 | ) 123 | for i in range(num_layers) 124 | ] 125 | ) 126 | 127 | self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) 128 | self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) 129 | self.gradient_checkpointing = use_gradient_checkpointing 130 | self.patch_size = patch_size 131 | self.use_flash_attn = use_flash_attn 132 | self.use_temporal_causal = use_temporal_causal 133 | self.pos_embed_type = pos_embed_type 134 | self.temp_pos_embed_type = temp_pos_embed_type 135 | self.add_temp_pos_embed = add_temp_pos_embed 136 | 137 | if self.use_temporal_causal: 138 | print("Using temporal causal attention") 139 | assert self.use_flash_attn is False, "The flash attention does not support temporal causal" 140 | 141 | if interp_condition_pos: 142 | print("We interp the position embedding of condition latents") 143 | 144 | # init weights 145 | self.initialize_weights() 146 | 147 | def initialize_weights(self): 148 | # Initialize transformer layers: 149 | def _basic_init(module): 150 | if isinstance(module, (nn.Linear, nn.Conv2d, nn.Conv3d)): 151 | torch.nn.init.xavier_uniform_(module.weight) 152 | if module.bias is not None: 153 | nn.init.constant_(module.bias, 0) 154 | self.apply(_basic_init) 155 | 156 | # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): 157 | w = self.pos_embed.proj.weight.data 158 | nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 159 | nn.init.constant_(self.pos_embed.proj.bias, 0) 160 | 161 | # Initialize all the conditioning to normal init 162 | nn.init.normal_(self.time_text_embed.timestep_embedder.linear_1.weight, std=0.02) 163 | nn.init.normal_(self.time_text_embed.timestep_embedder.linear_2.weight, std=0.02) 164 | nn.init.normal_(self.time_text_embed.text_embedder.linear_1.weight, std=0.02) 165 | nn.init.normal_(self.time_text_embed.text_embedder.linear_2.weight, std=0.02) 166 | nn.init.normal_(self.context_embedder.weight, std=0.02) 167 | 168 | # Zero-out adaLN modulation layers in DiT blocks: 169 | for block in self.transformer_blocks: 170 | nn.init.constant_(block.norm1.linear.weight, 0) 171 | nn.init.constant_(block.norm1.linear.bias, 0) 172 | nn.init.constant_(block.norm1_context.linear.weight, 0) 173 | nn.init.constant_(block.norm1_context.linear.bias, 0) 174 | 175 | # Zero-out output layers: 176 | nn.init.constant_(self.norm_out.linear.weight, 0) 177 | nn.init.constant_(self.norm_out.linear.bias, 0) 178 | nn.init.constant_(self.proj_out.weight, 0) 179 | nn.init.constant_(self.proj_out.bias, 0) 180 | 181 | @torch.no_grad() 182 | def _prepare_latent_image_ids(self, batch_size, temp, height, width, device): 183 | latent_image_ids = torch.zeros(temp, height, width, 3) 184 | latent_image_ids[..., 0] = latent_image_ids[..., 0] + torch.arange(temp)[:, None, None] 185 | latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[None, :, None] 186 | latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, None, :] 187 | 188 | latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1, 1) 189 | latent_image_ids = rearrange(latent_image_ids, 'b t h w c -> b (t h w) c') 190 | return latent_image_ids.to(device=device) 191 | 192 | @torch.no_grad() 193 | def _prepare_pyramid_latent_image_ids(self, batch_size, temp_list, height_list, width_list, device): 194 | base_width = width_list[-1]; base_height = height_list[-1] 195 | assert base_width == max(width_list) 196 | assert base_height == max(height_list) 197 | 198 | image_ids_list = [] 199 | for temp, height, width in zip(temp_list, height_list, width_list): 200 | latent_image_ids = torch.zeros(temp, height, width, 3) 201 | 202 | if height != base_height: 203 | height_pos = F.interpolate(torch.arange(base_height)[None, None, :].float(), height, mode='linear').squeeze(0, 1) 204 | else: 205 | height_pos = torch.arange(base_height).float() 206 | if width != base_width: 207 | width_pos = F.interpolate(torch.arange(base_width)[None, None, :].float(), width, mode='linear').squeeze(0, 1) 208 | else: 209 | width_pos = torch.arange(base_width).float() 210 | 211 | latent_image_ids[..., 0] = latent_image_ids[..., 0] + torch.arange(temp)[:, None, None] 212 | latent_image_ids[..., 1] = latent_image_ids[..., 1] + height_pos[None, :, None] 213 | latent_image_ids[..., 2] = latent_image_ids[..., 2] + width_pos[None, None, :] 214 | latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1, 1) 215 | latent_image_ids = rearrange(latent_image_ids, 'b t h w c -> b (t h w) c').to(device) 216 | image_ids_list.append(latent_image_ids) 217 | 218 | return image_ids_list 219 | 220 | @torch.no_grad() 221 | def _prepare_temporal_rope_ids(self, batch_size, temp, height, width, device, start_time_stamp=0): 222 | latent_image_ids = torch.zeros(temp, height, width, 1) 223 | latent_image_ids[..., 0] = latent_image_ids[..., 0] + torch.arange(start_time_stamp, start_time_stamp + temp)[:, None, None] 224 | latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1, 1) 225 | latent_image_ids = rearrange(latent_image_ids, 'b t h w c -> b (t h w) c') 226 | return latent_image_ids.to(device=device) 227 | 228 | @torch.no_grad() 229 | def _prepare_pyramid_temporal_rope_ids(self, sample, batch_size, device): 230 | image_ids_list = [] 231 | 232 | for i_b, sample_ in enumerate(sample): 233 | if not isinstance(sample_, list): 234 | sample_ = [sample_] 235 | 236 | cur_image_ids = [] 237 | start_time_stamp = 0 238 | 239 | for clip_ in sample_: 240 | _, _, temp, height, width = clip_.shape 241 | height = height // self.patch_size 242 | width = width // self.patch_size 243 | cur_image_ids.append(self._prepare_temporal_rope_ids(batch_size, temp, height, width, device, start_time_stamp=start_time_stamp)) 244 | start_time_stamp += temp 245 | 246 | cur_image_ids = torch.cat(cur_image_ids, dim=1) 247 | image_ids_list.append(cur_image_ids) 248 | 249 | return image_ids_list 250 | 251 | def merge_input(self, sample, encoder_hidden_length, encoder_attention_mask): 252 | """ 253 | Merge the input video with different resolutions into one sequence 254 | Sample: From low resolution to high resolution 255 | """ 256 | if isinstance(sample[0], list): 257 | device = sample[0][-1].device 258 | pad_batch_size = sample[0][-1].shape[0] 259 | else: 260 | device = sample[0].device 261 | pad_batch_size = sample[0].shape[0] 262 | 263 | num_stages = len(sample) 264 | height_list = [];width_list = [];temp_list = [] 265 | trainable_token_list = [] 266 | 267 | for i_b, sample_ in enumerate(sample): 268 | if isinstance(sample_, list): 269 | sample_ = sample_[-1] 270 | _, _, temp, height, width = sample_.shape 271 | height = height // self.patch_size 272 | width = width // self.patch_size 273 | temp_list.append(temp) 274 | height_list.append(height) 275 | width_list.append(width) 276 | trainable_token_list.append(height * width * temp) 277 | 278 | # prepare the RoPE embedding if needed 279 | if self.pos_embed_type == 'rope': 280 | # TODO: support the 3D Rope for video 281 | raise NotImplementedError("Not compatible with video generation now") 282 | text_ids = torch.zeros(pad_batch_size, encoder_hidden_length, 3).to(device=device) 283 | image_ids_list = self._prepare_pyramid_latent_image_ids(pad_batch_size, temp_list, height_list, width_list, device) 284 | input_ids_list = [torch.cat([text_ids, image_ids], dim=1) for image_ids in image_ids_list] 285 | image_rotary_emb = [self.rope_embed(input_ids) for input_ids in input_ids_list] # [bs, seq_len, 1, head_dim // 2, 2, 2] 286 | else: 287 | if self.temp_pos_embed_type == 'rope' and self.add_temp_pos_embed: 288 | image_ids_list = self._prepare_pyramid_temporal_rope_ids(sample, pad_batch_size, device) 289 | text_ids = torch.zeros(pad_batch_size, encoder_attention_mask.shape[1], 1).to(device=device) 290 | input_ids_list = [torch.cat([text_ids, image_ids], dim=1) for image_ids in image_ids_list] 291 | image_rotary_emb = [self.temp_rope_embed(input_ids) for input_ids in input_ids_list] # [bs, seq_len, 1, head_dim // 2, 2, 2] 292 | else: 293 | image_rotary_emb = None 294 | 295 | hidden_states = self.pos_embed(sample) # hidden states is a list of [b c t h w] b = real_b // num_stages 296 | hidden_length = [] 297 | 298 | for i_b in range(num_stages): 299 | hidden_length.append(hidden_states[i_b].shape[1]) 300 | 301 | # prepare the attention mask 302 | if self.use_flash_attn: 303 | attention_mask = None 304 | indices_list = [] 305 | for i_p, length in enumerate(hidden_length): 306 | pad_attention_mask = torch.ones((pad_batch_size, length), dtype=encoder_attention_mask.dtype).to(device) 307 | pad_attention_mask = torch.cat([encoder_attention_mask[i_p::num_stages], pad_attention_mask], dim=1) 308 | 309 | seqlens_in_batch = pad_attention_mask.sum(dim=-1, dtype=torch.int32) 310 | indices = torch.nonzero(pad_attention_mask.flatten(), as_tuple=False).flatten() 311 | 312 | indices_list.append( 313 | { 314 | 'indices': indices, 315 | 'seqlens_in_batch': seqlens_in_batch, 316 | } 317 | ) 318 | encoder_attention_mask = indices_list 319 | else: 320 | assert encoder_attention_mask.shape[1] == encoder_hidden_length 321 | real_batch_size = encoder_attention_mask.shape[0] 322 | # prepare text ids 323 | text_ids = torch.arange(1, real_batch_size + 1, dtype=encoder_attention_mask.dtype).unsqueeze(1).repeat(1, encoder_hidden_length) 324 | text_ids = text_ids.to(device) 325 | text_ids[encoder_attention_mask == 0] = 0 326 | 327 | # prepare image ids 328 | image_ids = torch.arange(1, real_batch_size + 1, dtype=encoder_attention_mask.dtype).unsqueeze(1).repeat(1, max(hidden_length)) 329 | image_ids = image_ids.to(device) 330 | image_ids_list = [] 331 | for i_p, length in enumerate(hidden_length): 332 | image_ids_list.append(image_ids[i_p::num_stages][:, :length]) 333 | 334 | attention_mask = [] 335 | for i_p in range(len(hidden_length)): 336 | image_ids = image_ids_list[i_p] 337 | token_ids = torch.cat([text_ids[i_p::num_stages], image_ids], dim=1) 338 | stage_attention_mask = rearrange(token_ids, 'b i -> b 1 i 1') == rearrange(token_ids, 'b j -> b 1 1 j') # [bs, 1, q_len, k_len] 339 | if self.use_temporal_causal: 340 | input_order_ids = input_ids_list[i_p].squeeze(2) 341 | temporal_causal_mask = rearrange(input_order_ids, 'b i -> b 1 i 1') >= rearrange(input_order_ids, 'b j -> b 1 1 j') 342 | stage_attention_mask = stage_attention_mask & temporal_causal_mask 343 | attention_mask.append(stage_attention_mask) 344 | 345 | return hidden_states, hidden_length, temp_list, height_list, width_list, trainable_token_list, encoder_attention_mask, attention_mask, image_rotary_emb 346 | 347 | def split_output(self, batch_hidden_states, hidden_length, temps, heights, widths, trainable_token_list): 348 | # To split the hidden states 349 | batch_size = batch_hidden_states.shape[0] 350 | output_hidden_list = [] 351 | batch_hidden_states = torch.split(batch_hidden_states, hidden_length, dim=1) 352 | 353 | for i_p, length in enumerate(hidden_length): 354 | width, height, temp = widths[i_p], heights[i_p], temps[i_p] 355 | trainable_token_num = trainable_token_list[i_p] 356 | hidden_states = batch_hidden_states[i_p] 357 | 358 | # only the trainable token are taking part in loss computation 359 | hidden_states = hidden_states[:, -trainable_token_num:] 360 | 361 | # unpatchify 362 | hidden_states = hidden_states.reshape( 363 | shape=(batch_size, temp, height, width, self.patch_size, self.patch_size, self.out_channels) 364 | ) 365 | hidden_states = rearrange(hidden_states, "b t h w p1 p2 c -> b t (h p1) (w p2) c") 366 | hidden_states = rearrange(hidden_states, "b t h w c -> b c t h w") 367 | output_hidden_list.append(hidden_states) 368 | 369 | return output_hidden_list 370 | 371 | def forward( 372 | self, 373 | sample: torch.FloatTensor, # [num_stages] 374 | encoder_hidden_states: torch.FloatTensor = None, 375 | encoder_attention_mask: torch.FloatTensor = None, 376 | pooled_projections: torch.FloatTensor = None, 377 | timestep_ratio: torch.FloatTensor = None, 378 | ): 379 | # Get the timestep embedding 380 | temb = self.time_text_embed(timestep_ratio, pooled_projections) 381 | encoder_hidden_states = self.context_embedder(encoder_hidden_states) 382 | encoder_hidden_length = encoder_hidden_states.shape[1] 383 | 384 | # Get the input sequence 385 | hidden_states, hidden_length, temps, heights, widths, trainable_token_list, encoder_attention_mask, \ 386 | attention_mask, image_rotary_emb = self.merge_input(sample, encoder_hidden_length, encoder_attention_mask) 387 | 388 | hidden_states = torch.cat(hidden_states, dim=1) 389 | 390 | # print(hidden_length) 391 | for i_b, block in enumerate(self.transformer_blocks): 392 | encoder_hidden_states, hidden_states = block( 393 | hidden_states=hidden_states, 394 | encoder_hidden_states=encoder_hidden_states, 395 | encoder_attention_mask=encoder_attention_mask, 396 | temb=temb, 397 | attention_mask=attention_mask, 398 | hidden_length=hidden_length, 399 | image_rotary_emb=image_rotary_emb, 400 | ) 401 | 402 | # nan_mask = torch.isnan(hidden_states) 403 | # if torch.any(nan_mask): 404 | # raise ValueError("nan in hidden_states") 405 | 406 | hidden_states = self.norm_out(hidden_states, temb, hidden_length=hidden_length) 407 | hidden_states = self.proj_out(hidden_states) 408 | 409 | output = self.split_output(hidden_states, hidden_length, temps, heights, widths, trainable_token_list) 410 | 411 | return output 412 | --------------------------------------------------------------------------------