├── requirements.txt ├── pyproject.toml ├── modules ├── convs.py ├── patch3d.py ├── downsample3d.py ├── upsample3d.py ├── cross_attn_down_block3d.py ├── cross_attn_up_block3d.py ├── transformer3d.py ├── resnet_block3d.py ├── transformer_block.py ├── fully_attention.py ├── flatten_cldm.py └── unet.py ├── nodes ├── load_flatten_controlnet_node.py ├── create_flow_noise_node.py ├── trajectory_node.py ├── load_flatten_model_node.py ├── apply_flatten_attention_node.py ├── flatten_unsampler_node.py └── flatten_ksampler_node.py ├── utils ├── batching_utils.py ├── flow_noise.py ├── injection_utils.py └── trajectories.py ├── __init__.py ├── .gitignore ├── README.md └── example_workflows └── example_flatten_batched.json /requirements.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "comfyui-flatten" 3 | description = "ComfyUI nodes to use [FLATTEN: optical FLow-guided ATTENtion for consistent text-to-video editing](https://github.com/yrcong/flatten)." 4 | version = "1.0.0" 5 | license = "LICENSE" 6 | dependencies = [] 7 | 8 | [project.urls] 9 | Repository = "https://github.com/logtd/ComfyUI-FLATTEN" 10 | # Used by Comfy Registry https://comfyregistry.org 11 | 12 | [tool.comfy] 13 | PublisherId = "logtd" 14 | DisplayName = "ComfyUI-FLATTEN" 15 | Icon = "" 16 | -------------------------------------------------------------------------------- /modules/convs.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from einops import rearrange 7 | 8 | 9 | class InflatedConv3d(nn.Conv2d): 10 | 11 | def forward(self, x): 12 | video_length = x.shape[2] 13 | 14 | x = rearrange(x, "b c f h w -> (b f) c h w") 15 | x = super().forward(x) 16 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) 17 | 18 | return x 19 | -------------------------------------------------------------------------------- /modules/patch3d.py: -------------------------------------------------------------------------------- 1 | from einops import rearrange 2 | 3 | 4 | def transform_to_2d(x): 5 | b = x.shape[0] 6 | return rearrange(x, 'b c f h w -> (b f) c h w'), b 7 | 8 | 9 | def transformed_to_3d(x, b): 10 | return rearrange(x, '(b f) c h w -> b c f h w', b=b) 11 | 12 | 13 | def apply_unet_patch3d(h, hs, transformer_options, patch): 14 | h, bh = transform_to_2d(h) 15 | hs, bhs = transform_to_2d(hs) 16 | h, hs = patch(h, hs, transformer_options) 17 | h = transformed_to_3d(h, bh) 18 | hs = transformed_to_3d(hs, bhs) 19 | return h, hs 20 | 21 | 22 | def apply_patch3d(x, transformer_options, patch): 23 | x, b = transform_to_2d(x) 24 | x = patch(x, transformer_options) 25 | x = transformed_to_3d(x, b) 26 | return x 27 | -------------------------------------------------------------------------------- /nodes/load_flatten_controlnet_node.py: -------------------------------------------------------------------------------- 1 | import folder_paths 2 | import comfy.controlnet 3 | import comfy.cldm.cldm 4 | from ..modules.flatten_cldm import FlattenControlNet 5 | 6 | 7 | class FlattenControlNetLoader: 8 | @classmethod 9 | def INPUT_TYPES(s): 10 | return {"required": {"control_net_name": (folder_paths.get_filename_list("controlnet"), )}} 11 | 12 | RETURN_TYPES = ("CONTROL_NET",) 13 | FUNCTION = "load_controlnet" 14 | 15 | CATEGORY = "loaders" 16 | 17 | def load_controlnet(self, control_net_name): 18 | controlnet_path = folder_paths.get_full_path( 19 | "controlnet", control_net_name) 20 | original_controlnet = comfy.cldm.cldm.ControlNet 21 | # Hack 22 | comfy.cldm.cldm.ControlNet = FlattenControlNet 23 | controlnet = comfy.controlnet.load_controlnet(controlnet_path) 24 | comfy.cldm.cldm.ControlNet = original_controlnet 25 | return (controlnet,) 26 | -------------------------------------------------------------------------------- /utils/batching_utils.py: -------------------------------------------------------------------------------- 1 | # Adjusted from ADE: https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved 2 | 3 | def create_windows_static_standard(num_frames, context_length, overlap): 4 | windows = [] 5 | if num_frames <= context_length or context_length == 0: 6 | windows.append(list(range(num_frames))) 7 | return windows 8 | # always return the same set of windows 9 | delta = context_length - overlap 10 | for start_idx in range(0, num_frames, delta): 11 | # if past the end of frames, move start_idx back to allow same context_length 12 | ending = start_idx + context_length 13 | if ending >= num_frames: 14 | final_delta = ending - num_frames 15 | final_start_idx = start_idx - final_delta 16 | windows.append( 17 | list(range(final_start_idx, final_start_idx + context_length))) 18 | break 19 | windows.append(list(range(start_idx, start_idx + context_length))) 20 | return windows 21 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .nodes.load_flatten_model_node import FlattenCheckpointLoaderNode 2 | from .nodes.flatten_ksampler_node import KSamplerFlattenNode 3 | from .nodes.flatten_unsampler_node import UnsamplerFlattenNode 4 | from .nodes.trajectory_node import TrajectoryNode 5 | from .nodes.apply_flatten_attention_node import ApplyFlattenAttentionNode 6 | from .nodes.create_flow_noise_node import CreateFlowNoiseNode 7 | 8 | 9 | NODE_CLASS_MAPPINGS = { 10 | "FlattenCheckpointLoaderNode": FlattenCheckpointLoaderNode, 11 | "KSamplerFlattenNode": KSamplerFlattenNode, 12 | "UnsamplerFlattenNode": UnsamplerFlattenNode, 13 | "TrajectoryNode": TrajectoryNode, 14 | "ApplyFlattenAttentionNode": ApplyFlattenAttentionNode, 15 | "CreateFlowNoiseNode": CreateFlowNoiseNode, 16 | } 17 | 18 | NODE_DISPLAY_NAME_MAPPINGS = { 19 | "FlattenCheckpointLoaderNode": "Load Checkpoint with FLATTEN model", 20 | "KSamplerFlattenNode": "KSampler (Flatten)", 21 | "UnsamplerFlattenNode": "Unsampler (Flatten)", 22 | "TrajectoryNode": "Sample Trajectories", 23 | "ApplyFlattenAttentionNode": "Apply Flatten Attention", 24 | "CreateFlowNoiseNode": "Create Flow Noise" 25 | } 26 | -------------------------------------------------------------------------------- /utils/flow_noise.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def create_noise_generator(directions_list, num_frames): 5 | def generator(latent_image): 6 | batch_size, c, h, w = latent_image.shape 7 | 8 | def create_noise(sigma, sigma_next): 9 | nonlocal latent_image 10 | visited = torch.zeros([num_frames, h, w], dtype=torch.bool) 11 | noise = torch.randn_like(latent_image[0]) 12 | noise = torch.cat([noise.unsqueeze(0)]*num_frames) 13 | for t in range(num_frames): 14 | for x in range(h): 15 | for y in range(w): 16 | if visited[t, x, y]: 17 | continue 18 | for directions in directions_list: 19 | if (t, x, y) in directions: 20 | for (pt, px, py) in directions[(t, x, y)]: 21 | noise[pt, :, px, py] = noise[t, :, x, y] 22 | visited[pt, px, py] = True 23 | break 24 | return noise 25 | return create_noise 26 | return generator 27 | -------------------------------------------------------------------------------- /nodes/create_flow_noise_node.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..utils.flow_noise import create_noise_generator 4 | 5 | 6 | class CreateFlowNoiseNode: 7 | @classmethod 8 | def INPUT_TYPES(s): 9 | return {"required": 10 | {"latent": ("LATENT",), 11 | "trajectories": ("TRAJECTORY",), 12 | "add_noise_to_latent": ("BOOLEAN", {"default": False}) 13 | } 14 | } 15 | 16 | RETURN_TYPES = ("LATENT",) 17 | RETURN_NAMES = ("noise",) 18 | FUNCTION = "create" 19 | 20 | CATEGORY = "flatten" 21 | 22 | def create(self, latent, trajectories, add_noise_to_latent): 23 | 24 | noise = torch.zeros_like(latent['samples']) 25 | 26 | noise_gen = create_noise_generator( 27 | [traj['directions'] for traj in trajectories['trajectory_windows'].values()], noise.shape[0]) 28 | 29 | noise = noise_gen(noise)(None, None) 30 | 31 | if add_noise_to_latent: 32 | noise += latent['samples'] 33 | latent = latent.copy() 34 | latent['sampels'] = noise 35 | return (latent, ) 36 | 37 | return ({'samples': noise}, ) 38 | -------------------------------------------------------------------------------- /modules/downsample3d.py: -------------------------------------------------------------------------------- 1 | from .convs import InflatedConv3d 2 | import torch.nn as nn 3 | 4 | import comfy.ops 5 | ops = comfy.ops.disable_weight_init 6 | 7 | 8 | class Downsample3D(nn.Module): 9 | def __init__(self, 10 | channels, 11 | use_conv=False, 12 | dims=2, 13 | out_channels=None, 14 | padding=1, 15 | dtype=None, 16 | device=None, 17 | operations=ops, 18 | ): 19 | super().__init__() 20 | self.channels = channels 21 | self.out_channels = out_channels or channels 22 | self.use_conv = use_conv 23 | self.padding = padding 24 | self.dims = dims 25 | stride = 2 if dims != 3 else (1, 2, 2) # was always 2 26 | 27 | if use_conv: 28 | self.op = InflatedConv3d( 29 | self.channels, self.out_channels, 3, stride=stride, padding=padding).half() 30 | else: 31 | raise NotImplementedError 32 | 33 | def forward(self, hidden_states, **kwargs): 34 | assert hidden_states.shape[1] == self.channels 35 | if self.use_conv and self.padding == 0: 36 | raise NotImplementedError 37 | 38 | assert hidden_states.shape[1] == self.channels 39 | hidden_states = self.op(hidden_states) 40 | 41 | return hidden_states 42 | -------------------------------------------------------------------------------- /nodes/trajectory_node.py: -------------------------------------------------------------------------------- 1 | from ..utils.trajectories import sample_trajectories 2 | from ..utils.batching_utils import create_windows_static_standard 3 | import comfy.model_management 4 | from torchvision.models.optical_flow import Raft_Large_Weights 5 | from torchvision.models.optical_flow import raft_large 6 | 7 | 8 | class TrajectoryNode: 9 | @classmethod 10 | def INPUT_TYPES(s): 11 | return {"required": {"images": ("IMAGE", ), 12 | "context_length": ("INT", {"default": 20, "min": 0, "max": 40, "step": 1}), 13 | "context_overlap": ("INT", {"default": 10, "min": 0, "step": 1}), 14 | }} 15 | RETURN_TYPES = ("TRAJECTORY",) 16 | FUNCTION = "sample" 17 | 18 | CATEGORY = "flatten" 19 | 20 | def sample(self, images, context_length, context_overlap): 21 | device = comfy.model_management.get_torch_device() 22 | weights = Raft_Large_Weights.DEFAULT 23 | model = raft_large(weights=weights, 24 | progress=False).to(device) 25 | 26 | windows = create_windows_static_standard( 27 | images.shape[0], context_length, context_overlap) 28 | pbar = comfy.utils.ProgressBar(len(windows)) 29 | trajectory = { 30 | 'trajectory_windows': {}, 31 | 'context_windows': windows, 32 | 'height': images.shape[1], 33 | 'width': images.shape[2] 34 | } 35 | for i, window in enumerate(windows): 36 | traj = sample_trajectories(images[window], model, weights, device) 37 | trajectory['trajectory_windows'][window[0]] = traj 38 | pbar.update_absolute(i + 1, len(windows)) 39 | return (trajectory,) 40 | -------------------------------------------------------------------------------- /modules/upsample3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .convs import InflatedConv3d 6 | 7 | 8 | # DONE 9 | class Upsample3D(nn.Module): 10 | def __init__(self, 11 | channels, 12 | use_conv=False, 13 | dims=2, 14 | out_channels=None, 15 | padding=1, # using this instead of hardcoded flatten value of 1 16 | dtype=None, 17 | device=None, 18 | operations=None, 19 | ): 20 | super().__init__() 21 | self.channels = channels 22 | self.out_channels = out_channels or channels 23 | self.use_conv = use_conv 24 | self.dims = dims 25 | 26 | # conv = None 27 | if use_conv: 28 | self.conv = InflatedConv3d( 29 | self.channels, self.out_channels, 3, padding=padding).half() 30 | 31 | # self.Conv2d_0 = conv 32 | 33 | def forward(self, hidden_states, output_shape=None): 34 | assert hidden_states.shape[1] == self.channels 35 | 36 | # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 37 | dtype = hidden_states.dtype 38 | if dtype == torch.bfloat16: 39 | hidden_states = hidden_states.to(torch.float32) 40 | 41 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 42 | if hidden_states.shape[0] >= 64: 43 | hidden_states = hidden_states.contiguous() 44 | 45 | # if `output_shape` is passed we force the interpolation output 46 | # size and do not make use of `scale_factor=2` 47 | # if `output_shape` is passed we force the interpolation output 48 | # size and do not make use of `scale_factor=2` 49 | if output_shape is None: 50 | hidden_states = F.interpolate(hidden_states, scale_factor=[ 51 | 1.0, 2.0, 2.0], mode="nearest") 52 | else: 53 | hidden_states = F.interpolate( 54 | hidden_states, size=output_shape, mode="nearest") 55 | 56 | # If the input is bfloat16, we cast back to bfloat16 57 | if dtype == torch.bfloat16: 58 | hidden_states = hidden_states.to(dtype) 59 | 60 | if self.use_conv: 61 | hidden_states = self.conv(hidden_states) 62 | 63 | return hidden_states 64 | -------------------------------------------------------------------------------- /utils/injection_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def _get_injection_names(): 5 | names = ['features0', 'features1', 'features2'] 6 | for i in range(4, 10): 7 | names.append(f'q{i}') 8 | names.append(f'k{i}') 9 | 10 | return names 11 | 12 | 13 | def clear_injections(model): 14 | model = model.model.diffusion_model 15 | res_attn_dict = {1: [0, 1], 2: [0]} 16 | for res in res_attn_dict: 17 | for block in res_attn_dict[res]: 18 | model.output_blocks[3*res+block][0].out_layers_features = None 19 | attn_res_dict = {1: [1, 2], 2: [0, 1, 2], 3: [0]} 20 | for attn in attn_res_dict: 21 | for block in attn_res_dict[attn]: 22 | module = model.output_blocks[3*attn + 23 | block][1].transformer_blocks[0].attn1 24 | module.q = None 25 | module.k = None 26 | module.inject_q = None 27 | module.inject_k = None 28 | 29 | 30 | def get_blank_injection_dict(context_windows): 31 | names = _get_injection_names() 32 | 33 | injection_dict = {} 34 | 35 | for name in names: 36 | blank = {} 37 | for context_window in context_windows: 38 | blank[context_window[0]] = [] 39 | injection_dict[name] = blank 40 | return injection_dict 41 | 42 | 43 | def update_injections(model, injection, context_start, save_steps): 44 | model = model.model.diffusion_model 45 | 46 | res_dict = {1: [0, 1], 2: [0]} 47 | res_idx = 0 48 | for res in res_dict: 49 | for block in res_dict[res]: 50 | feature = model.output_blocks[3*res + 51 | block][0].out_layers_features.cpu() 52 | if len(injection[f'features{res_idx}'][context_start]) < save_steps: 53 | injection[f'features{res_idx}'][context_start].append(feature) 54 | res_idx += 1 55 | 56 | attn_dict = {1: [1, 2], 2: [0, 1, 2], 3: [0]} 57 | attn_idx = 4 58 | for attn in attn_dict: 59 | for block in attn_dict[attn]: 60 | module = model.output_blocks[3*attn + 61 | block][1].transformer_blocks[0].attn1 62 | if len(injection[f'q{attn_idx}'][context_start]) < save_steps: 63 | injection[f'q{attn_idx}'][context_start].append(module.q.cpu()) 64 | injection[f'k{attn_idx}'][context_start].append(module.k.cpu()) 65 | 66 | attn_idx += 1 67 | 68 | 69 | def inject_features(model, injection, device, step, context_start, len_conds): 70 | model = model.model.diffusion_model 71 | 72 | res_dict = {1: [0, 1], 2: [0]} 73 | res_idx = 0 74 | for res in res_dict: 75 | for block in res_dict[res]: 76 | feature = torch.cat( 77 | [injection[f'features{res_idx}'][context_start][step][0, :, :].unsqueeze(0)]*len_conds) 78 | model.output_blocks[3*res + 79 | block][0].out_layers_features = feature.to(device) 80 | res_idx += 1 81 | 82 | attn_dict = {1: [1, 2], 2: [0, 1, 2], 3: [0]} 83 | attn_idx = 4 84 | for attn in attn_dict: 85 | for block in attn_dict[attn]: 86 | module = model.output_blocks[3*attn + 87 | block][1].transformer_blocks[0].attn1 88 | q = torch.cat( 89 | [injection[f'q{attn_idx}'][context_start][step]] * len_conds) 90 | module.inject_q = q.to(device) 91 | k = torch.cat( 92 | [injection[f'k{attn_idx}'][context_start][step]] * len_conds) 93 | module.inject_k = k.to(device) 94 | attn_idx += 1 95 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /modules/cross_attn_down_block3d.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch import nn 4 | 5 | from .transformer3d import Transformer3DModel 6 | from .resnet_block3d import ResnetBlock3D 7 | from .downsample3d import Downsample3D 8 | 9 | 10 | # Attn Blocks are wrappers around 11 | # 1. ResBlock3D 12 | # 2. Transformer3DModel -> SequentialTransformer 13 | # this needs to be changed into the Comfy format on laying flat 14 | class CrossAttnDownBlock3D(nn.Module): 15 | def __init__( 16 | self, 17 | in_channels: int, 18 | out_channels: int, 19 | temb_channels: int, 20 | dropout: float = 0.0, 21 | num_layers: int = 1, 22 | resnet_eps: float = 1e-6, 23 | resnet_time_scale_shift: str = "default", 24 | resnet_act_fn: str = "swish", 25 | resnet_groups: int = 32, 26 | resnet_pre_norm: bool = True, 27 | attn_num_head_channels=1, 28 | cross_attention_dim=1280, 29 | output_scale_factor=1.0, 30 | downsample_padding=1, 31 | add_downsample=True, 32 | use_linear_projection=False, 33 | only_cross_attention=False, 34 | upcast_attention=False, 35 | ): 36 | super().__init__() 37 | resnets = [] 38 | attentions = [] 39 | 40 | self.has_cross_attention = True 41 | self.attn_num_head_channels = attn_num_head_channels 42 | 43 | for i in range(num_layers): 44 | in_channels = in_channels if i == 0 else out_channels 45 | resnets.append( 46 | ResnetBlock3D( 47 | in_channels=in_channels, 48 | out_channels=out_channels, 49 | temb_channels=temb_channels, 50 | eps=resnet_eps, 51 | groups=resnet_groups, 52 | dropout=dropout, 53 | time_embedding_norm=resnet_time_scale_shift, 54 | non_linearity=resnet_act_fn, 55 | output_scale_factor=output_scale_factor, 56 | pre_norm=resnet_pre_norm, 57 | ) 58 | ) 59 | attentions.append( 60 | Transformer3DModel( 61 | attn_num_head_channels, 62 | out_channels // attn_num_head_channels, 63 | in_channels=out_channels, 64 | num_layers=1, 65 | cross_attention_dim=cross_attention_dim, 66 | norm_num_groups=resnet_groups, 67 | use_linear_projection=use_linear_projection, 68 | only_cross_attention=only_cross_attention, 69 | upcast_attention=upcast_attention, 70 | ) 71 | ) 72 | self.attentions = nn.ModuleList(attentions) 73 | self.resnets = nn.ModuleList(resnets) 74 | 75 | # if add_downsample: 76 | # self.downsamplers = nn.ModuleList( 77 | # [ 78 | # Downsample3D( 79 | # out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" 80 | # ) 81 | # ] 82 | # ) 83 | # else: 84 | # self.downsamplers = None 85 | 86 | self.gradient_checkpointing = False 87 | 88 | def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, inter_frame=False, **kwargs): 89 | output_states = () 90 | 91 | for resnet, attn in zip(self.resnets, self.attentions): 92 | hidden_states = resnet(hidden_states, temb) 93 | hidden_states = attn( 94 | hidden_states, encoder_hidden_states=encoder_hidden_states, inter_frame=inter_frame, **kwargs).sample 95 | 96 | output_states += (hidden_states,) 97 | 98 | if self.downsamplers is not None: 99 | for downsampler in self.downsamplers: 100 | hidden_states = downsampler(hidden_states) 101 | 102 | output_states += (hidden_states,) 103 | 104 | return hidden_states, output_states 105 | -------------------------------------------------------------------------------- /modules/cross_attn_up_block3d.py: -------------------------------------------------------------------------------- 1 | 2 | # This needs to be cut up and put in the UNet code 3 | class CrossAttnUpBlock3D(nn.Module): 4 | def __init__( 5 | self, 6 | in_channels: int, 7 | out_channels: int, 8 | prev_output_channel: int, 9 | temb_channels: int, 10 | dropout: float = 0.0, 11 | num_layers: int = 1, 12 | resnet_eps: float = 1e-6, 13 | resnet_time_scale_shift: str = "default", 14 | resnet_act_fn: str = "swish", 15 | resnet_groups: int = 32, 16 | resnet_pre_norm: bool = True, 17 | attn_num_head_channels=1, 18 | cross_attention_dim=1280, 19 | output_scale_factor=1.0, 20 | add_upsample=True, 21 | dual_cross_attention=False, 22 | use_linear_projection=False, 23 | only_cross_attention=False, 24 | upcast_attention=False, 25 | ): 26 | super().__init__() 27 | resnets = [] 28 | attentions = [] 29 | 30 | self.has_cross_attention = True 31 | self.attn_num_head_channels = attn_num_head_channels 32 | 33 | for i in range(num_layers): 34 | res_skip_channels = in_channels if ( 35 | i == num_layers - 1) else out_channels 36 | resnet_in_channels = prev_output_channel if i == 0 else out_channels 37 | 38 | resnets.append( 39 | ResnetBlock3D( 40 | in_channels=resnet_in_channels + res_skip_channels, 41 | out_channels=out_channels, 42 | temb_channels=temb_channels, 43 | eps=resnet_eps, 44 | groups=resnet_groups, 45 | dropout=dropout, 46 | time_embedding_norm=resnet_time_scale_shift, 47 | non_linearity=resnet_act_fn, 48 | output_scale_factor=output_scale_factor, 49 | pre_norm=resnet_pre_norm, 50 | ) 51 | ) 52 | if dual_cross_attention: 53 | raise NotImplementedError 54 | attentions.append( 55 | Transformer3DModel( 56 | attn_num_head_channels, 57 | out_channels // attn_num_head_channels, 58 | in_channels=out_channels, 59 | num_layers=1, 60 | cross_attention_dim=cross_attention_dim, 61 | norm_num_groups=resnet_groups, 62 | use_linear_projection=use_linear_projection, 63 | only_cross_attention=only_cross_attention, 64 | upcast_attention=upcast_attention, 65 | ) 66 | ) 67 | 68 | self.attentions = nn.ModuleList(attentions) 69 | self.resnets = nn.ModuleList(resnets) 70 | 71 | if add_upsample: 72 | self.upsamplers = nn.ModuleList( 73 | [Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) 74 | else: 75 | self.upsamplers = None 76 | 77 | self.gradient_checkpointing = False 78 | 79 | def forward( 80 | self, 81 | hidden_states, 82 | res_hidden_states_tuple, 83 | temb=None, 84 | encoder_hidden_states=None, 85 | upsample_size=None, 86 | attention_mask=None, 87 | inter_frame=False, 88 | **kwargs, 89 | ): 90 | for resnet, attn in zip(self.resnets, self.attentions): 91 | # pop res hidden states 92 | res_hidden_states = res_hidden_states_tuple[-1] 93 | res_hidden_states_tuple = res_hidden_states_tuple[:-1] 94 | hidden_states = torch.cat( 95 | [hidden_states, res_hidden_states], dim=1) 96 | 97 | hidden_states = resnet(hidden_states, temb) 98 | hidden_states = attn( 99 | hidden_states, encoder_hidden_states=encoder_hidden_states, inter_frame=inter_frame, **kwargs).sample 100 | 101 | if self.upsamplers is not None: 102 | for upsampler in self.upsamplers: 103 | hidden_states = upsampler(hidden_states, upsample_size) 104 | 105 | return hidden_states 106 | -------------------------------------------------------------------------------- /nodes/load_flatten_model_node.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange 3 | import comfy.sd 4 | import comfy.model_base 5 | import comfy.model_management 6 | import folder_paths 7 | import comfy.ldm.modules.diffusionmodules.openaimodel as openaimodel 8 | from ..modules.unet import UNetModel as FlattenModel 9 | 10 | 11 | class PatchBaseModel(comfy.model_base.BaseModel): 12 | def __init__(self, model_config, *args, model_type=comfy.model_base.ModelType.EPS, device=None, unet_model=FlattenModel, **kwargs): 13 | super().__init__(model_config, model_type, device, FlattenModel) 14 | 15 | 16 | class FlattenCheckpointLoaderNode: 17 | @classmethod 18 | def INPUT_TYPES(s): 19 | return {"required": {"ckpt_name": (folder_paths.get_filename_list("checkpoints"), ), 20 | }} 21 | RETURN_TYPES = ("MODEL", "CLIP", "VAE") 22 | FUNCTION = "load_checkpoint" 23 | 24 | CATEGORY = "loaders" 25 | 26 | def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True): 27 | original_base = comfy.model_base.BaseModel 28 | comfy.model_base.BaseModel = PatchBaseModel 29 | ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) 30 | out = comfy.sd.load_checkpoint_guess_config( 31 | ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) 32 | comfy.model_base.BaseModel = original_base 33 | 34 | def model_function_wrapper(apply_model_func, apply_params): 35 | # Prepare 3D latent 36 | input_x = apply_params['input'] 37 | len_conds = len(apply_params['cond_or_uncond']) 38 | frame_count = input_x.shape[0] // len_conds 39 | input_x = rearrange( 40 | input_x, "(b f) c h w -> b c f h w", b=len_conds) 41 | timestep_ = apply_params['timestep'] 42 | timestep_ = timestep_[torch.arange( 43 | 0, timestep_.size(0), frame_count)] 44 | 45 | # Correct Flatten vars for any batching 46 | 47 | # Do injection if needed 48 | transformer_options = apply_params['c'].get( 49 | 'transformer_options', {}) 50 | flatten_options = transformer_options.get('flatten', {}) 51 | 52 | idxs = None 53 | context_start = 0 54 | if 'ad_params' in transformer_options and transformer_options['ad_params'].get('sub_idxs', None) is not None: 55 | idxs = transformer_options['ad_params']['sub_idxs'] 56 | context_start = idxs[0] 57 | else: 58 | idxs = list(range(frame_count)) 59 | transformer_options['flatten']['trajs'] = transformer_options['flatten']['trajs_windows'][0] 60 | transformer_options['flatten']['trajs'] = transformer_options['flatten']['trajs_windows'][context_start] 61 | 62 | transformer_options['flatten']['idxs'] = idxs 63 | transformer_options['flatten']['video_length'] = frame_count 64 | 65 | # Inject if sampling 66 | injection_handler = flatten_options.get('injection_handler', None) 67 | if injection_handler is not None: 68 | step = flatten_options['injection_handler']( 69 | timestep_[0], context_start, len_conds) 70 | flatten_options['step'] = step 71 | 72 | del apply_params['timestep'] 73 | conditioning = {} 74 | for key in apply_params['c']: 75 | value = apply_params['c'][key] 76 | if key == 'c_crossattn': 77 | value = value[torch.arange(0, value.size(0), frame_count)] 78 | 79 | conditioning[key] = value 80 | 81 | conditioning 82 | del apply_params['c'] 83 | del apply_params['input'] 84 | model_out = apply_model_func(input_x, timestep_, **conditioning) 85 | 86 | # Save injections if unsampling 87 | save_injections_handler = flatten_options.get( 88 | 'save_injections_handler', None) 89 | if save_injections_handler is not None: 90 | save_injections_handler(context_start) 91 | 92 | # Return 2D latent 93 | model_out = rearrange(model_out, 'b c f h w -> (b f) c h w') 94 | return model_out 95 | 96 | model = out[0] 97 | model.model_options['model_function_wrapper'] = model_function_wrapper 98 | 99 | load_device = comfy.model_management.get_torch_device() 100 | offload_device = comfy.model_management.unet_offload_device() 101 | model_patcher = comfy.model_patcher.ModelPatcher( 102 | model.model, load_device=load_device, offload_device=offload_device) 103 | 104 | out = list(out) 105 | out[0] = model_patcher 106 | model_patcher.model_options['model_function_wrapper'] = model_function_wrapper 107 | return out[:3] 108 | -------------------------------------------------------------------------------- /modules/transformer3d.py: -------------------------------------------------------------------------------- 1 | from comfy.ldm.util import exists 2 | from .transformer_block import BasicTransformerBlock 3 | from einops import rearrange, repeat 4 | import torch 5 | import torch.nn as nn 6 | 7 | import comfy.ops 8 | ops = comfy.ops.disable_weight_init 9 | 10 | 11 | class Transformer3DModel(nn.Module): 12 | def __init__(self, 13 | in_channels, 14 | n_heads, 15 | d_head: int = 88, 16 | depth=1, 17 | dropout=0., 18 | context_dim=None, 19 | disable_self_attn=False, 20 | use_linear=False, 21 | use_checkpoint=True, 22 | dtype=None, 23 | device=None, 24 | operations=ops, 25 | ): 26 | super().__init__() 27 | if exists(context_dim) and not isinstance(context_dim, list): 28 | context_dim = [context_dim] * depth 29 | self.in_channels = in_channels 30 | inner_dim = n_heads * d_head 31 | self.norm = operations.GroupNorm( 32 | num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device) 33 | if not use_linear: 34 | self.proj_in = operations.Conv2d(in_channels, 35 | inner_dim, 36 | kernel_size=1, 37 | stride=1, 38 | padding=0, dtype=dtype, device=device) 39 | else: 40 | self.proj_in = operations.Linear( 41 | in_channels, inner_dim, dtype=dtype, device=device) 42 | 43 | self.transformer_blocks = nn.ModuleList( 44 | [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], 45 | disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, dtype=dtype, device=device, operations=operations) 46 | for d in range(depth)] 47 | ) 48 | if not use_linear: 49 | self.proj_out = operations.Conv2d(inner_dim, in_channels, 50 | kernel_size=1, 51 | stride=1, 52 | padding=0, dtype=dtype, device=device) 53 | else: 54 | self.proj_out = operations.Linear( 55 | in_channels, inner_dim, dtype=dtype, device=device) 56 | self.use_linear = use_linear 57 | 58 | def forward(self, hidden_states, context=None, transformer_options={}): 59 | # Input 60 | 61 | assert hidden_states.dim( 62 | ) == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." 63 | inter_frame = False 64 | if 'flatten' in transformer_options and 'inter_frame' in transformer_options["flatten"]: 65 | inter_frame = transformer_options["flatten"]['inter_frame'] 66 | video_length = hidden_states.shape[2] 67 | cond_size = hidden_states.shape[0] 68 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") 69 | context = repeat( 70 | context, 'b n c -> (b f) n c', f=video_length) 71 | 72 | batch, channel, height, weight = hidden_states.shape 73 | residual = hidden_states 74 | 75 | # check resolution 76 | 77 | resolu = hidden_states.shape[-2] # height 78 | height = resolu 79 | width = hidden_states.shape[-1] 80 | traj_options = {"resolution": resolu, 81 | "cond_size": cond_size, "height": height, "width": width} 82 | 83 | hidden_states = self.norm(hidden_states) 84 | if not self.use_linear: 85 | hidden_states = self.proj_in(hidden_states) 86 | inner_dim = hidden_states.shape[1] 87 | hidden_states = hidden_states.permute( 88 | 0, 2, 3, 1).reshape(batch, height * weight, inner_dim) 89 | else: 90 | inner_dim = hidden_states.shape[1] 91 | hidden_states = hidden_states.permute( 92 | 0, 2, 3, 1).reshape(batch, height * weight, inner_dim) 93 | hidden_states = self.proj_in(hidden_states) 94 | 95 | # Blocks 96 | for block in self.transformer_blocks: 97 | hidden_states = block( 98 | hidden_states, 99 | context=context, 100 | video_length=video_length, 101 | inter_frame=inter_frame, 102 | transformer_options=transformer_options, 103 | traj_options=traj_options 104 | ) 105 | 106 | # Output 107 | if not self.use_linear: 108 | hidden_states = ( 109 | hidden_states.reshape(batch, height, weight, inner_dim).permute( 110 | 0, 3, 1, 2).contiguous() 111 | ) 112 | hidden_states = self.proj_out(hidden_states) 113 | else: 114 | hidden_states = self.proj_out(hidden_states) 115 | hidden_states = ( 116 | hidden_states.reshape(batch, height, weight, inner_dim).permute( 117 | 0, 3, 1, 2).contiguous() 118 | ) 119 | 120 | output = hidden_states + residual 121 | 122 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) 123 | 124 | return output 125 | -------------------------------------------------------------------------------- /nodes/apply_flatten_attention_node.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | from einops import rearrange 5 | 6 | from comfy.ldm.modules.attention import optimized_attention 7 | from comfy.model_patcher import ModelPatcher 8 | 9 | 10 | def reshape_heads_to_batch_dim3(tensor, head_size): 11 | batch_size1, batch_size2, seq_len, dim = tensor.shape 12 | tensor = tensor.reshape(batch_size1, batch_size2, 13 | seq_len, head_size, dim // head_size) 14 | tensor = tensor.permute(0, 3, 1, 2, 4) 15 | return tensor 16 | 17 | 18 | def apply_flow(query, 19 | key, 20 | value, 21 | trajectories, 22 | extra_options): 23 | # TODO: Hardcoded for SD1.5 24 | height = trajectories['height']//8 25 | width = trajectories['width']//8 26 | n_heads = extra_options['n_heads'] 27 | cond_size = len(extra_options['cond_or_uncond']) 28 | video_length = len(query) // cond_size 29 | 30 | ad_params = extra_options.get('ad_params', {}) 31 | sub_idxs = ad_params.get('sub_idxs', None) 32 | idx = 0 33 | if sub_idxs is not None: 34 | idx = sub_idxs[0] 35 | 36 | traj_window = trajectories['trajectory_windows'][idx] 37 | trajs = traj_window[f'traj{height}'] 38 | traj_mask = traj_window[f'mask{height}'] 39 | 40 | start = -video_length+1 41 | end = trajs.shape[2] 42 | 43 | traj_key_sequence_inds = torch.cat( 44 | [trajs[:, :, 0, :].unsqueeze(-2), trajs[:, :, start:end, :]], dim=-2) 45 | traj_mask = torch.cat([traj_mask[:, :, 0].unsqueeze(-1), 46 | traj_mask[:, :, start:end]], dim=-1) 47 | 48 | t_inds = traj_key_sequence_inds[:, :, :, 0] 49 | x_inds = traj_key_sequence_inds[:, :, :, 1] 50 | y_inds = traj_key_sequence_inds[:, :, :, 2] 51 | 52 | query_tempo = query.unsqueeze(-2) 53 | _key = rearrange(key, '(b f) (h w) d -> b f h w d', 54 | b=cond_size, h=height, w=width) 55 | _value = rearrange(value, '(b f) (h w) d -> b f h w d', 56 | b=cond_size, h=height, w=width) 57 | key_tempo = _key[:, t_inds, x_inds, y_inds] 58 | value_tempo = _value[:, t_inds, x_inds, y_inds] 59 | key_tempo = rearrange(key_tempo, 'b f n l d -> (b f) n l d') 60 | value_tempo = rearrange(value_tempo, 'b f n l d -> (b f) n l d') 61 | 62 | traj_mask = rearrange(torch.stack( 63 | [traj_mask] * cond_size), 'b f n l -> (b f) n l') 64 | traj_mask = traj_mask[:, None].repeat( 65 | 1, n_heads, 1, 1).unsqueeze(-2) 66 | attn_bias = torch.zeros_like( 67 | traj_mask, dtype=key_tempo.dtype, device=query.device) # regular zeros_like 68 | attn_bias[~traj_mask] = -torch.inf 69 | 70 | # flow attention 71 | query_tempo = reshape_heads_to_batch_dim3(query_tempo, n_heads) 72 | key_tempo = reshape_heads_to_batch_dim3(key_tempo, n_heads) 73 | value_tempo = reshape_heads_to_batch_dim3(value_tempo, n_heads) 74 | 75 | attn_matrix2 = query_tempo @ key_tempo.transpose(-2, -1) / math.sqrt( 76 | query_tempo.size(-1)) + attn_bias 77 | attn_matrix2 = F.softmax(attn_matrix2, dim=-1) 78 | out = (attn_matrix2@value_tempo).squeeze(-2) 79 | 80 | hidden_states = rearrange(out, 'b k r d -> b r (k d)') 81 | 82 | return hidden_states 83 | 84 | 85 | def get_flatten_attention(trajectories, use_old_qk=False): 86 | def flatten_attention(q, k, v, extra_options): 87 | n_heads = extra_options['n_heads'] 88 | 89 | hidden_states = optimized_attention(q, k, v, n_heads, mask=None) 90 | 91 | _, hw, _ = q.shape 92 | # TODO: Hardcoded for SD1.5 93 | target_height = trajectories['height']//8 94 | target_width = trajectories['width']//8 95 | 96 | if target_height * target_width == hw: 97 | if use_old_qk is True: 98 | query = q 99 | key = k 100 | else: 101 | query = hidden_states 102 | key = hidden_states 103 | hidden_states = apply_flow( 104 | query, 105 | key, 106 | hidden_states, 107 | trajectories, 108 | extra_options 109 | ) 110 | 111 | return hidden_states 112 | return flatten_attention 113 | 114 | 115 | class ApplyFlattenAttentionNode: 116 | @classmethod 117 | def INPUT_TYPES(s): 118 | return {"required": 119 | {"model": ("MODEL",), 120 | "trajectories": ("TRAJECTORY",), 121 | "use_old_qk": ("BOOLEAN", {"default": False}), 122 | "input_attn_1": ("BOOLEAN", {"default": True}), 123 | "input_attn_2": ("BOOLEAN", {"default": True}), 124 | "output_attn_9": ("BOOLEAN", {"default": True}), 125 | "output_attn_10": ("BOOLEAN", {"default": True}), 126 | "output_attn_11": ("BOOLEAN", {"default": True}), 127 | } 128 | } 129 | 130 | RETURN_TYPES = ("MODEL",) 131 | FUNCTION = "apply" 132 | 133 | CATEGORY = "flatten" 134 | 135 | def apply(self, model, trajectories, use_old_qk, 136 | input_attn_1, input_attn_2, output_attn_9, output_attn_10, output_attn_11): 137 | model: ModelPatcher = model.clone() 138 | 139 | # TODO: Hardcoded for SD1.5 140 | attn = get_flatten_attention(trajectories, use_old_qk) 141 | if input_attn_1: 142 | model.set_model_patch_replace(attn, 'attn1', 143 | 'input', 1) 144 | if input_attn_2: 145 | model.set_model_patch_replace(attn, 'attn1', 146 | 'input', 2) 147 | if output_attn_9: 148 | model.set_model_patch_replace(attn, 'attn1', 149 | 'output', 9) 150 | if output_attn_10: 151 | model.set_model_patch_replace(attn, 'attn1', 152 | 'output', 10) 153 | if output_attn_11: 154 | model.set_model_patch_replace(attn, 'attn1', 155 | 'output', 11) 156 | 157 | return (model, ) 158 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ComfyUI-FLATTEN 2 | ComfyUI nodes to use FLATTEN. 3 | 4 | Original research repo: [FLATTEN](https://github.com/yrcong/flatten) 5 | 6 | https://github.com/logtd/ComfyUI-FLATTEN/assets/160989552/518865fe-8bf3-44aa-ab05-edaaff92c3e0 7 | 8 | ## Table of Contents 9 | - [Installation](#installation) 10 | - [How to Install](#how-to-install) 11 | - [Nodes](#nodes) 12 | - [Accompanying Node Repos](#accompanying-node-repos) 13 | - [Examples](#examples) 14 | - [Acknowledgements](#acknowledgements) 15 | 16 | ## Installation 17 | 18 | ### How to Install 19 | Clone or download this repo into your `ComfyUI/custom_nodes/` directory or use the ComfyUI-Manager to automatically install the nodes. No additional Python packages outside of ComfyUI requirements should be necessary. 20 | 21 | ## Nodes 22 | flatten_nodes_screenshot 23 | 24 | * Node: Load Checkpoint with FLATTEN model 25 | * Loads any given SD1.5 checkpoint with the FLATTEN optical flow model. Use the `sdxl` branch of this repo to load SDXL models 26 | * The loaded model only works with the Flatten KSampler and a standard ComfyUI checkpoint loader is required for other KSamplers 27 | 28 | * Node: Sample Trajectories 29 | * Takes the input images and samples their optical flow into trajectories. Trajectories are created for the dimensions of the input image and must match the latent size Flatten processes. 30 | * Context Length and Overlap for Batching with AnimateDiff-Evolved 31 | * Context Length defines the window size Flatten processes at a time. Flatten is not limitted to a certain frame count, but this can be used to reduce VRAM usage at a single time 32 | * Context Overlap is the overlap between windows 33 | * Can only use Standard Static from AnimateDiff-Evolved and these values must match the values given to AnimateDiff's Evolved Sampling context 34 | * Currently does not support Views 35 | 36 | * Node: Unsampler (Flatten) 37 | * Unsamples the input latent and creates the needed injections required for sampling 38 | * Only use Euler or ddpm2m as the sampling method since this process creates noise from the input images 39 | 40 | * Node: KSampler (Flatten) 41 | * Samples the unsampled latents and uses the injections from the Unsampler 42 | * Can use any sampling method, but use Euler or ddpm2m for editing pieces of the video or another sampling method to get drastic changes in the video 43 | 44 | * Node: Apply Flatten Attention (SD1.5 Only) 45 | * Use Flatten's Optical Flow attention mechanism without the rest of Flatten's model -- can be used to combine with other models 46 | * Warning: Flatten's attention requires "Flow Noise" so it does not always work with methods that add normal noise 47 | 48 | * Node: Create Flow Noise 49 | * Creates flow noise given a latent and trajectories 50 | * Can be used to add initial noise to a latent instead of using normal noise from a traditional KSampler 51 | 52 | 53 | ## Accompanying Node Repos 54 | * [Video Helper Suite](https://github.com/Kosinkadink/ComfyUI-VideoHelperSuite) for loading and combining videos 55 | * [AnimateDiff-Evolved](https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved) for batching options 56 | 57 | ## Examples 58 | For working ComfyUI example workflows see the `example_workflows/` directory. 59 | 60 | ### Video Editing 61 | FLATTEN excels at editing videos with temporal consistency. The recommended settings for this are to use an Unsampler and KSampler with `old_qk = 0`. The Unsampler should use the euler sampler and the KSampler should use the dpmpp_2m sampler. Users may experiment with `old_qk` depending on their use case, but it is not recommended to use other samplers or `add_noise` for video editing. Style transfer nodes such as IP-Adapter may have difficulty making quality edits without the additional noise and will require fine tuning. 62 | 63 | ### Scene Editing (Experimental) 64 | Inspired by the optical flow use in FLATTEN, these nodes can utilize noise that is driven by optical flow. The current implementation is experimental and allows the user to create highly altered scenes, however it can lose some of the consistency and does not work well with high motion scenes. 65 | 66 | To use this, it is recommended to use LCM on the KSampler (not the Unsampler) alongside setting `old_qk = 1` on the KSampler. Ancestral sampling methods also work well. Users may experiment with toggling the `add_noise` setting on the KSampler when using a sampling method that injects noise (e.g. anything besides Euler and dpmpp2). Using IPAdapter can help guide these generations towards a specific look. 67 | 68 | https://github.com/logtd/ComfyUI-FLATTEN/assets/160989552/18b49cbb-9647-48c0-9f3d-b58440fc9c1a 69 | 70 | https://github.com/logtd/ComfyUI-FLATTEN/assets/160989552/13769f9a-05f0-4669-ba80-556a8169e3df 71 | 72 | https://github.com/logtd/ComfyUI-FLATTEN/assets/160989552/f6fcf5c4-df0e-4ca4-8411-388520442d6c 73 | 74 | https://github.com/logtd/ComfyUI-FLATTEN/assets/160989552/d9942a82-aadb-49a6-92f4-9bf95de390ed 75 | 76 | ## ComfyUI Support 77 | The ComfyUI-FLATTEN implementation can support most ComfyUI nodes, including ControlNets, IP-Adapter, LCM, InstanceDiffusion/GLIGEN, and many more. 78 | 79 | ### Batching 80 | Currently batching for large amount of frames results in a loss in consistency and a possible solution is under consideration. 81 | 82 | The current batching mechanism utilizes the AnimateDiff-Evolved batching nodes and is required to batch. See the example workflow for a working example. 83 | 84 | ### SDXL Support 85 | Experiments for supporting SDXL were made and resulted in generating somewhat consistent videos, but not up-to-par with the SD1.5 implementation. 86 | Feel free to check out the `sdxl` branch, but there will be no further development in this direction. 87 | 88 | ### Unsupported 89 | Currently the known unsupported custom ComfyUI features are: 90 | * Scheduled Prompting 91 | * Context Views for advanced batching 92 | 93 | ## Acknowledgements 94 | * [Cong, Yuren and Xu, Mengmeng and Simon, Christian and Chen, Shoufa and Ren, Jiawei and Xie, Yanping and Perez-Rua, Juan-Manuel and Rosenhahn, Bodo and Xiang, Tao and He, Sen](https://github.com/yrcong/flatten) for their research on FLATTEN, producing the original repo, and contributing to open source. 95 | * [Kosinkadink](https://github.com/Kosinkadink) for creating Video Helper Suite and AnimateDiff-Evolved 96 | * [Kijai](https://github.com/kijai) for making helpful nodes 97 | * [@AIWarper](https://twitter.com/AIWarper) for testing and making amazing content 98 | -------------------------------------------------------------------------------- /modules/resnet_block3d.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py 2 | 3 | from .downsample3d import Downsample3D 4 | from .upsample3d import Upsample3D 5 | from .convs import InflatedConv3d 6 | from einops import rearrange 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | import comfy.ops 12 | ops = comfy.ops.disable_weight_init 13 | 14 | 15 | class ResnetBlock3D(nn.Module): 16 | def __init__( 17 | self, 18 | channels, 19 | emb_channels, 20 | dropout, 21 | out_channels=None, 22 | use_conv=False, 23 | use_scale_shift_norm=False, 24 | dims=2, 25 | use_checkpoint=False, 26 | up=False, 27 | down=False, 28 | kernel_size=3, 29 | exchange_temb_dims=False, 30 | skip_t_emb=False, 31 | dtype=None, 32 | device=None, 33 | operations=ops, 34 | groups=32, 35 | groups_out=None, 36 | pre_norm=True, 37 | eps=1e-6, 38 | ): 39 | super().__init__() 40 | self.pre_norm = pre_norm 41 | self.pre_norm = True 42 | self.channels = channels 43 | out_channels = channels if out_channels is None else out_channels 44 | self.out_channels = out_channels 45 | self.use_conv_shortcut = use_conv 46 | # comfy setup 47 | self.channels = channels 48 | self.emb_channels = emb_channels 49 | self.dropout = dropout 50 | self.out_channels = out_channels or channels 51 | self.use_conv = use_conv 52 | self.use_checkpoint = use_checkpoint 53 | self.use_scale_shift_norm = use_scale_shift_norm 54 | self.exchange_temb_dims = exchange_temb_dims 55 | 56 | if groups_out is None: 57 | groups_out = groups 58 | 59 | if isinstance(kernel_size, list): 60 | padding = [k // 2 for k in kernel_size] 61 | else: 62 | padding = kernel_size // 2 63 | self.in_layers = nn.Sequential( 64 | operations.GroupNorm(32, channels, dtype=dtype, device=device), 65 | nn.SiLU(), # comfy 66 | InflatedConv3d(channels, out_channels, kernel_size=3, 67 | stride=1, padding=padding).half() 68 | ) 69 | 70 | self.updown = up or down 71 | if up: 72 | self.h_upd = Upsample3D( 73 | channels, False, dims, dtype=dtype, device=device) 74 | self.x_upd = Upsample3D( 75 | channels, False, dims, dtype=dtype, device=device) 76 | elif down: 77 | downsample_padding = 1 78 | self.h_upd = Downsample3D( 79 | out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" 80 | ) 81 | self.x_upd = Downsample3D( 82 | channels, False, dims, dtype=dtype, device=device) 83 | else: 84 | self.h_upd = self.x_upd = nn.Identity() 85 | 86 | self.skip_t_emb = skip_t_emb 87 | if self.skip_t_emb: 88 | self.emb_layers = None 89 | self.exchange_temb_dims = False 90 | else: 91 | self.emb_layers = nn.Sequential( 92 | nn.SiLU(), 93 | operations.Linear( 94 | emb_channels, 95 | 2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype, device=device 96 | ), 97 | ) 98 | 99 | self.out_layers = nn.Sequential( 100 | operations.GroupNorm(32, self.out_channels, 101 | dtype=dtype, device=device), 102 | nn.SiLU(), 103 | nn.Dropout(p=dropout), 104 | InflatedConv3d(out_channels, out_channels, kernel_size=3, 105 | stride=1, padding=1, dtype=dtype, device=device), 106 | ) 107 | 108 | if self.out_channels == channels: 109 | self.skip_connection = nn.Identity() 110 | elif use_conv: 111 | self.skip_connection = InflatedConv3d( 112 | channels, out_channels, kernel_size=kernel_size, padding=padding).to(dtype) 113 | else: 114 | self.skip_connection = InflatedConv3d(channels, out_channels, kernel_size=1).half( 115 | ).to(dtype) 116 | 117 | # save features 118 | self.out_layers_features = None 119 | self.out_layers_inject_features = None 120 | 121 | def forward(self, x, temb, transformer_options={}, **kwargs): 122 | input_tensor = x 123 | if self.updown: 124 | in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] 125 | h = in_rest(x) 126 | h = self.h_upd(h) 127 | x = self.x_upd(x) 128 | h = in_conv(h) 129 | else: 130 | h = self.in_layers(x) 131 | 132 | emb = temb 133 | emb_out = None 134 | if not self.skip_t_emb: 135 | emb_out = self.emb_layers(emb).type(h.dtype) 136 | while len(emb_out.shape) < len(h.shape): 137 | emb_out = emb_out[..., None] 138 | if self.use_scale_shift_norm: 139 | out_norm, out_rest = self.out_layers[0], self.out_layers[1:] 140 | h = out_norm(h) 141 | if emb_out is not None: 142 | scale, shift = torch.chunk(emb_out, 2, dim=1) 143 | h *= (1 + scale) 144 | h += shift 145 | h = out_rest(h) 146 | else: 147 | if emb_out is not None: 148 | if self.exchange_temb_dims: 149 | emb_out = rearrange(emb_out, "b t c ... -> b c t ...") 150 | if emb_out.shape[0] != h.shape[0]: # ControlNet Hack TODO 151 | video_length = transformer_options['flatten']['original_shape'][0] 152 | emb_out = rearrange( 153 | emb_out, '(b f) t c h w -> b t (c f) h w', f=video_length) 154 | h = h + emb_out # (2, 320, 10, 64, 64) + (2, 320, 1, 1, 1) 155 | h = self.out_layers(h) 156 | 157 | if self.skip_connection is not None: 158 | input_tensor = self.skip_connection(input_tensor) 159 | 160 | self.out_layers_features = h 161 | if self.out_layers_inject_features is not None: 162 | h = self.out_layers_inject_features 163 | 164 | output_tensor = input_tensor + h 165 | 166 | return output_tensor 167 | -------------------------------------------------------------------------------- /nodes/flatten_unsampler_node.py: -------------------------------------------------------------------------------- 1 | from tqdm import trange 2 | import comfy.samplers 3 | import torch 4 | import comfy.k_diffusion.sampling 5 | import comfy.sample 6 | 7 | from ..utils.injection_utils import get_blank_injection_dict, clear_injections, update_injections 8 | from ..utils.flow_noise import create_noise_generator 9 | 10 | 11 | @torch.no_grad() 12 | def sample_inversed_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): 13 | """Implements Algorithm 2 (Euler steps) from Karras et al. (2022).""" 14 | extra_args = {} if extra_args is None else extra_args 15 | s_in = x.new_ones([x.shape[0]]) 16 | latents = [] 17 | for i in trange(1, len(sigmas), disable=disable): 18 | sigma_in = sigmas[i-1] 19 | 20 | if i == 1: 21 | sigma_t = sigmas[i] 22 | else: 23 | sigma_t = sigma_in 24 | 25 | denoised = model(x, sigma_t * s_in, **extra_args) 26 | 27 | if i == 1: 28 | d = (x - denoised) / (2 * sigmas[i]) 29 | else: 30 | d = (x - denoised) / sigmas[i-1] 31 | 32 | dt = sigmas[i] - sigmas[i-1] 33 | x = x + d * dt 34 | if callback is not None: 35 | callback( 36 | {'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) 37 | 38 | return x / sigmas[-1] 39 | 40 | 41 | class UnsamplerFlattenNode: 42 | @classmethod 43 | def INPUT_TYPES(s): 44 | return {"required": 45 | {"model": ("MODEL",), 46 | "steps": ("INT", {"default": 20, "min": 1, "max": 10000}), 47 | "save_steps": ("INT", {"default": 8, "min": 0, "max": 10000}), 48 | "sampler_name": (comfy.samplers.KSampler.SAMPLERS, ), 49 | "scheduler": (comfy.samplers.KSampler.SCHEDULERS, ), 50 | "normalize": (["disable", "enable"], ), 51 | "positive": ("CONDITIONING", ), 52 | "latent_image": ("LATENT", ), 53 | "trajectories": ("TRAJECTORY", ), 54 | "old_qk": ("INT", {"default": 0, "min": 0, "max": 1}), 55 | }} 56 | 57 | RETURN_TYPES = ("LATENT", "INJECTIONS") 58 | FUNCTION = "unsampler" 59 | 60 | CATEGORY = "sampling" 61 | 62 | def unsampler(self, model, sampler_name, steps, save_steps, scheduler, normalize, positive, latent_image, trajectories, old_qk): 63 | # DEFAULTS 64 | device = comfy.model_management.get_torch_device() 65 | 66 | cfg = 1 # hardcoded to make attention injection faster and simpler 67 | noise_seed = 777 # no noise is added 68 | negative = [] 69 | normalize = normalize == 'enable' 70 | 71 | latent = latent_image 72 | latent_image = latent["samples"] 73 | original_shape = latent_image.shape 74 | 75 | # SETUP TRANSFORMER OPTIONS 76 | injection_dict = get_blank_injection_dict( 77 | trajectories['context_windows']) 78 | 79 | def save_injections_handler(context_start): 80 | update_injections(model, injection_dict, context_start, save_steps) 81 | 82 | original_transformer_options = model.model_options.get( 83 | 'transformer_options', {}) 84 | 85 | transformer_options = { 86 | **original_transformer_options, 87 | 'flatten': { 88 | 'trajs_windows': trajectories['trajectory_windows'], 89 | 'old_qk': old_qk, 90 | 'input_shape': original_shape, 91 | 'stage': 'inversion', 92 | 'save_injections_handler': save_injections_handler 93 | } 94 | } 95 | model.model_options['transformer_options'] = transformer_options 96 | 97 | # SETUP NOISE 98 | default_noise_sampler = comfy.k_diffusion.sampling.default_noise_sampler 99 | comfy.k_diffusion.sampling.default_noise_sampler = create_noise_generator( 100 | [traj['directions'] for traj in trajectories['trajectory_windows'].values()], latent_image.shape[0]) 101 | noise = torch.zeros(latent_image.size( 102 | ), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") 103 | 104 | noise_mask = None 105 | if "noise_mask" in latent: 106 | noise_mask = latent["noise_mask"] 107 | 108 | # SETUP SAMPLING 109 | inversed_euler = sampler_name == 'inverse_euler' 110 | if inversed_euler: 111 | sampler_name = 'euler' 112 | sampler = comfy.samplers.KSampler(model, steps=steps, device=device, sampler=sampler_name, 113 | scheduler=scheduler, denoise=1.0, model_options=model.model_options) 114 | ksampler = comfy.samplers.ksampler(sampler_name) 115 | 116 | if inversed_euler: 117 | ksampler.sampler_function = sample_inversed_euler 118 | sigmas = sampler.sigmas.flip(0) 119 | else: 120 | sigmas = sampler.sigmas.flip(0) + 0.0001 121 | 122 | pbar = comfy.utils.ProgressBar(steps) 123 | 124 | def callback(step, x0, x, total_steps): 125 | pbar.update_absolute(step + 1, total_steps) 126 | disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED 127 | 128 | # UNSAMPLE MODEL 129 | try: 130 | clear_injections(model) 131 | samples = comfy.sample.sample_custom(model, noise, cfg, ksampler, sigmas, positive, negative, 132 | latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed) 133 | except Exception as e: 134 | print('Flatten Unsampler error encountereed:', e) 135 | raise e 136 | finally: 137 | # CLEANUP 138 | clear_injections(model) 139 | comfy.k_diffusion.sampling.default_noise_sampler = default_noise_sampler 140 | model.model_options['transformer_options'] = original_transformer_options 141 | del transformer_options 142 | del callback 143 | del save_injections_handler 144 | 145 | # RETURN SAMPLES 146 | if normalize: 147 | # technically doesn't normalize because unsampling is not guaranteed to end at a std given by the schedule 148 | samples -= samples.mean() 149 | samples /= samples.std() 150 | 151 | out = latent.copy() 152 | out['samples'] = samples 153 | return (out, injection_dict) 154 | -------------------------------------------------------------------------------- /nodes/flatten_ksampler_node.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import comfy.sd 3 | import comfy.model_base 4 | import comfy.samplers 5 | import comfy.sample 6 | import comfy.k_diffusion.sampling 7 | 8 | from ..utils.injection_utils import inject_features, clear_injections 9 | from ..utils.flow_noise import create_noise_generator 10 | 11 | 12 | class KSamplerFlattenNode: 13 | @classmethod 14 | def INPUT_TYPES(s): 15 | return {"required": 16 | {"model": ("MODEL",), 17 | "add_noise": (["disable", "enable"], ), 18 | "noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), 19 | "steps": ("INT", {"default": 10, "min": 1, "max": 10000}), 20 | "injection_steps": ("INT", {"default": 8, "min": 0, "max": 10000}), 21 | "old_qk": ("INT", {"default": 0, "min": 0, "max": 1}), 22 | "trajectories": ("TRAJECTORY",), 23 | "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step": 0.1, "round": 0.01}), 24 | "sampler_name": (comfy.samplers.KSampler.SAMPLERS, ), 25 | "scheduler": (comfy.samplers.KSampler.SCHEDULERS, ), 26 | "positive": ("CONDITIONING", ), 27 | "negative": ("CONDITIONING", ), 28 | "latent_image": ("LATENT", ), 29 | "injections": ("INJECTIONS",), 30 | "start_at_step": ("INT", {"default": 0, "min": 0, "max": 10000}), 31 | "end_at_step": ("INT", {"default": 10000, "min": 0, "max": 10000}), 32 | "return_with_leftover_noise": (["disable", "enable"], ), 33 | } 34 | } 35 | 36 | RETURN_TYPES = ("LATENT",) 37 | FUNCTION = "sample" 38 | 39 | CATEGORY = "sampling" 40 | 41 | injection_step = 0 42 | previous_timestep = None 43 | 44 | def sample(self, model, add_noise, noise_seed, steps, injection_steps, old_qk, trajectories, cfg, sampler_name, scheduler, positive, negative, latent_image, injections, start_at_step, end_at_step, return_with_leftover_noise, denoise=1.0): 45 | # DEFAULTS 46 | device = comfy.model_management.get_torch_device() 47 | 48 | latent = latent_image 49 | latent_image = latent["samples"] 50 | original_shape = latent_image.shape 51 | 52 | # SETUP NOISE 53 | noise_mask = None 54 | if "noise_mask" in latent: 55 | noise_mask = latent["noise_mask"] 56 | 57 | add_noise = add_noise == 'enable' 58 | if not add_noise: 59 | noise = torch.zeros(latent_image.size( 60 | ), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") 61 | else: 62 | batch_inds = latent["batch_index"] if "batch_index" in latent else None 63 | noise = comfy.sample.prepare_noise( 64 | latent_image, noise_seed, batch_inds) 65 | noise = torch.cat([noise[0].unsqueeze(0)] * original_shape[0]) 66 | 67 | # SETUP SIGMAS AND STEPS 68 | sampler = comfy.samplers.KSampler(model, steps=steps, device=device, sampler=sampler_name, 69 | scheduler=scheduler, denoise=1.0, model_options=model.model_options) 70 | 71 | sigmas = sampler.sigmas 72 | timestep_to_step = {} 73 | for i, sigma in enumerate(sigmas): 74 | t = int(model.model.model_sampling.timestep(sigma)) 75 | timestep_to_step[t] = i 76 | 77 | # FLATTEN TRANSFORMER OPTIONS 78 | original_transformer_options = model.model_options.get( 79 | 'transformer_options', {}) 80 | 81 | # step hack 82 | self.previous_timestep = None 83 | self.injection_step = -1 84 | 85 | def injection_handler(sigma, context_start, len_conds): 86 | clear_injections(model) 87 | t = int(model.model.model_sampling.timestep(sigma)) 88 | if self.previous_timestep != t: 89 | self.previous_timestep = t 90 | self.injection_step += 1 91 | if self.injection_step < injection_steps: 92 | inject_features(model, injections, device, 93 | self.injection_step, context_start, len_conds) 94 | else: 95 | clear_injections(model) 96 | return self.injection_step 97 | 98 | transformer_options = { 99 | **original_transformer_options, 100 | 'flatten': { 101 | 'trajs_windows': trajectories['trajectory_windows'], 102 | 'old_qk': old_qk, 103 | 'injection_handler': injection_handler, 104 | 'input_shape': original_shape, 105 | 'stage': 'sampling', 106 | 'injection_steps': injection_steps 107 | } 108 | } 109 | model.model_options['transformer_options'] = transformer_options 110 | 111 | # HACK NOISE 112 | default_noise_sampler = comfy.k_diffusion.sampling.default_noise_sampler 113 | comfy.k_diffusion.sampling.default_noise_sampler = create_noise_generator( 114 | [traj['directions'] for traj in trajectories['trajectory_windows'].values()], latent_image.shape[0]) 115 | 116 | # SAMPLE MODEL 117 | pbar = comfy.utils.ProgressBar(steps) 118 | 119 | def callback(step, x0, x, total_steps): 120 | pbar.update_absolute(step + 1, total_steps) 121 | 122 | disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED 123 | try: 124 | clear_injections(model) 125 | samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, 126 | denoise=denoise, disable_noise=False, start_step=start_at_step, last_step=end_at_step, 127 | force_full_denoise=not return_with_leftover_noise, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed) 128 | except Exception as e: 129 | print('Flatten KSampler error encountereed:', e) 130 | raise e 131 | finally: 132 | # CLEANUP 133 | clear_injections(model) 134 | comfy.k_diffusion.sampling.default_noise_sampler = default_noise_sampler 135 | model.model_options['transformer_options'] = original_transformer_options 136 | self.previous_timestep = None 137 | self.injection_step = 0 138 | 139 | del injection_handler 140 | del transformer_options 141 | 142 | # RETURN 143 | out = {} 144 | out["samples"] = samples 145 | return (out, ) 146 | -------------------------------------------------------------------------------- /modules/transformer_block.py: -------------------------------------------------------------------------------- 1 | from einops import rearrange, repeat 2 | from .fully_attention import FullyFrameAttention 3 | from .patch3d import apply_patch3d 4 | from comfy.ldm.modules.attention import FeedForward, CrossAttention 5 | import torch 6 | import torch.nn as nn 7 | 8 | import comfy.ops 9 | ops = comfy.ops.disable_weight_init 10 | 11 | 12 | class BasicTransformerBlock(nn.Module): 13 | def __init__( 14 | self, 15 | dim: int, 16 | n_heads: int, 17 | d_head: int, 18 | dropout=0.0, 19 | context_dim=None, 20 | gated_ff=True, 21 | checkpoint=True, 22 | ff_in=False, 23 | inner_dim=None, 24 | disable_self_attn=False, 25 | disable_temporal_crossattention=False, 26 | switch_temporal_ca_to_sa=False, 27 | dtype=None, 28 | device=None, 29 | operations=ops, 30 | num_embeds_ada_norm=None, 31 | attention_bias: bool = False, 32 | only_cross_attention: bool = False, 33 | ): 34 | super().__init__() 35 | # comfy setup 36 | self.ff_in = ff_in or inner_dim is not None 37 | if inner_dim is None: 38 | inner_dim = dim 39 | self.is_res = inner_dim == dim 40 | 41 | # flatten setup 42 | self.only_cross_attention = only_cross_attention 43 | self.use_ada_layer_norm = num_embeds_ada_norm is not None 44 | 45 | if self.ff_in: 46 | self.norm_in = operations.LayerNorm( 47 | dim, dtype=dtype, device=device) 48 | self.ff_in = FeedForward(dim, dim_out=inner_dim, dropout=dropout, 49 | glu=gated_ff, dtype=dtype, device=device, operations=operations) 50 | 51 | self.disable_self_attn = disable_self_attn 52 | # Fully 53 | self.attn1 = FullyFrameAttention( 54 | query_dim=dim, 55 | heads=n_heads, 56 | dim_head=d_head, 57 | dropout=dropout, 58 | bias=attention_bias, 59 | context_dim=context_dim if self.disable_self_attn else None, 60 | dtype=dtype, 61 | device=device, 62 | ) 63 | self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, 64 | glu=gated_ff, dtype=dtype, device=device, operations=operations) 65 | 66 | if disable_temporal_crossattention: 67 | if switch_temporal_ca_to_sa: 68 | raise ValueError 69 | else: 70 | self.norm2 = None 71 | else: 72 | context_dim_attn2 = None 73 | if not switch_temporal_ca_to_sa: 74 | context_dim_attn2 = context_dim 75 | 76 | self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2, 77 | heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none 78 | self.norm2 = operations.LayerNorm( 79 | inner_dim, dtype=dtype, device=device) 80 | 81 | self.norm1 = operations.LayerNorm( 82 | inner_dim, dtype=dtype, device=device) 83 | self.norm3 = operations.LayerNorm( 84 | inner_dim, dtype=dtype, device=device) 85 | self.checkpoint = checkpoint 86 | self.n_heads = n_heads 87 | self.d_head = d_head 88 | self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa 89 | 90 | def forward( 91 | self, 92 | x, 93 | context=None, 94 | transformer_options={}, 95 | attention_mask=None, 96 | video_length=None, 97 | inter_frame=False, 98 | traj_options=None 99 | ): 100 | # Comfy setup 101 | extra_options = {} 102 | block = transformer_options.get("block", None) 103 | block_index = transformer_options.get("block_index", 0) 104 | transformer_patches = {} 105 | transformer_patches_replace = {} 106 | 107 | if block is not None: 108 | transformer_block = (block[0], block[1], block_index) 109 | else: 110 | transformer_block = None 111 | 112 | for k in transformer_options: 113 | if k == "patches": 114 | transformer_patches = transformer_options[k] 115 | elif k == "patches_replace": 116 | transformer_patches_replace = transformer_options[k] 117 | else: 118 | extra_options[k] = transformer_options[k] 119 | 120 | extra_options["n_heads"] = self.n_heads 121 | extra_options["dim_head"] = self.d_head 122 | 123 | hidden_states = x 124 | encoder_hidden_states = context 125 | norm_hidden_states = self.norm1(hidden_states) 126 | 127 | if self.only_cross_attention: 128 | hidden_states = ( 129 | self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask, video_length=video_length, 130 | inter_frame=inter_frame, transformer_options=transformer_options, traj_options=traj_options) + hidden_states 131 | ) 132 | else: 133 | hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, 134 | video_length=video_length, inter_frame=inter_frame, transformer_options=transformer_options, traj_options=traj_options) + hidden_states 135 | 136 | if "middle_patch" in transformer_patches: 137 | patch = transformer_patches["middle_patch"] 138 | for p in patch: 139 | hidden_states = p(hidden_states, extra_options) 140 | 141 | if self.attn2 is not None: 142 | # Cross-Attention 143 | norm_hidden_states = self.norm2(hidden_states) 144 | # switch cross attention to self attention 145 | if self.switch_temporal_ca_to_sa: 146 | context_attn2 = norm_hidden_states 147 | else: 148 | context_attn2 = encoder_hidden_states 149 | 150 | value_attn2 = None 151 | if "attn2_patch" in transformer_patches: 152 | patch = transformer_patches["attn2_patch"] 153 | value_attn2 = context_attn2 154 | for p in patch: 155 | norm_hidden_states, context_attn2, value_attn2 = p( 156 | norm_hidden_states, context_attn2, value_attn2, extra_options) 157 | 158 | attn2_replace_patch = transformer_patches_replace.get("attn2", {}) 159 | block_attn2 = transformer_block 160 | if block_attn2 not in attn2_replace_patch: 161 | block_attn2 = block 162 | 163 | if block_attn2 is not None and block_attn2 in attn2_replace_patch: 164 | if value_attn2 is None: 165 | value_attn2 = context_attn2 166 | norm_hidden_states = self.attn2.to_q(norm_hidden_states) 167 | context_attn2 = self.attn2.to_k(context_attn2) 168 | value_attn2 = self.attn2.to_v(value_attn2) 169 | attn2_hidden_states = attn2_replace_patch[block_attn2]( 170 | norm_hidden_states, context_attn2, value_attn2, extra_options) 171 | hidden_states = self.attn2.to_out( 172 | attn2_hidden_states) + hidden_states 173 | else: 174 | # Flatten adds the hidden states here and after the feed-forward 175 | hidden_states = self.attn2( 176 | norm_hidden_states, context=context_attn2, mask=attention_mask) + hidden_states 177 | 178 | if "attn2_output_patch" in transformer_patches: 179 | patch = transformer_patches["attn2_output_patch"] 180 | for p in patch: 181 | norm_hidden_states = p(norm_hidden_states, extra_options) 182 | 183 | # Feed-forward 184 | hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states 185 | 186 | return hidden_states 187 | -------------------------------------------------------------------------------- /utils/trajectories.py: -------------------------------------------------------------------------------- 1 | import random 2 | from einops import rearrange 3 | import torch 4 | 5 | import torchvision.transforms.functional as F 6 | 7 | 8 | # TODO hard coded 512 9 | def preprocess(img1_batch, img2_batch, transforms): 10 | img1_batch = F.resize(img1_batch, size=[512, 512], antialias=False) 11 | img2_batch = F.resize(img2_batch, size=[512, 512], antialias=False) 12 | return transforms(img1_batch, img2_batch) 13 | 14 | 15 | def keys_with_same_value(dictionary): 16 | result = {} 17 | for key, value in dictionary.items(): 18 | if value not in result: 19 | result[value] = [key] 20 | else: 21 | result[value].append(key) 22 | 23 | conflict_points = {} 24 | for k in result.keys(): 25 | if len(result[k]) > 1: 26 | conflict_points[k] = result[k] 27 | return conflict_points 28 | 29 | 30 | def find_duplicates(input_list): 31 | seen = set() 32 | duplicates = set() 33 | 34 | for item in input_list: 35 | if item in seen: 36 | duplicates.add(item) 37 | else: 38 | seen.add(item) 39 | 40 | return list(duplicates) 41 | 42 | 43 | def neighbors_index(point, h_size, w_size, H, W): 44 | """return the spatial neighbor indices""" 45 | t, x, y = point 46 | neighbors = [] 47 | for i in range(-h_size, h_size + 1): 48 | for j in range(-w_size, w_size + 1): 49 | if i == 0 and j == 0: 50 | continue 51 | if x + i < 0 or x + i >= H or y + j < 0 or y + j >= W: 52 | continue 53 | neighbors.append((t, x + i, y + j)) 54 | return neighbors 55 | 56 | 57 | def get_window_size(resolution): 58 | # this isn't always correct and needs an actual calculation 59 | if resolution > 64: 60 | return 4 61 | elif resolution > 32: 62 | return 2 63 | else: 64 | return 1 65 | 66 | 67 | @torch.no_grad() 68 | def sample_trajectories(frames, model, weights, device): 69 | model.eval() 70 | image_height = frames.shape[1] 71 | image_width = frames.shape[2] 72 | 73 | clips = list(range(len(frames))) 74 | frames = rearrange(frames, "f h w c -> f c h w") 75 | current_frames, next_frames = frames[clips[:-1]], frames[clips[1:]] 76 | list_of_flows = model(current_frames.to(device), next_frames.to(device)) 77 | predicted_flows = list_of_flows[-1] 78 | 79 | predicted_flows[:, 0] = predicted_flows[:, 0]/image_width 80 | predicted_flows[:, 1] = predicted_flows[:, 1]/image_height 81 | 82 | height_reso = image_height//8 83 | height_resoultions = [height_reso] 84 | 85 | width_reso = image_width//8 86 | width_resolutions = [width_reso] 87 | 88 | res = {} 89 | 90 | for height_resolution, width_resolution in zip(height_resoultions, width_resolutions): 91 | trajectories = {} 92 | x_flows = torch.round(height_resolution*torch.nn.functional.interpolate( 93 | predicted_flows[:, 1].unsqueeze(1), scale_factor=(height_resolution/image_height, width_resolution/image_width))) 94 | y_flows = torch.round(width_resolution*torch.nn.functional.interpolate( 95 | predicted_flows[:, 0].unsqueeze(1), scale_factor=(height_resolution/image_height, width_resolution/image_width))) 96 | 97 | predicted_flow_resolu = torch.cat([y_flows, x_flows], dim=1) 98 | 99 | T = predicted_flow_resolu.shape[0]+1 100 | H = predicted_flow_resolu.shape[2] 101 | W = predicted_flow_resolu.shape[3] 102 | 103 | is_activated = torch.zeros([T, H, W], dtype=torch.bool) 104 | 105 | for t in range(T-1): 106 | flow = predicted_flow_resolu[t] 107 | for h in range(H): 108 | for w in range(W): 109 | 110 | if not is_activated[t, h, w]: 111 | is_activated[t, h, w] = True 112 | # this point has not been traversed, start new trajectory 113 | x = h + int(flow[1, h, w]) 114 | y = w + int(flow[0, h, w]) 115 | if x >= 0 and x < H and y >= 0 and y < W: 116 | # trajectories.append([(t, h, w), (t+1, x, y)]) 117 | trajectories[(t, h, w)] = (t+1, x, y) 118 | 119 | conflict_points = keys_with_same_value(trajectories) 120 | for k in conflict_points: 121 | index_to_pop = random.randint(0, len(conflict_points[k]) - 1) 122 | conflict_points[k].pop(index_to_pop) 123 | for point in conflict_points[k]: 124 | if point[0] != T-1: 125 | trajectories[point] = (-1, -1, -1) 126 | 127 | active_traj = [] 128 | all_traj = [] 129 | for t in range(T): 130 | pixel_set = {(t, x//H, x % H): 0 for x in range(H*W)} 131 | new_active_traj = [] 132 | for traj in active_traj: 133 | if traj[-1] in trajectories: 134 | v = trajectories[traj[-1]] 135 | new_active_traj.append(traj + [v]) 136 | pixel_set[v] = 1 137 | else: 138 | all_traj.append(traj) 139 | active_traj = new_active_traj 140 | active_traj += [[pixel] 141 | for pixel in pixel_set if pixel_set[pixel] == 0] 142 | # these are vectors from point start to point end [(t,x,y), (t+1, x,y)...] 143 | all_traj += active_traj 144 | 145 | useful_traj = [segment for segment in all_traj if len(segment) > 1] 146 | for idx in range(len(useful_traj)): 147 | if useful_traj[idx][-1] == (-1, -1, -1): 148 | useful_traj[idx] = useful_traj[idx][:-1] 149 | trajs = [] 150 | for traj in useful_traj: 151 | trajs = trajs + traj 152 | assert len(find_duplicates( 153 | trajs)) == 0, "There should not be duplicates in the useful trajectories." 154 | 155 | all_points = set([(t, x, y) for t in range(T) 156 | for x in range(H) for y in range(W)]) 157 | left_points = all_points - set(trajs) 158 | for p in list(left_points): # add points that are missing 159 | useful_traj.append([p]) 160 | 161 | longest_length = max([len(traj) for traj in useful_traj]) 162 | h_size = get_window_size(height_resolution) 163 | w_size = get_window_size(width_resolution) 164 | window_size = (h_size*2+1) * (w_size*2+1) 165 | sequence_length = window_size + longest_length - 1 166 | 167 | seqs = [] 168 | masks = [] 169 | 170 | # create a dictionary to facilitate checking the trajectories to which each point belongs. 171 | directions = {} 172 | point_to_traj = {} # point to vector/segmeent 173 | for traj in useful_traj: 174 | for p in traj: 175 | point_to_traj[p] = traj 176 | cut_traj = list(traj) 177 | while len(cut_traj) > 0 and cut_traj[0] != p: 178 | cut_traj.pop(0) 179 | directions[p] = cut_traj 180 | 181 | for t in range(T): 182 | for x in range(H): 183 | for y in range(W): 184 | neighbours = neighbors_index( 185 | (t, x, y), h_size, w_size, H, W) 186 | sequence = [(t, x, y)]+neighbours + [(0, 0, 0) 187 | for i in range(window_size-1-len(neighbours))] 188 | sequence_mask = torch.zeros( 189 | sequence_length, dtype=torch.bool) 190 | sequence_mask[:len(neighbours)+1] = True 191 | 192 | traj = point_to_traj[(t, x, y)].copy() 193 | traj.remove((t, x, y)) 194 | sequence = sequence + traj + \ 195 | [(0, 0, 0) for k in range(longest_length-1-len(traj)) 196 | ] # add (0,0,0) to fill in gaps 197 | sequence_mask[window_size:window_size + len(traj)] = True 198 | 199 | seqs.append(sequence) 200 | masks.append(sequence_mask) 201 | 202 | seqs = torch.tensor(seqs) 203 | seqs = torch.cat([seqs[:, 0, :].unsqueeze( 204 | 1), seqs[:, -len(frames)+1:, :]], dim=1) 205 | seqs = rearrange(seqs, '(f n) l d -> f n l d', f=len(frames)) 206 | masks = torch.stack(masks) 207 | masks = torch.cat([masks[:, 0].unsqueeze( 208 | 1), masks[:, -len(frames)+1:]], dim=1) 209 | masks = rearrange(masks, '(f n) l -> f n l', f=len(frames)) 210 | res["traj{}".format(height_resolution)] = seqs.cpu() 211 | res["mask{}".format(height_resolution)] = masks.cpu() 212 | res['directions'] = directions 213 | return res 214 | -------------------------------------------------------------------------------- /modules/fully_attention.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py 2 | from comfy.ldm.modules.attention import optimized_attention, optimized_attention_masked, attention_basic 3 | import torch 4 | import torch.nn 5 | import torch.nn.functional as F 6 | 7 | from typing import Optional 8 | import math 9 | from torch import nn 10 | from einops import rearrange 11 | 12 | import comfy.model_management 13 | if comfy.model_management.xformers_enabled(): 14 | import xformers 15 | import xformers.ops 16 | from comfy.cli_args import args 17 | import comfy.ops 18 | ops = comfy.ops.disable_weight_init 19 | 20 | 21 | if args.dont_upcast_attention: 22 | print("disabling upcasting of attention") 23 | _ATTN_PRECISION = "fp16" 24 | else: 25 | _ATTN_PRECISION = "fp32" 26 | 27 | 28 | class FullyFrameAttention(nn.Module): 29 | r""" 30 | A cross attention layer. 31 | 32 | Parameters: 33 | query_dim (`int`): The number of channels in the query. 34 | context_dim (`int`, *optional*): 35 | The number of channels in the context. If not given, defaults to `query_dim`. 36 | heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. 37 | dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. 38 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 39 | bias (`bool`, *optional*, defaults to False): 40 | Set to `True` for the query, key, and value linear layers to contain a bias parameter. 41 | """ 42 | 43 | def __init__( 44 | self, 45 | query_dim: int, 46 | context_dim: Optional[int] = None, 47 | heads: int = 8, 48 | dim_head: int = 64, 49 | dropout: float = 0.0, 50 | dtype=None, 51 | device=None, 52 | operations=ops, # is ops in original CrossAttention module 53 | # Flatten params 54 | bias=False, 55 | norm_num_groups: Optional[int] = None, 56 | ): 57 | super().__init__() 58 | inner_dim = dim_head * heads 59 | context_dim = context_dim if context_dim is not None else query_dim 60 | self.upcast_attention = _ATTN_PRECISION == 'fp32' # upcast_attention 61 | 62 | self.scale = dim_head**-0.5 63 | 64 | self.heads = heads 65 | # for slice_size > 0 the attention score computation 66 | # is split across the batch axis to save memory 67 | # You can set slice_size with `set_attention_slice` 68 | self.sliceable_head_dim = heads 69 | self._slice_size = None 70 | self._use_memory_efficient_attention_xformers = True 71 | 72 | if norm_num_groups is not None: 73 | self.group_norm = nn.GroupNorm( 74 | num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True) 75 | else: 76 | self.group_norm = None 77 | 78 | self.to_q = operations.Linear( 79 | query_dim, inner_dim, bias=bias, dtype=dtype, device=device) 80 | self.to_k = operations.Linear( 81 | context_dim, inner_dim, bias=bias, dtype=dtype, device=device) 82 | self.to_v = operations.Linear( 83 | context_dim, inner_dim, bias=bias, dtype=dtype, device=device) 84 | 85 | self.to_out = nn.ModuleList([]) 86 | self.to_out.append(operations.Linear( 87 | inner_dim, query_dim, dtype=dtype, device=device)) 88 | self.to_out.append(nn.Dropout(dropout)) 89 | 90 | self.q = None 91 | self.inject_q = None 92 | self.k = None 93 | self.inject_k = None 94 | 95 | def reshape_heads_to_batch_dim(self, tensor): 96 | batch_size, seq_len, dim = tensor.shape 97 | head_size = self.heads 98 | tensor = tensor.reshape(batch_size, seq_len, 99 | head_size, dim // head_size) 100 | tensor = tensor.permute(0, 2, 1, 3).reshape( 101 | batch_size * head_size, seq_len, dim // head_size) 102 | return tensor 103 | 104 | def reshape_heads_to_batch_dim3(self, tensor): 105 | batch_size1, batch_size2, seq_len, dim = tensor.shape 106 | head_size = self.heads 107 | tensor = tensor.reshape(batch_size1, batch_size2, 108 | seq_len, head_size, dim // head_size) 109 | tensor = tensor.permute(0, 3, 1, 2, 4) 110 | return tensor 111 | 112 | def reshape_batch_dim_to_heads(self, tensor): 113 | batch_size, seq_len, dim = tensor.shape 114 | head_size = self.heads 115 | tensor = tensor.reshape(batch_size // head_size, 116 | head_size, seq_len, dim) 117 | tensor = tensor.permute(0, 2, 1, 3).reshape( 118 | batch_size // head_size, seq_len, dim * head_size) 119 | return tensor 120 | 121 | def _memory_efficient_attention_xformers(self, query, key, value, attention_mask): 122 | query = query.contiguous() 123 | key = key.contiguous() 124 | value = value.contiguous() 125 | hidden_states = xformers.ops.memory_efficient_attention( 126 | query, key, value, attn_bias=attention_mask) 127 | hidden_states = self.reshape_batch_dim_to_heads(hidden_states) 128 | return hidden_states 129 | 130 | def _attention_mechanism(self, query, key, value, attention_mask): 131 | # Comfy default attention mechanism 132 | if attention_mask is not None: 133 | hidden_states = optimized_attention_masked( 134 | query, key, value, self.heads, attention_mask) 135 | else: 136 | hidden_states = optimized_attention( 137 | query, key, value, self.heads) 138 | return hidden_states 139 | 140 | def forward(self, hidden_states, context=None, value=None, attention_mask=None, video_length=None, inter_frame=False, transformer_options={}, traj_options={}): 141 | batch_size, sequence_length, _ = hidden_states.shape 142 | flatten_options = transformer_options['flatten'] 143 | 144 | transformer_block = transformer_options.get('block', ('', -1))[0] 145 | transformer_index = transformer_options.get('transformer_index', -1) 146 | patches_replace = transformer_options.get('patches_replace', {}) 147 | attn1_replace = patches_replace.get('attn1', {}) 148 | block = (transformer_block, transformer_index) 149 | if block in attn1_replace: 150 | replace_fn = attn1_replace[block] 151 | hidden_states = replace_fn( 152 | self.to_q(hidden_states), 153 | self.to_k(hidden_states), 154 | self.to_v(hidden_states), 155 | extra_options=transformer_options 156 | ) 157 | hidden_states = self.to_out[0](hidden_states) 158 | hidden_states = self.to_out[1](hidden_states) 159 | return hidden_states 160 | 161 | h = traj_options['height'] 162 | w = traj_options['width'] 163 | target_resolution = flatten_options['input_shape'][-2] 164 | if self.group_norm is not None: 165 | hidden_states = self.group_norm( 166 | hidden_states.transpose(1, 2)).transpose(1, 2) 167 | 168 | query = self.to_q(hidden_states) # (bf) x d(hw) x c 169 | self.q = query 170 | if self.inject_q is not None: 171 | query = self.inject_q 172 | 173 | query_old = None 174 | if flatten_options['old_qk'] == 1: 175 | query_old = query.clone() 176 | 177 | context = context if context is not None else hidden_states 178 | key = self.to_k(context) 179 | self.k = key 180 | if self.inject_k is not None: 181 | key = self.inject_k 182 | 183 | key_old = None 184 | if flatten_options['old_qk'] == 1: 185 | key_old = key.clone() 186 | value = self.to_v(context) 187 | 188 | query = rearrange(query, "(b f) d c -> b (f d) c", f=video_length) 189 | key = rearrange(key, "(b f) d c -> b (f d) c", f=video_length) 190 | value = rearrange(value, "(b f) d c -> b (f d) c", f=video_length) 191 | 192 | if attention_mask is not None: 193 | if attention_mask.shape[-1] != query.shape[1]: 194 | target_length = query.shape[1] 195 | attention_mask = F.pad( 196 | attention_mask, (0, target_length), value=0.0) 197 | attention_mask = attention_mask.repeat_interleave( 198 | self.heads, dim=0) 199 | 200 | hidden_states = self._attention_mechanism( 201 | query, key, value, attention_mask) 202 | query = self.reshape_heads_to_batch_dim(query) 203 | key = self.reshape_heads_to_batch_dim(key) 204 | value = self.reshape_heads_to_batch_dim(value) 205 | 206 | if h in [target_resolution]: 207 | hidden_states = rearrange( 208 | hidden_states, "b (f d) c -> (b f) d c", f=video_length) 209 | if self.group_norm is not None: 210 | hidden_states = self.group_norm( 211 | hidden_states.transpose(1, 2)).transpose(1, 2) 212 | 213 | if flatten_options['old_qk'] == 1: 214 | query = query_old 215 | key = key_old 216 | else: 217 | query = hidden_states 218 | key = hidden_states 219 | value = hidden_states 220 | 221 | cond_size = traj_options['cond_size'] 222 | resolu = traj_options['resolution'] 223 | trajs = flatten_options['trajs'][f'traj{resolu}'] 224 | traj_mask = flatten_options['trajs'][f'mask{resolu}'] 225 | 226 | start = -video_length+1 227 | end = trajs.shape[2] 228 | 229 | traj_key_sequence_inds = torch.cat( 230 | [trajs[:, :, 0, :].unsqueeze(-2), trajs[:, :, start:end, :]], dim=-2) 231 | traj_mask = torch.cat([traj_mask[:, :, 0].unsqueeze(-1), 232 | traj_mask[:, :, start:end]], dim=-1) 233 | 234 | t_inds = traj_key_sequence_inds[:, :, :, 0] 235 | x_inds = traj_key_sequence_inds[:, :, :, 1] 236 | y_inds = traj_key_sequence_inds[:, :, :, 2] 237 | 238 | query_tempo = query.unsqueeze(-2) 239 | _key = rearrange(key, '(b f) (h w) d -> b f h w d', 240 | b=int(batch_size/video_length), f=video_length, h=h, w=w) 241 | _value = rearrange(value, '(b f) (h w) d -> b f h w d', 242 | b=int(batch_size/video_length), f=video_length, h=h, w=w) 243 | key_tempo = _key[:, t_inds, x_inds, y_inds] # This fails 244 | value_tempo = _value[:, t_inds, x_inds, y_inds] 245 | key_tempo = rearrange(key_tempo, 'b f n l d -> (b f) n l d') 246 | value_tempo = rearrange(value_tempo, 'b f n l d -> (b f) n l d') 247 | 248 | traj_mask = rearrange(torch.stack( 249 | [traj_mask] * cond_size), 'b f n l -> (b f) n l') 250 | traj_mask = traj_mask[:, None].repeat( 251 | 1, self.heads, 1, 1).unsqueeze(-2) 252 | attn_bias = torch.zeros_like( 253 | traj_mask, dtype=key_tempo.dtype, device=query.device) # regular zeros_like 254 | attn_bias[~traj_mask] = -torch.inf 255 | 256 | # flow attention 257 | query_tempo = self.reshape_heads_to_batch_dim3(query_tempo) 258 | key_tempo = self.reshape_heads_to_batch_dim3(key_tempo) 259 | value_tempo = self.reshape_heads_to_batch_dim3(value_tempo) 260 | 261 | attn_matrix2 = query_tempo @ key_tempo.transpose(-2, -1) / math.sqrt( 262 | query_tempo.size(-1)) + attn_bias 263 | attn_matrix2 = F.softmax(attn_matrix2, dim=-1) 264 | out = (attn_matrix2@value_tempo).squeeze(-2) 265 | 266 | hidden_states = rearrange(out, '(b f) k (h w) d -> b (f h w) (k d)', b=int( 267 | batch_size/video_length), f=video_length, h=h, w=w) 268 | 269 | # linear proj 270 | hidden_states = self.to_out[0](hidden_states) 271 | 272 | # dropout 273 | hidden_states = self.to_out[1](hidden_states) 274 | 275 | # All frames 276 | hidden_states = rearrange( 277 | hidden_states, "b (f d) c -> (b f) d c", f=video_length) 278 | 279 | return hidden_states 280 | -------------------------------------------------------------------------------- /modules/flatten_cldm.py: -------------------------------------------------------------------------------- 1 | # taken from: https://github.com/lllyasviel/ControlNet 2 | # and modified with Flatten modules 3 | # Mostly an experiment 4 | 5 | from einops import rearrange 6 | import torch 7 | import torch as th 8 | import torch.nn as nn 9 | 10 | import comfy.ops 11 | from comfy.ldm.modules.diffusionmodules.util import ( 12 | zero_module, 13 | timestep_embedding, 14 | ) 15 | from comfy.ldm.util import exists 16 | 17 | from .unet import TimestepEmbedSequential 18 | from .convs import InflatedConv3d 19 | from .transformer3d import Transformer3DModel 20 | from .downsample3d import Downsample3D 21 | from .resnet_block3d import ResnetBlock3D 22 | 23 | 24 | class FlattenControlNet(nn.Module): 25 | def __init__( 26 | self, 27 | image_size, 28 | in_channels, 29 | model_channels, 30 | hint_channels, 31 | num_res_blocks, 32 | dropout=0, 33 | channel_mult=(1, 2, 4, 8), 34 | conv_resample=True, 35 | dims=2, 36 | num_classes=None, 37 | use_checkpoint=False, 38 | dtype=torch.float32, 39 | num_heads=-1, 40 | num_head_channels=-1, 41 | num_heads_upsample=-1, 42 | use_scale_shift_norm=False, 43 | resblock_updown=False, 44 | use_new_attention_order=False, 45 | use_spatial_transformer=False, # custom transformer support 46 | transformer_depth=1, # custom transformer support 47 | context_dim=None, # custom transformer support 48 | # custom support for prediction of discrete ids into codebook of first stage vq model 49 | n_embed=None, 50 | legacy=True, 51 | disable_self_attentions=None, 52 | num_attention_blocks=None, 53 | disable_middle_self_attn=False, 54 | use_linear_in_transformer=False, 55 | adm_in_channels=None, 56 | transformer_depth_middle=None, 57 | transformer_depth_output=None, 58 | device=None, 59 | operations=comfy.ops.disable_weight_init, 60 | **kwargs, 61 | ): 62 | super().__init__() 63 | assert use_spatial_transformer == True, "use_spatial_transformer has to be true" 64 | if use_spatial_transformer: 65 | assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' 66 | 67 | if context_dim is not None: 68 | assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' 69 | # from omegaconf.listconfig import ListConfig 70 | # if type(context_dim) == ListConfig: 71 | # context_dim = list(context_dim) 72 | 73 | if num_heads_upsample == -1: 74 | num_heads_upsample = num_heads 75 | 76 | if num_heads == -1: 77 | assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' 78 | 79 | if num_head_channels == -1: 80 | assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' 81 | 82 | self.dims = dims 83 | self.image_size = image_size 84 | self.in_channels = in_channels 85 | self.model_channels = model_channels 86 | 87 | if isinstance(num_res_blocks, int): 88 | self.num_res_blocks = len(channel_mult) * [num_res_blocks] 89 | else: 90 | if len(num_res_blocks) != len(channel_mult): 91 | raise ValueError("provide num_res_blocks either as an int (globally constant) or " 92 | "as a list/tuple (per-level) with the same length as channel_mult") 93 | self.num_res_blocks = num_res_blocks 94 | 95 | if disable_self_attentions is not None: 96 | # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not 97 | assert len(disable_self_attentions) == len(channel_mult) 98 | if num_attention_blocks is not None: 99 | assert len(num_attention_blocks) == len(self.num_res_blocks) 100 | assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range( 101 | len(num_attention_blocks)))) 102 | 103 | transformer_depth = transformer_depth[:] 104 | 105 | self.dropout = dropout 106 | self.channel_mult = channel_mult 107 | self.conv_resample = conv_resample 108 | self.num_classes = num_classes 109 | self.use_checkpoint = use_checkpoint 110 | self.dtype = dtype 111 | self.num_heads = num_heads 112 | self.num_head_channels = num_head_channels 113 | self.num_heads_upsample = num_heads_upsample 114 | self.predict_codebook_ids = n_embed is not None 115 | 116 | time_embed_dim = model_channels * 4 117 | self.time_embed = nn.Sequential( 118 | operations.Linear(model_channels, time_embed_dim, 119 | dtype=self.dtype, device=device), 120 | nn.SiLU(), 121 | operations.Linear(time_embed_dim, time_embed_dim, 122 | dtype=self.dtype, device=device), 123 | ) 124 | 125 | if self.num_classes is not None: 126 | if isinstance(self.num_classes, int): 127 | self.label_emb = nn.Embedding(num_classes, time_embed_dim) 128 | elif self.num_classes == "continuous": 129 | print("setting up linear c_adm embedding layer") 130 | self.label_emb = nn.Linear(1, time_embed_dim) 131 | elif self.num_classes == "sequential": 132 | assert adm_in_channels is not None 133 | self.label_emb = nn.Sequential( 134 | nn.Sequential( 135 | operations.Linear( 136 | adm_in_channels, time_embed_dim, dtype=self.dtype, device=device), 137 | nn.SiLU(), 138 | operations.Linear( 139 | time_embed_dim, time_embed_dim, dtype=self.dtype, device=device), 140 | ) 141 | ) 142 | else: 143 | raise ValueError() 144 | 145 | self.input_blocks = nn.ModuleList( 146 | [ 147 | TimestepEmbedSequential( 148 | InflatedConv3d(in_channels, model_channels, 149 | 3, padding=1).half() 150 | ) 151 | ] 152 | ) 153 | self.zero_convs = nn.ModuleList([self.make_zero_conv( 154 | model_channels, operations=operations, dtype=self.dtype, device=device)]) 155 | 156 | self.input_hint_block = TimestepEmbedSequential( 157 | InflatedConv3d(hint_channels, 16, 3, padding=1).half(), 158 | nn.SiLU(), 159 | InflatedConv3d(16, 16, 3, padding=1).half(), 160 | nn.SiLU(), 161 | InflatedConv3d(16, 32, 3, padding=1, stride=2).half(), 162 | nn.SiLU(), 163 | InflatedConv3d(32, 32, 3, padding=1).half(), 164 | nn.SiLU(), 165 | InflatedConv3d(32, 96, 3, padding=1, stride=2).half(), 166 | nn.SiLU(), 167 | InflatedConv3d(96, 96, 3, padding=1).half(), 168 | nn.SiLU(), 169 | InflatedConv3d(96, 256, 3, padding=1, stride=2).half(), 170 | nn.SiLU(), 171 | InflatedConv3d(256, model_channels, 3, padding=1).half(), 172 | ) 173 | 174 | self._feature_size = model_channels 175 | input_block_chans = [model_channels] 176 | ch = model_channels 177 | ds = 1 178 | for level, mult in enumerate(channel_mult): 179 | for nr in range(self.num_res_blocks[level]): 180 | layers = [ 181 | ResnetBlock3D( 182 | ch, 183 | time_embed_dim, 184 | dropout, 185 | out_channels=mult * model_channels, 186 | dims=dims, 187 | use_checkpoint=use_checkpoint, 188 | use_scale_shift_norm=use_scale_shift_norm, 189 | dtype=self.dtype, 190 | device=device, 191 | operations=operations, 192 | ) 193 | ] 194 | ch = mult * model_channels 195 | num_transformers = transformer_depth.pop(0) 196 | if num_transformers > 0: 197 | if num_head_channels == -1: 198 | dim_head = ch // num_heads 199 | else: 200 | num_heads = ch // num_head_channels 201 | dim_head = num_head_channels 202 | if legacy: 203 | # num_heads = 1 204 | dim_head = ch // num_heads if use_spatial_transformer else num_head_channels 205 | if exists(disable_self_attentions): 206 | disabled_sa = disable_self_attentions[level] 207 | else: 208 | disabled_sa = False 209 | 210 | if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: 211 | layers.append( 212 | Transformer3DModel( 213 | ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim, 214 | disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, 215 | use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations 216 | ) 217 | ) 218 | self.input_blocks.append(TimestepEmbedSequential(*layers)) 219 | self.zero_convs.append(self.make_zero_conv( 220 | ch, operations=operations, dtype=self.dtype, device=device)) 221 | self._feature_size += ch 222 | input_block_chans.append(ch) 223 | if level != len(channel_mult) - 1: 224 | out_ch = ch 225 | self.input_blocks.append( 226 | TimestepEmbedSequential( 227 | ResnetBlock3D( 228 | ch, 229 | time_embed_dim, 230 | dropout, 231 | out_channels=out_ch, 232 | dims=dims, 233 | use_checkpoint=use_checkpoint, 234 | use_scale_shift_norm=use_scale_shift_norm, 235 | down=True, 236 | dtype=self.dtype, 237 | device=device, 238 | operations=operations 239 | ) 240 | if resblock_updown 241 | else Downsample3D( 242 | ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device, operations=operations 243 | ) 244 | ) 245 | ) 246 | ch = out_ch 247 | input_block_chans.append(ch) 248 | self.zero_convs.append(self.make_zero_conv( 249 | ch, operations=operations, dtype=self.dtype, device=device)) 250 | ds *= 2 251 | self._feature_size += ch 252 | 253 | if num_head_channels == -1: 254 | dim_head = ch // num_heads 255 | else: 256 | num_heads = ch // num_head_channels 257 | dim_head = num_head_channels 258 | if legacy: 259 | # num_heads = 1 260 | dim_head = ch // num_heads if use_spatial_transformer else num_head_channels 261 | mid_block = [ 262 | ResnetBlock3D( 263 | ch, 264 | time_embed_dim, 265 | dropout, 266 | dims=dims, 267 | use_checkpoint=use_checkpoint, 268 | use_scale_shift_norm=use_scale_shift_norm, 269 | dtype=self.dtype, 270 | device=device, 271 | operations=operations 272 | )] 273 | if transformer_depth_middle >= 0: 274 | mid_block += [Transformer3DModel( # always uses a self-attn 275 | ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim, 276 | disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, 277 | use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations 278 | ), 279 | ResnetBlock3D( 280 | ch, 281 | time_embed_dim, 282 | dropout, 283 | dims=dims, 284 | use_checkpoint=use_checkpoint, 285 | use_scale_shift_norm=use_scale_shift_norm, 286 | dtype=self.dtype, 287 | device=device, 288 | operations=operations 289 | )] 290 | self.middle_block = TimestepEmbedSequential(*mid_block) 291 | self.middle_block_out = self.make_zero_conv( 292 | ch, operations=operations, dtype=self.dtype, device=device) 293 | self._feature_size += ch 294 | self.injection = {} 295 | self.timestep = -1 296 | self.step = 0 297 | 298 | def make_zero_conv(self, channels, operations=None, dtype=None, device=None): 299 | return TimestepEmbedSequential(InflatedConv3d(channels, channels, 1, padding=0).half()) 300 | 301 | def forward(self, x, hint, timesteps, context, y=None, model_options={}, batched_number=1, **kwargs): 302 | transformer_options = model_options['transformer_options'] 303 | original_shape = transformer_options['flatten']['original_shape'] 304 | video_length = original_shape[0] 305 | 306 | stage = transformer_options['flatten'].get('stage', None) 307 | injection_steps = transformer_options['flatten'].get( 308 | 'injection_steps', 0) 309 | step = None 310 | if stage == 'inversion' or stage is None: 311 | self.injection = { 312 | 'input': {}, 313 | 'middle': {} 314 | } 315 | for i in range(len(self.input_blocks)): 316 | self.injection['input'][i] = {} 317 | elif stage == 'sampling': 318 | if int(timesteps[0]) > self.timestep: 319 | self.timestep = int(timesteps[0]) 320 | self.step = 0 321 | elif int(timesteps[0]) < self.timestep: 322 | self.timestep = int(timesteps[0]) 323 | self.step += 1 324 | step = self.step 325 | 326 | incoming_shape = x.shape 327 | if len(incoming_shape) == 4: 328 | # TODO get video_length in here 329 | x = rearrange(x, '(b f) c h w -> b c f h w', f=video_length) 330 | if len(hint.shape) == 4: 331 | hint = rearrange(hint, '(b f) c h w -> b c f h w', f=video_length) 332 | 333 | cond_length = x.shape[0] 334 | t_emb = timestep_embedding( 335 | timesteps, self.model_channels, repeat_only=False).to(x.dtype) 336 | emb = self.time_embed(t_emb) 337 | 338 | guided_hint = self.input_hint_block(hint, emb, context) 339 | 340 | outs = [] 341 | 342 | hs = [] 343 | if self.num_classes is not None: 344 | assert y.shape[0] == x.shape[0] 345 | emb = emb + self.label_emb(y) 346 | 347 | h = x 348 | i = 0 349 | for k, attn in enumerate(self.named_modules()): 350 | if hasattr(attn, 'inject_q') and hasattr(attn, 'inject_k'): 351 | attn.inject_q = None 352 | attn.inject_k = None 353 | 354 | for module, zero_conv in zip(self.input_blocks, self.zero_convs): 355 | if stage == 'sampling': 356 | for k, attn in enumerate(self.named_modules()): 357 | if hasattr(attn, 'inject_q') and hasattr(attn, 'inject_k'): 358 | if step < injection_steps: 359 | attn.inject_q = torch.cat( 360 | [self.injection['input'][i][k]['q'][step]]*cond_length).to('cuda') 361 | attn.inject_k = torch.cat( 362 | [self.injection['input'][i][k]['k'][step]]*cond_length).to('cuda') 363 | if guided_hint is not None: 364 | h = module(h, emb, context, 365 | transformer_options=transformer_options) 366 | h += guided_hint 367 | guided_hint = None 368 | else: 369 | h = module(h, emb, context, 370 | transformer_options=transformer_options) 371 | result = rearrange(zero_conv(h, emb, context), 372 | 'b c f h w -> (b f) c h w') 373 | outs.append(result) 374 | if stage == 'inversion': 375 | for k, attn in enumerate(self.input_blocks.named_modules()): 376 | if hasattr(attn, 'q') and hasattr(attn, 'k'): 377 | if k not in self.injection['input'][i]: 378 | self.injection['input'][i][k] = {'q': [], 'k': []} 379 | self.injection['input'][i][k]['q'].append(attn.q.cpu()) 380 | self.injection['input'][i][k]['k'].append(attn.k.cpu()) 381 | i += 1 382 | 383 | if stage == 'sampling' and step < injection_steps: 384 | for k, attn in enumerate(self.middle_block.named_modules()): 385 | if hasattr(attn, 'inject_q') and hasattr(attn, 'inject_k'): 386 | attn.inject_q = torch.cat( 387 | [self.injection['middle'][k]['q'][step]]*cond_length).to('cuda') 388 | attn.inject_k = torch.cat( 389 | [self.injection['middle'][k]['k'][step]]*cond_length).to('cuda') 390 | h = self.middle_block( 391 | h, emb, context, transformer_options=transformer_options) 392 | if stage == 'inversion': 393 | for k, attn in enumerate(self.named_modules()): 394 | if hasattr(attn, 'q') and hasattr(attn, 'k'): 395 | if k not in self.injection['middle']: 396 | self.injection['middle'][k] = {'q': [], 'k': []} 397 | self.injection['middle'][k]['q'].append(attn.q.cpu()) 398 | self.injection['middle'][k]['k'].append(attn.k.cpu()) 399 | result = rearrange(self.middle_block_out( 400 | h, emb, context), 'b c f h w -> (b f) c h w') 401 | outs.append(result) 402 | 403 | return outs 404 | -------------------------------------------------------------------------------- /modules/unet.py: -------------------------------------------------------------------------------- 1 | from einops import rearrange 2 | from .resnet_block3d import ResnetBlock3D 3 | from .downsample3d import Downsample3D 4 | from .upsample3d import Upsample3D 5 | from .convs import InflatedConv3d 6 | from .transformer3d import Transformer3DModel 7 | from .patch3d import apply_unet_patch3d, apply_patch3d 8 | 9 | import torch as th 10 | import torch.nn as nn 11 | 12 | from comfy.ldm.modules.diffusionmodules.util import ( 13 | zero_module, 14 | timestep_embedding, 15 | ) 16 | from comfy.ldm.util import exists 17 | from comfy.ldm.modules.diffusionmodules.openaimodel import TimestepBlock, apply_control as apply_control_2d 18 | import comfy.ops 19 | ops = comfy.ops.disable_weight_init 20 | 21 | 22 | def apply_control(hsp, control, block_type): 23 | b = hsp.shape[0] 24 | hsp = rearrange(hsp, 'b c f h w -> (b f) c h w') 25 | hsp = apply_control_2d(hsp, control, block_type) 26 | hsp = rearrange(hsp, '(b f) c h w -> b c f h w', b=b) 27 | return hsp 28 | 29 | 30 | def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None, time_context=None, num_video_frames=None, image_only_indicator=None): 31 | for layer in ts: 32 | if isinstance(layer, TimestepBlock): 33 | x = layer(x, emb) 34 | elif isinstance(layer, ResnetBlock3D): 35 | x = layer(x, emb, transformer_options=transformer_options) 36 | elif isinstance(layer, Transformer3DModel): 37 | x = layer(x, context, transformer_options) 38 | if "transformer_index" in transformer_options: 39 | transformer_options["transformer_index"] += 1 40 | elif isinstance(layer, Upsample3D): 41 | x = layer(x, output_shape=output_shape) 42 | else: 43 | x = layer(x) 44 | return x 45 | 46 | 47 | class TimestepEmbedSequential(nn.Sequential, TimestepBlock): 48 | """ 49 | A sequential module that passes timestep embeddings to the children that 50 | support it as an extra input. 51 | """ 52 | 53 | def forward(self, *args, **kwargs): 54 | return forward_timestep_embed(self, *args, **kwargs) 55 | 56 | 57 | class UNetModel(nn.Module): 58 | """ 59 | The full UNet model with attention and timestep embedding. 60 | :param in_channels: channels in the input Tensor. 61 | :param model_channels: base channel count for the model. 62 | :param out_channels: channels in the output Tensor. 63 | :param num_res_blocks: number of residual blocks per downsample. 64 | :param dropout: the dropout probability. 65 | :param channel_mult: channel multiplier for each level of the UNet. 66 | :param conv_resample: if True, use learned convolutions for upsampling and 67 | downsampling. 68 | :param dims: determines if the signal is 1D, 2D, or 3D. 69 | :param num_classes: if specified (as an int), then this model will be 70 | class-conditional with `num_classes` classes. 71 | :param use_checkpoint: use gradient checkpointing to reduce memory usage. 72 | :param num_heads: the number of attention heads in each attention layer. 73 | :param num_heads_channels: if specified, ignore num_heads and instead use 74 | a fixed channel width per attention head. 75 | :param num_heads_upsample: works with num_heads to set a different number 76 | of heads for upsampling. Deprecated. 77 | :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. 78 | :param resblock_updown: use residual blocks for up/downsampling. 79 | :param use_new_attention_order: use a different attention pattern for potentially 80 | increased efficiency. 81 | """ 82 | 83 | def __init__( 84 | self, 85 | image_size, 86 | in_channels, 87 | model_channels, 88 | out_channels, 89 | num_res_blocks, 90 | dropout=0, 91 | channel_mult=(1, 2, 4, 8), 92 | conv_resample=True, 93 | dims=2, 94 | num_classes=None, 95 | use_checkpoint=False, 96 | dtype=th.float32, 97 | num_heads=-1, 98 | num_head_channels=-1, 99 | num_heads_upsample=-1, 100 | use_scale_shift_norm=False, 101 | resblock_updown=False, 102 | use_new_attention_order=False, 103 | use_spatial_transformer=False, # custom transformer support 104 | transformer_depth=1, # custom transformer support 105 | context_dim=None, # custom transformer support 106 | # custom support for prediction of discrete ids into codebook of first stage vq model 107 | n_embed=None, 108 | legacy=True, 109 | disable_self_attentions=None, 110 | num_attention_blocks=None, 111 | disable_middle_self_attn=False, 112 | use_linear_in_transformer=False, 113 | adm_in_channels=None, 114 | transformer_depth_middle=None, 115 | transformer_depth_output=None, 116 | use_temporal_resblock=False, 117 | use_temporal_attention=False, 118 | time_context_dim=None, 119 | extra_ff_mix_layer=False, 120 | use_spatial_context=False, 121 | merge_strategy=None, 122 | merge_factor=0.0, 123 | video_kernel_size=None, 124 | disable_temporal_crossattention=False, 125 | max_ddpm_temb_period=10000, 126 | device=None, 127 | operations=ops, 128 | ): 129 | super().__init__() 130 | 131 | if context_dim is not None: 132 | assert use_spatial_transformer, 'You forgot to use the spatial transformer for your cross-attention conditioning...' 133 | 134 | if num_heads_upsample == -1: 135 | num_heads_upsample = num_heads 136 | 137 | if num_heads == -1: 138 | assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' 139 | 140 | if num_head_channels == -1: 141 | assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' 142 | 143 | self.in_channels = in_channels 144 | self.model_channels = model_channels 145 | self.out_channels = out_channels 146 | 147 | if isinstance(num_res_blocks, int): 148 | self.num_res_blocks = len(channel_mult) * [num_res_blocks] 149 | else: 150 | if len(num_res_blocks) != len(channel_mult): 151 | raise ValueError("provide num_res_blocks either as an int (globally constant) or " 152 | "as a list/tuple (per-level) with the same length as channel_mult") 153 | self.num_res_blocks = num_res_blocks 154 | 155 | if disable_self_attentions is not None: 156 | # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not 157 | assert len(disable_self_attentions) == len(channel_mult) 158 | if num_attention_blocks is not None: 159 | assert len(num_attention_blocks) == len(self.num_res_blocks) 160 | 161 | transformer_depth = transformer_depth[:] 162 | transformer_depth_output = transformer_depth_output[:] 163 | 164 | self.dropout = dropout 165 | self.channel_mult = channel_mult 166 | self.conv_resample = conv_resample 167 | self.num_classes = num_classes 168 | self.use_checkpoint = use_checkpoint 169 | self.dtype = dtype 170 | self.num_heads = num_heads 171 | self.num_head_channels = num_head_channels 172 | self.num_heads_upsample = num_heads_upsample 173 | self.use_temporal_resblocks = use_temporal_resblock 174 | self.predict_codebook_ids = n_embed is not None 175 | 176 | self.default_num_video_frames = None 177 | self.default_image_only_indicator = None 178 | 179 | time_embed_dim = model_channels * 4 180 | self.time_embed = nn.Sequential( 181 | operations.Linear(model_channels, time_embed_dim, 182 | dtype=self.dtype, device=device), 183 | nn.SiLU(), 184 | operations.Linear(time_embed_dim, time_embed_dim, 185 | dtype=self.dtype, device=device), 186 | ) 187 | 188 | if self.num_classes is not None: 189 | if isinstance(self.num_classes, int): 190 | self.label_emb = nn.Embedding( 191 | num_classes, time_embed_dim, dtype=self.dtype, device=device) 192 | elif self.num_classes == "continuous": 193 | print("setting up linear c_adm embedding layer") 194 | self.label_emb = nn.Linear(1, time_embed_dim) 195 | elif self.num_classes == "sequential": 196 | assert adm_in_channels is not None 197 | self.label_emb = nn.Sequential( 198 | nn.Sequential( 199 | operations.Linear( 200 | adm_in_channels, time_embed_dim, dtype=self.dtype, device=device), 201 | nn.SiLU(), 202 | operations.Linear( 203 | time_embed_dim, time_embed_dim, dtype=self.dtype, device=device), 204 | ) 205 | ) 206 | else: 207 | raise ValueError() 208 | 209 | self.input_blocks = nn.ModuleList( 210 | [ 211 | TimestepEmbedSequential( 212 | InflatedConv3d( 213 | in_channels, model_channels, kernel_size=3, padding=(1, 1)) 214 | ).half() 215 | ] 216 | ) 217 | self._feature_size = model_channels 218 | input_block_chans = [model_channels] 219 | ch = model_channels 220 | ds = 1 221 | 222 | def get_attention_layer( 223 | ch, 224 | num_heads, 225 | dim_head, 226 | depth=1, 227 | context_dim=None, 228 | use_checkpoint=False, 229 | disable_self_attn=False, 230 | ): 231 | 232 | return Transformer3DModel( 233 | ch, num_heads, dim_head, depth=depth, context_dim=context_dim, 234 | disable_self_attn=disable_self_attn, use_linear=use_linear_in_transformer, 235 | use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations 236 | ) 237 | 238 | def get_resblock( 239 | ch, 240 | time_embed_dim, 241 | dropout, 242 | out_channels, 243 | dims, 244 | use_checkpoint, 245 | use_scale_shift_norm, 246 | down=False, 247 | up=False, 248 | dtype=None, 249 | device=None, 250 | operations=ops 251 | ): 252 | return ResnetBlock3D( 253 | channels=ch, 254 | emb_channels=time_embed_dim, 255 | dropout=dropout, 256 | out_channels=out_channels, 257 | use_checkpoint=use_checkpoint, 258 | dims=dims, 259 | use_scale_shift_norm=use_scale_shift_norm, 260 | down=down, 261 | up=up, 262 | dtype=dtype, 263 | device=device, 264 | operations=operations 265 | ) 266 | 267 | for level, mult in enumerate(channel_mult): 268 | for nr in range(self.num_res_blocks[level]): 269 | layers = [ 270 | get_resblock( 271 | ch=ch, 272 | time_embed_dim=time_embed_dim, 273 | dropout=dropout, 274 | out_channels=mult * model_channels, 275 | dims=dims, 276 | use_checkpoint=use_checkpoint, 277 | use_scale_shift_norm=use_scale_shift_norm, 278 | dtype=self.dtype, 279 | device=device, 280 | operations=operations, 281 | ) 282 | ] 283 | ch = mult * model_channels 284 | num_transformers = transformer_depth.pop(0) 285 | if num_transformers > 0: 286 | if num_head_channels == -1: 287 | dim_head = ch // num_heads 288 | else: 289 | num_heads = ch // num_head_channels 290 | dim_head = num_head_channels 291 | if legacy: 292 | # num_heads = 1 293 | dim_head = ch // num_heads if use_spatial_transformer else num_head_channels 294 | if exists(disable_self_attentions): 295 | disabled_sa = disable_self_attentions[level] 296 | else: 297 | disabled_sa = False 298 | 299 | if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: 300 | layers.append(get_attention_layer( 301 | ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim, 302 | disable_self_attn=disabled_sa, use_checkpoint=use_checkpoint) 303 | ) 304 | self.input_blocks.append(TimestepEmbedSequential(*layers)) 305 | self._feature_size += ch 306 | input_block_chans.append(ch) 307 | if level != len(channel_mult) - 1: 308 | out_ch = ch 309 | self.input_blocks.append( 310 | TimestepEmbedSequential( 311 | get_resblock( 312 | ch=ch, 313 | time_embed_dim=time_embed_dim, 314 | dropout=dropout, 315 | out_channels=out_ch, 316 | dims=dims, 317 | use_checkpoint=use_checkpoint, 318 | use_scale_shift_norm=use_scale_shift_norm, 319 | down=True, 320 | dtype=self.dtype, 321 | device=device, 322 | operations=operations 323 | ) 324 | if resblock_updown 325 | else Downsample3D( 326 | ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device, operations=operations 327 | ) 328 | ) 329 | ) 330 | ch = out_ch 331 | input_block_chans.append(ch) 332 | ds *= 2 333 | self._feature_size += ch 334 | 335 | if num_head_channels == -1: 336 | dim_head = ch // num_heads 337 | else: 338 | num_heads = ch // num_head_channels 339 | dim_head = num_head_channels 340 | if legacy: 341 | # num_heads = 1 342 | dim_head = ch // num_heads if use_spatial_transformer else num_head_channels 343 | mid_block = [ 344 | get_resblock( 345 | ch=ch, 346 | time_embed_dim=time_embed_dim, 347 | dropout=dropout, 348 | out_channels=None, 349 | dims=dims, 350 | use_checkpoint=use_checkpoint, 351 | use_scale_shift_norm=use_scale_shift_norm, 352 | dtype=self.dtype, 353 | device=device, 354 | operations=operations 355 | )] 356 | if transformer_depth_middle >= 0: 357 | mid_block += [get_attention_layer( # always uses a self-attn 358 | ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim, 359 | disable_self_attn=disable_middle_self_attn, use_checkpoint=use_checkpoint 360 | ), 361 | get_resblock( 362 | ch=ch, 363 | time_embed_dim=time_embed_dim, 364 | dropout=dropout, 365 | out_channels=None, 366 | dims=dims, 367 | use_checkpoint=use_checkpoint, 368 | use_scale_shift_norm=use_scale_shift_norm, 369 | dtype=self.dtype, 370 | device=device, 371 | operations=operations 372 | )] 373 | self.middle_block = TimestepEmbedSequential(*mid_block) 374 | self._feature_size += ch 375 | 376 | self.num_upsamplers = 0 377 | self.output_blocks = nn.ModuleList([]) 378 | for level, mult in list(enumerate(channel_mult))[::-1]: 379 | for i in range(self.num_res_blocks[level] + 1): 380 | ich = input_block_chans.pop() 381 | layers = [ 382 | get_resblock( 383 | ch=ch + ich, 384 | time_embed_dim=time_embed_dim, 385 | dropout=dropout, 386 | out_channels=model_channels * mult, 387 | dims=dims, 388 | use_checkpoint=use_checkpoint, 389 | use_scale_shift_norm=use_scale_shift_norm, 390 | dtype=self.dtype, 391 | device=device, 392 | operations=operations 393 | ) 394 | ] 395 | ch = model_channels * mult 396 | num_transformers = transformer_depth_output.pop() 397 | if num_transformers > 0: 398 | if num_head_channels == -1: 399 | dim_head = ch // num_heads 400 | else: 401 | num_heads = ch // num_head_channels 402 | dim_head = num_head_channels 403 | if legacy: 404 | # num_heads = 1 405 | dim_head = ch // num_heads if use_spatial_transformer else num_head_channels 406 | if exists(disable_self_attentions): 407 | disabled_sa = disable_self_attentions[level] 408 | else: 409 | disabled_sa = False 410 | 411 | if not exists(num_attention_blocks) or i < num_attention_blocks[level]: 412 | layers.append( 413 | get_attention_layer( 414 | ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim, 415 | disable_self_attn=disabled_sa, use_checkpoint=use_checkpoint 416 | ) 417 | ) 418 | if level and i == self.num_res_blocks[level]: 419 | out_ch = ch 420 | layers.append( 421 | get_resblock( 422 | ch=ch, 423 | time_embed_dim=time_embed_dim, 424 | dropout=dropout, 425 | out_channels=out_ch, 426 | dims=dims, 427 | use_checkpoint=use_checkpoint, 428 | use_scale_shift_norm=use_scale_shift_norm, 429 | up=True, 430 | dtype=self.dtype, 431 | device=device, 432 | operations=operations 433 | ) 434 | if resblock_updown 435 | else Upsample3D(ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device, operations=operations) 436 | ) 437 | if not resblock_updown: 438 | self.num_upsamplers += 1 439 | ds //= 2 440 | self.output_blocks.append(TimestepEmbedSequential(*layers)) 441 | self._feature_size += ch 442 | 443 | self.out = nn.Sequential( 444 | operations.GroupNorm(32, ch, dtype=self.dtype, device=device), 445 | nn.SiLU(), 446 | zero_module(InflatedConv3d(model_channels, 447 | out_channels, 3, padding=1, dtype=self.dtype, device=device).half()), 448 | ) 449 | if self.predict_codebook_ids: 450 | self.id_predictor = nn.Sequential( 451 | operations.GroupNorm(32, ch, dtype=self.dtype, device=device), 452 | InflatedConv3d(model_channels, n_embed, 453 | 1, dtype=self.dtype, device=device).half(), 454 | # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits 455 | ) 456 | 457 | def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs): 458 | """ 459 | Apply the model to an input batch. 460 | :param x: an [N x C x ...] Tensor of inputs. 461 | :param timesteps: a 1-D batch of timesteps. 462 | :param context: conditioning plugged in via crossattn 463 | :param y: an [N] Tensor of labels, if class-conditional. 464 | :return: an [N x C x ...] Tensor of outputs. 465 | """ 466 | default_overall_up_factor = 2**self.num_upsamplers 467 | forward_upsample_size = False 468 | upsample_size = None 469 | 470 | if any(s % default_overall_up_factor != 0 for s in x.shape[-2:]): 471 | print('Upsampling True') 472 | forward_upsample_size = True 473 | 474 | flatten_shape = x.shape 475 | original_shape = (flatten_shape[0]*flatten_shape[2], 476 | flatten_shape[1], flatten_shape[3], flatten_shape[4]) 477 | transformer_options["original_shape"] = original_shape 478 | transformer_options["flatten_shape"] = flatten_shape 479 | transformer_options["transformer_index"] = 0 480 | transformer_patches = transformer_options.get("patches", {}) 481 | 482 | num_video_frames = kwargs.get( 483 | "num_video_frames", self.default_num_video_frames) 484 | image_only_indicator = kwargs.get( 485 | "image_only_indicator", self.default_image_only_indicator) 486 | time_context = kwargs.get("time_context", None) 487 | 488 | assert (y is not None) == ( 489 | self.num_classes is not None 490 | ), "must specify y if and only if the model is class-conditional" 491 | hs = [] 492 | t_emb = timestep_embedding( 493 | timesteps, self.model_channels, repeat_only=False).to(x.dtype) 494 | emb = self.time_embed(t_emb) 495 | 496 | if self.num_classes is not None: 497 | assert y.shape[0] == x.shape[0] 498 | emb = emb + self.label_emb(y) 499 | 500 | h = x 501 | for id, module in enumerate(self.input_blocks): 502 | transformer_options["block"] = ("input", id) 503 | h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context, 504 | num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) 505 | h = apply_control(h, control, 'input') 506 | if "input_block_patch" in transformer_patches: 507 | patch = transformer_patches["input_block_patch"] 508 | for p in patch: 509 | h = apply_patch3d(h, transformer_options, p) 510 | 511 | hs.append(h) 512 | if "input_block_patch_after_skip" in transformer_patches: 513 | patch = transformer_patches["input_block_patch_after_skip"] 514 | for p in patch: 515 | h = apply_patch3d(h, transformer_options, p) 516 | 517 | transformer_options["block"] = ("middle", 0) 518 | h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, 519 | num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) 520 | h = apply_control(h, control, 'middle') 521 | 522 | for id, module in enumerate(self.output_blocks): 523 | transformer_options["block"] = ("output", id) 524 | hsp = hs.pop() 525 | hsp = apply_control(hsp, control, 'output') 526 | 527 | if "output_block_patch" in transformer_patches: 528 | patch = transformer_patches["output_block_patch"] 529 | for p in patch: 530 | h, hsp = apply_unet_patch3d(h, hsp, transformer_options, p) 531 | 532 | h = th.cat([h, hsp], dim=1) 533 | del hsp 534 | if len(hs) > 0 and forward_upsample_size: 535 | output_shape = hs[-1].shape[2:] 536 | else: 537 | output_shape = None 538 | 539 | h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape, 540 | time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) 541 | h = h.type(x.dtype) 542 | if self.predict_codebook_ids: 543 | return self.id_predictor(h) 544 | else: 545 | return self.out(h) 546 | -------------------------------------------------------------------------------- /example_workflows/example_flatten_batched.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 119, 3 | "last_link_id": 324, 4 | "nodes": [ 5 | { 6 | "id": 16, 7 | "type": "CLIPTextEncode", 8 | "pos": [ 9 | -844, 10 | 569 11 | ], 12 | "size": { 13 | "0": 210, 14 | "1": 76 15 | }, 16 | "flags": {}, 17 | "order": 5, 18 | "mode": 0, 19 | "inputs": [ 20 | { 21 | "name": "clip", 22 | "type": "CLIP", 23 | "link": 30 24 | } 25 | ], 26 | "outputs": [ 27 | { 28 | "name": "CONDITIONING", 29 | "type": "CONDITIONING", 30 | "links": [ 31 | 276 32 | ], 33 | "slot_index": 0 34 | } 35 | ], 36 | "properties": { 37 | "Node name for S&R": "CLIPTextEncode" 38 | }, 39 | "widgets_values": [ 40 | "" 41 | ] 42 | }, 43 | { 44 | "id": 7, 45 | "type": "CLIPTextEncode", 46 | "pos": [ 47 | -846, 48 | 438 49 | ], 50 | "size": { 51 | "0": 210, 52 | "1": 76 53 | }, 54 | "flags": {}, 55 | "order": 6, 56 | "mode": 0, 57 | "inputs": [ 58 | { 59 | "name": "clip", 60 | "type": "CLIP", 61 | "link": 31 62 | } 63 | ], 64 | "outputs": [ 65 | { 66 | "name": "CONDITIONING", 67 | "type": "CONDITIONING", 68 | "links": [ 69 | 226 70 | ], 71 | "slot_index": 0 72 | } 73 | ], 74 | "properties": { 75 | "Node name for S&R": "CLIPTextEncode" 76 | }, 77 | "widgets_values": [ 78 | "" 79 | ] 80 | }, 81 | { 82 | "id": 117, 83 | "type": "ADE_StandardStaticContextOptions", 84 | "pos": [ 85 | -989, 86 | 1107 87 | ], 88 | "size": { 89 | "0": 319.20001220703125, 90 | "1": 190 91 | }, 92 | "flags": {}, 93 | "order": 9, 94 | "mode": 0, 95 | "inputs": [ 96 | { 97 | "name": "prev_context", 98 | "type": "CONTEXT_OPTIONS", 99 | "link": null 100 | }, 101 | { 102 | "name": "view_opts", 103 | "type": "VIEW_OPTS", 104 | "link": null 105 | }, 106 | { 107 | "name": "context_length", 108 | "type": "INT", 109 | "link": 320, 110 | "widget": { 111 | "name": "context_length" 112 | } 113 | }, 114 | { 115 | "name": "context_overlap", 116 | "type": "INT", 117 | "link": 323, 118 | "widget": { 119 | "name": "context_overlap" 120 | } 121 | } 122 | ], 123 | "outputs": [ 124 | { 125 | "name": "CONTEXT_OPTS", 126 | "type": "CONTEXT_OPTIONS", 127 | "links": [ 128 | 318 129 | ], 130 | "shape": 3, 131 | "slot_index": 0 132 | } 133 | ], 134 | "properties": { 135 | "Node name for S&R": "ADE_StandardStaticContextOptions" 136 | }, 137 | "widgets_values": [ 138 | 20, 139 | 10, 140 | "relative", 141 | false, 142 | 0, 143 | 1 144 | ] 145 | }, 146 | { 147 | "id": 14, 148 | "type": "VAEEncode", 149 | "pos": [ 150 | -850, 151 | 689 152 | ], 153 | "size": { 154 | "0": 210, 155 | "1": 46 156 | }, 157 | "flags": {}, 158 | "order": 12, 159 | "mode": 0, 160 | "inputs": [ 161 | { 162 | "name": "pixels", 163 | "type": "IMAGE", 164 | "link": 114 165 | }, 166 | { 167 | "name": "vae", 168 | "type": "VAE", 169 | "link": 33 170 | } 171 | ], 172 | "outputs": [ 173 | { 174 | "name": "LATENT", 175 | "type": "LATENT", 176 | "links": [ 177 | 277 178 | ], 179 | "shape": 3, 180 | "slot_index": 0 181 | } 182 | ], 183 | "properties": { 184 | "Node name for S&R": "VAEEncode" 185 | } 186 | }, 187 | { 188 | "id": 113, 189 | "type": "ADE_UseEvolvedSampling", 190 | "pos": [ 191 | -633, 192 | 1052 193 | ], 194 | "size": { 195 | "0": 235.1999969482422, 196 | "1": 118 197 | }, 198 | "flags": {}, 199 | "order": 14, 200 | "mode": 0, 201 | "inputs": [ 202 | { 203 | "name": "model", 204 | "type": "MODEL", 205 | "link": 314 206 | }, 207 | { 208 | "name": "m_models", 209 | "type": "M_MODELS", 210 | "link": null 211 | }, 212 | { 213 | "name": "context_options", 214 | "type": "CONTEXT_OPTIONS", 215 | "link": 318 216 | }, 217 | { 218 | "name": "sample_settings", 219 | "type": "SAMPLE_SETTINGS", 220 | "link": null 221 | } 222 | ], 223 | "outputs": [ 224 | { 225 | "name": "MODEL", 226 | "type": "MODEL", 227 | "links": [ 228 | 317, 229 | 324 230 | ], 231 | "shape": 3, 232 | "slot_index": 0 233 | } 234 | ], 235 | "properties": { 236 | "Node name for S&R": "ADE_UseEvolvedSampling" 237 | }, 238 | "widgets_values": [ 239 | "use existing" 240 | ] 241 | }, 242 | { 243 | "id": 15, 244 | "type": "TrajectoryNode", 245 | "pos": [ 246 | -856, 247 | 786 248 | ], 249 | "size": { 250 | "0": 220, 251 | "1": 75.51868438720703 252 | }, 253 | "flags": {}, 254 | "order": 10, 255 | "mode": 0, 256 | "inputs": [ 257 | { 258 | "name": "images", 259 | "type": "IMAGE", 260 | "link": 53 261 | }, 262 | { 263 | "name": "context_length", 264 | "type": "INT", 265 | "link": 322, 266 | "widget": { 267 | "name": "context_length" 268 | } 269 | }, 270 | { 271 | "name": "context_overlap", 272 | "type": "INT", 273 | "link": 321, 274 | "widget": { 275 | "name": "context_overlap" 276 | } 277 | } 278 | ], 279 | "outputs": [ 280 | { 281 | "name": "TRAJECTORY", 282 | "type": "TRAJECTORY", 283 | "links": [ 284 | 278, 285 | 287 286 | ], 287 | "shape": 3, 288 | "slot_index": 0 289 | } 290 | ], 291 | "properties": { 292 | "Node name for S&R": "TrajectoryNode" 293 | }, 294 | "widgets_values": [ 295 | 20, 296 | 10 297 | ] 298 | }, 299 | { 300 | "id": 24, 301 | "type": "VHS_VideoCombine", 302 | "pos": [ 303 | 706, 304 | 491 305 | ], 306 | "size": [ 307 | 320, 308 | 604 309 | ], 310 | "flags": {}, 311 | "order": 11, 312 | "mode": 0, 313 | "inputs": [ 314 | { 315 | "name": "images", 316 | "type": "IMAGE", 317 | "link": 55 318 | }, 319 | { 320 | "name": "audio", 321 | "type": "VHS_AUDIO", 322 | "link": null 323 | }, 324 | { 325 | "name": "batch_manager", 326 | "type": "VHS_BatchManager", 327 | "link": null 328 | } 329 | ], 330 | "outputs": [ 331 | { 332 | "name": "Filenames", 333 | "type": "VHS_FILENAMES", 334 | "links": null, 335 | "shape": 3 336 | } 337 | ], 338 | "properties": { 339 | "Node name for S&R": "VHS_VideoCombine" 340 | }, 341 | "widgets_values": { 342 | "frame_rate": 8, 343 | "loop_count": 0, 344 | "filename_prefix": "AnimateDiff", 345 | "format": "video/h264-mp4", 346 | "pix_fmt": "yuv420p", 347 | "crf": 19, 348 | "save_metadata": true, 349 | "pingpong": false, 350 | "save_output": true, 351 | "videopreview": { 352 | "hidden": false, 353 | "paused": false, 354 | "params": { 355 | "filename": "AnimateDiff_00514.mp4", 356 | "subfolder": "", 357 | "type": "output", 358 | "format": "video/h264-mp4" 359 | } 360 | } 361 | } 362 | }, 363 | { 364 | "id": 88, 365 | "type": "DepthAnythingPreprocessor", 366 | "pos": [ 367 | -855, 368 | 91 369 | ], 370 | "size": { 371 | "0": 254.98291015625, 372 | "1": 82 373 | }, 374 | "flags": {}, 375 | "order": 13, 376 | "mode": 0, 377 | "inputs": [ 378 | { 379 | "name": "image", 380 | "type": "IMAGE", 381 | "link": 253 382 | } 383 | ], 384 | "outputs": [ 385 | { 386 | "name": "IMAGE", 387 | "type": "IMAGE", 388 | "links": [ 389 | 254 390 | ], 391 | "shape": 3, 392 | "slot_index": 0 393 | } 394 | ], 395 | "properties": { 396 | "Node name for S&R": "DepthAnythingPreprocessor" 397 | }, 398 | "widgets_values": [ 399 | "depth_anything_vitl14.pth", 400 | 512 401 | ] 402 | }, 403 | { 404 | "id": 18, 405 | "type": "FlattenCheckpointLoaderNode", 406 | "pos": [ 407 | -1274, 408 | 647 409 | ], 410 | "size": { 411 | "0": 285.6000061035156, 412 | "1": 98 413 | }, 414 | "flags": {}, 415 | "order": 0, 416 | "mode": 0, 417 | "outputs": [ 418 | { 419 | "name": "MODEL", 420 | "type": "MODEL", 421 | "links": [ 422 | 314 423 | ], 424 | "shape": 3, 425 | "slot_index": 0 426 | }, 427 | { 428 | "name": "CLIP", 429 | "type": "CLIP", 430 | "links": [ 431 | 30, 432 | 31, 433 | 32 434 | ], 435 | "shape": 3, 436 | "slot_index": 1 437 | }, 438 | { 439 | "name": "VAE", 440 | "type": "VAE", 441 | "links": [ 442 | 33, 443 | 35 444 | ], 445 | "shape": 3, 446 | "slot_index": 2 447 | } 448 | ], 449 | "properties": { 450 | "Node name for S&R": "FlattenCheckpointLoaderNode" 451 | }, 452 | "widgets_values": [ 453 | "juggernaut_reborn.safetensors" 454 | ] 455 | }, 456 | { 457 | "id": 23, 458 | "type": "ImageScale", 459 | "pos": [ 460 | -1242, 461 | 394 462 | ], 463 | "size": { 464 | "0": 242.2943115234375, 465 | "1": 130 466 | }, 467 | "flags": {}, 468 | "order": 8, 469 | "mode": 0, 470 | "inputs": [ 471 | { 472 | "name": "image", 473 | "type": "IMAGE", 474 | "link": 52 475 | } 476 | ], 477 | "outputs": [ 478 | { 479 | "name": "IMAGE", 480 | "type": "IMAGE", 481 | "links": [ 482 | 53, 483 | 55, 484 | 114, 485 | 253 486 | ], 487 | "shape": 3, 488 | "slot_index": 0 489 | } 490 | ], 491 | "properties": { 492 | "Node name for S&R": "ImageScale" 493 | }, 494 | "widgets_values": [ 495 | "nearest-exact", 496 | 512, 497 | 512, 498 | "center" 499 | ] 500 | }, 501 | { 502 | "id": 8, 503 | "type": "VAEDecode", 504 | "pos": [ 505 | 180, 506 | 490 507 | ], 508 | "size": { 509 | "0": 140, 510 | "1": 46 511 | }, 512 | "flags": {}, 513 | "order": 18, 514 | "mode": 0, 515 | "inputs": [ 516 | { 517 | "name": "samples", 518 | "type": "LATENT", 519 | "link": 288 520 | }, 521 | { 522 | "name": "vae", 523 | "type": "VAE", 524 | "link": 35 525 | } 526 | ], 527 | "outputs": [ 528 | { 529 | "name": "IMAGE", 530 | "type": "IMAGE", 531 | "links": [ 532 | 10 533 | ], 534 | "slot_index": 0 535 | } 536 | ], 537 | "properties": { 538 | "Node name for S&R": "VAEDecode" 539 | } 540 | }, 541 | { 542 | "id": 26, 543 | "type": "ControlNetLoader", 544 | "pos": [ 545 | -829, 546 | -32 547 | ], 548 | "size": { 549 | "0": 210, 550 | "1": 58 551 | }, 552 | "flags": {}, 553 | "order": 1, 554 | "mode": 0, 555 | "outputs": [ 556 | { 557 | "name": "CONTROL_NET", 558 | "type": "CONTROL_NET", 559 | "links": [ 560 | 169 561 | ], 562 | "shape": 3, 563 | "slot_index": 0 564 | } 565 | ], 566 | "properties": { 567 | "Node name for S&R": "ControlNetLoader" 568 | }, 569 | "widgets_values": [ 570 | "control_v11f1p_sd15_depth.pth" 571 | ] 572 | }, 573 | { 574 | "id": 119, 575 | "type": "PrimitiveNode", 576 | "pos": [ 577 | -1226, 578 | 1012 579 | ], 580 | "size": { 581 | "0": 210, 582 | "1": 82 583 | }, 584 | "flags": {}, 585 | "order": 2, 586 | "mode": 0, 587 | "outputs": [ 588 | { 589 | "name": "INT", 590 | "type": "INT", 591 | "links": [ 592 | 321, 593 | 323 594 | ], 595 | "slot_index": 0, 596 | "widget": { 597 | "name": "context_overlap" 598 | } 599 | } 600 | ], 601 | "properties": { 602 | "Run widget replace on values": false 603 | }, 604 | "widgets_values": [ 605 | 10, 606 | "fixed" 607 | ] 608 | }, 609 | { 610 | "id": 95, 611 | "type": "UnsamplerFlattenNode", 612 | "pos": [ 613 | -362, 614 | 793 615 | ], 616 | "size": { 617 | "0": 210, 618 | "1": 238 619 | }, 620 | "flags": {}, 621 | "order": 16, 622 | "mode": 0, 623 | "inputs": [ 624 | { 625 | "name": "model", 626 | "type": "MODEL", 627 | "link": 324 628 | }, 629 | { 630 | "name": "positive", 631 | "type": "CONDITIONING", 632 | "link": 276 633 | }, 634 | { 635 | "name": "latent_image", 636 | "type": "LATENT", 637 | "link": 277 638 | }, 639 | { 640 | "name": "trajectories", 641 | "type": "TRAJECTORY", 642 | "link": 278 643 | } 644 | ], 645 | "outputs": [ 646 | { 647 | "name": "LATENT", 648 | "type": "LATENT", 649 | "links": [ 650 | 284 651 | ], 652 | "shape": 3, 653 | "slot_index": 0 654 | }, 655 | { 656 | "name": "INJECTIONS", 657 | "type": "INJECTIONS", 658 | "links": [ 659 | 283 660 | ], 661 | "shape": 3, 662 | "slot_index": 1 663 | } 664 | ], 665 | "properties": { 666 | "Node name for S&R": "UnsamplerFlattenNode" 667 | }, 668 | "widgets_values": [ 669 | 20, 670 | 8, 671 | "euler", 672 | "normal", 673 | "disable", 674 | 0 675 | ] 676 | }, 677 | { 678 | "id": 96, 679 | "type": "KSamplerFlattenNode", 680 | "pos": [ 681 | -116, 682 | 629 683 | ], 684 | "size": { 685 | "0": 275.4591064453125, 686 | "1": 422 687 | }, 688 | "flags": {}, 689 | "order": 17, 690 | "mode": 0, 691 | "inputs": [ 692 | { 693 | "name": "model", 694 | "type": "MODEL", 695 | "link": 317 696 | }, 697 | { 698 | "name": "trajectories", 699 | "type": "TRAJECTORY", 700 | "link": 287 701 | }, 702 | { 703 | "name": "positive", 704 | "type": "CONDITIONING", 705 | "link": 286 706 | }, 707 | { 708 | "name": "negative", 709 | "type": "CONDITIONING", 710 | "link": 285 711 | }, 712 | { 713 | "name": "latent_image", 714 | "type": "LATENT", 715 | "link": 284 716 | }, 717 | { 718 | "name": "injections", 719 | "type": "INJECTIONS", 720 | "link": 283 721 | } 722 | ], 723 | "outputs": [ 724 | { 725 | "name": "LATENT", 726 | "type": "LATENT", 727 | "links": [ 728 | 288 729 | ], 730 | "shape": 3, 731 | "slot_index": 0 732 | } 733 | ], 734 | "properties": { 735 | "Node name for S&R": "KSamplerFlattenNode" 736 | }, 737 | "widgets_values": [ 738 | "disable", 739 | 846365516879693, 740 | "fixed", 741 | 10, 742 | 8, 743 | 0, 744 | 6, 745 | "dpmpp_2m", 746 | "karras", 747 | 0, 748 | 10000, 749 | "disable" 750 | ] 751 | }, 752 | { 753 | "id": 6, 754 | "type": "CLIPTextEncode", 755 | "pos": [ 756 | -869, 757 | 259 758 | ], 759 | "size": { 760 | "0": 260.2884826660156, 761 | "1": 114.05644226074219 762 | }, 763 | "flags": {}, 764 | "order": 7, 765 | "mode": 0, 766 | "inputs": [ 767 | { 768 | "name": "clip", 769 | "type": "CLIP", 770 | "link": 32 771 | } 772 | ], 773 | "outputs": [ 774 | { 775 | "name": "CONDITIONING", 776 | "type": "CONDITIONING", 777 | "links": [ 778 | 225 779 | ], 780 | "slot_index": 0 781 | } 782 | ], 783 | "properties": { 784 | "Node name for S&R": "CLIPTextEncode" 785 | }, 786 | "widgets_values": [ 787 | "an armored knight" 788 | ] 789 | }, 790 | { 791 | "id": 10, 792 | "type": "VHS_VideoCombine", 793 | "pos": [ 794 | 354, 795 | 490 796 | ], 797 | "size": [ 798 | 320, 799 | 604 800 | ], 801 | "flags": {}, 802 | "order": 19, 803 | "mode": 0, 804 | "inputs": [ 805 | { 806 | "name": "images", 807 | "type": "IMAGE", 808 | "link": 10 809 | }, 810 | { 811 | "name": "audio", 812 | "type": "VHS_AUDIO", 813 | "link": null 814 | }, 815 | { 816 | "name": "batch_manager", 817 | "type": "VHS_BatchManager", 818 | "link": null 819 | } 820 | ], 821 | "outputs": [ 822 | { 823 | "name": "Filenames", 824 | "type": "VHS_FILENAMES", 825 | "links": null, 826 | "shape": 3 827 | } 828 | ], 829 | "properties": { 830 | "Node name for S&R": "VHS_VideoCombine" 831 | }, 832 | "widgets_values": { 833 | "frame_rate": 8, 834 | "loop_count": 0, 835 | "filename_prefix": "AnimateDiff", 836 | "format": "video/h264-mp4", 837 | "pix_fmt": "yuv420p", 838 | "crf": 18, 839 | "save_metadata": true, 840 | "pingpong": false, 841 | "save_output": true, 842 | "videopreview": { 843 | "hidden": false, 844 | "paused": false, 845 | "params": { 846 | "filename": "AnimateDiff_00515.mp4", 847 | "subfolder": "", 848 | "type": "output", 849 | "format": "video/h264-mp4" 850 | } 851 | } 852 | } 853 | }, 854 | { 855 | "id": 59, 856 | "type": "ACN_AdvancedControlNetApply", 857 | "pos": [ 858 | -492, 859 | 211 860 | ], 861 | "size": { 862 | "0": 285.6000061035156, 863 | "1": 266 864 | }, 865 | "flags": {}, 866 | "order": 15, 867 | "mode": 0, 868 | "inputs": [ 869 | { 870 | "name": "positive", 871 | "type": "CONDITIONING", 872 | "link": 225 873 | }, 874 | { 875 | "name": "negative", 876 | "type": "CONDITIONING", 877 | "link": 226 878 | }, 879 | { 880 | "name": "control_net", 881 | "type": "CONTROL_NET", 882 | "link": 169 883 | }, 884 | { 885 | "name": "image", 886 | "type": "IMAGE", 887 | "link": 254 888 | }, 889 | { 890 | "name": "mask_optional", 891 | "type": "MASK", 892 | "link": null 893 | }, 894 | { 895 | "name": "timestep_kf", 896 | "type": "TIMESTEP_KEYFRAME", 897 | "link": null 898 | }, 899 | { 900 | "name": "latent_kf_override", 901 | "type": "LATENT_KEYFRAME", 902 | "link": null 903 | }, 904 | { 905 | "name": "weights_override", 906 | "type": "CONTROL_NET_WEIGHTS", 907 | "link": null 908 | }, 909 | { 910 | "name": "model_optional", 911 | "type": "MODEL", 912 | "link": null 913 | } 914 | ], 915 | "outputs": [ 916 | { 917 | "name": "positive", 918 | "type": "CONDITIONING", 919 | "links": [ 920 | 286 921 | ], 922 | "shape": 3, 923 | "slot_index": 0 924 | }, 925 | { 926 | "name": "negative", 927 | "type": "CONDITIONING", 928 | "links": [ 929 | 285 930 | ], 931 | "shape": 3, 932 | "slot_index": 1 933 | }, 934 | { 935 | "name": "model_opt", 936 | "type": "MODEL", 937 | "links": null, 938 | "shape": 3 939 | } 940 | ], 941 | "properties": { 942 | "Node name for S&R": "ACN_AdvancedControlNetApply" 943 | }, 944 | "widgets_values": [ 945 | 0.5, 946 | 0, 947 | 0.75 948 | ] 949 | }, 950 | { 951 | "id": 13, 952 | "type": "VHS_LoadVideo", 953 | "pos": [ 954 | -1556, 955 | 409 956 | ], 957 | "size": [ 958 | 240, 959 | 476 960 | ], 961 | "flags": {}, 962 | "order": 3, 963 | "mode": 0, 964 | "inputs": [ 965 | { 966 | "name": "batch_manager", 967 | "type": "VHS_BatchManager", 968 | "link": null 969 | } 970 | ], 971 | "outputs": [ 972 | { 973 | "name": "IMAGE", 974 | "type": "IMAGE", 975 | "links": [ 976 | 52 977 | ], 978 | "shape": 3, 979 | "slot_index": 0 980 | }, 981 | { 982 | "name": "frame_count", 983 | "type": "INT", 984 | "links": null, 985 | "shape": 3 986 | }, 987 | { 988 | "name": "audio", 989 | "type": "VHS_AUDIO", 990 | "links": null, 991 | "shape": 3 992 | } 993 | ], 994 | "properties": { 995 | "Node name for S&R": "VHS_LoadVideo" 996 | }, 997 | "widgets_values": { 998 | "video": "waving2.mp4", 999 | "force_rate": 0, 1000 | "force_size": "Disabled", 1001 | "custom_width": 512, 1002 | "custom_height": 512, 1003 | "frame_load_cap": 40, 1004 | "skip_first_frames": 0, 1005 | "select_every_nth": 1, 1006 | "choose video to upload": "image", 1007 | "videopreview": { 1008 | "hidden": false, 1009 | "paused": false, 1010 | "params": { 1011 | "frame_load_cap": 40, 1012 | "skip_first_frames": 0, 1013 | "force_rate": 0, 1014 | "select_every_nth": 1, 1015 | "filename": "waving2.mp4", 1016 | "type": "input", 1017 | "format": "video/mp4" 1018 | } 1019 | } 1020 | } 1021 | }, 1022 | { 1023 | "id": 118, 1024 | "type": "PrimitiveNode", 1025 | "pos": [ 1026 | -1226, 1027 | 880 1028 | ], 1029 | "size": { 1030 | "0": 210, 1031 | "1": 82 1032 | }, 1033 | "flags": {}, 1034 | "order": 4, 1035 | "mode": 0, 1036 | "outputs": [ 1037 | { 1038 | "name": "INT", 1039 | "type": "INT", 1040 | "links": [ 1041 | 320, 1042 | 322 1043 | ], 1044 | "slot_index": 0, 1045 | "widget": { 1046 | "name": "context_length" 1047 | } 1048 | } 1049 | ], 1050 | "properties": { 1051 | "Run widget replace on values": false 1052 | }, 1053 | "widgets_values": [ 1054 | 20, 1055 | "fixed" 1056 | ] 1057 | } 1058 | ], 1059 | "links": [ 1060 | [ 1061 | 10, 1062 | 8, 1063 | 0, 1064 | 10, 1065 | 0, 1066 | "IMAGE" 1067 | ], 1068 | [ 1069 | 30, 1070 | 18, 1071 | 1, 1072 | 16, 1073 | 0, 1074 | "CLIP" 1075 | ], 1076 | [ 1077 | 31, 1078 | 18, 1079 | 1, 1080 | 7, 1081 | 0, 1082 | "CLIP" 1083 | ], 1084 | [ 1085 | 32, 1086 | 18, 1087 | 1, 1088 | 6, 1089 | 0, 1090 | "CLIP" 1091 | ], 1092 | [ 1093 | 33, 1094 | 18, 1095 | 2, 1096 | 14, 1097 | 1, 1098 | "VAE" 1099 | ], 1100 | [ 1101 | 35, 1102 | 18, 1103 | 2, 1104 | 8, 1105 | 1, 1106 | "VAE" 1107 | ], 1108 | [ 1109 | 52, 1110 | 13, 1111 | 0, 1112 | 23, 1113 | 0, 1114 | "IMAGE" 1115 | ], 1116 | [ 1117 | 53, 1118 | 23, 1119 | 0, 1120 | 15, 1121 | 0, 1122 | "IMAGE" 1123 | ], 1124 | [ 1125 | 55, 1126 | 23, 1127 | 0, 1128 | 24, 1129 | 0, 1130 | "IMAGE" 1131 | ], 1132 | [ 1133 | 114, 1134 | 23, 1135 | 0, 1136 | 14, 1137 | 0, 1138 | "IMAGE" 1139 | ], 1140 | [ 1141 | 169, 1142 | 26, 1143 | 0, 1144 | 59, 1145 | 2, 1146 | "CONTROL_NET" 1147 | ], 1148 | [ 1149 | 225, 1150 | 6, 1151 | 0, 1152 | 59, 1153 | 0, 1154 | "CONDITIONING" 1155 | ], 1156 | [ 1157 | 226, 1158 | 7, 1159 | 0, 1160 | 59, 1161 | 1, 1162 | "CONDITIONING" 1163 | ], 1164 | [ 1165 | 253, 1166 | 23, 1167 | 0, 1168 | 88, 1169 | 0, 1170 | "IMAGE" 1171 | ], 1172 | [ 1173 | 254, 1174 | 88, 1175 | 0, 1176 | 59, 1177 | 3, 1178 | "IMAGE" 1179 | ], 1180 | [ 1181 | 276, 1182 | 16, 1183 | 0, 1184 | 95, 1185 | 1, 1186 | "CONDITIONING" 1187 | ], 1188 | [ 1189 | 277, 1190 | 14, 1191 | 0, 1192 | 95, 1193 | 2, 1194 | "LATENT" 1195 | ], 1196 | [ 1197 | 278, 1198 | 15, 1199 | 0, 1200 | 95, 1201 | 3, 1202 | "TRAJECTORY" 1203 | ], 1204 | [ 1205 | 283, 1206 | 95, 1207 | 1, 1208 | 96, 1209 | 5, 1210 | "INJECTIONS" 1211 | ], 1212 | [ 1213 | 284, 1214 | 95, 1215 | 0, 1216 | 96, 1217 | 4, 1218 | "LATENT" 1219 | ], 1220 | [ 1221 | 285, 1222 | 59, 1223 | 1, 1224 | 96, 1225 | 3, 1226 | "CONDITIONING" 1227 | ], 1228 | [ 1229 | 286, 1230 | 59, 1231 | 0, 1232 | 96, 1233 | 2, 1234 | "CONDITIONING" 1235 | ], 1236 | [ 1237 | 287, 1238 | 15, 1239 | 0, 1240 | 96, 1241 | 1, 1242 | "TRAJECTORY" 1243 | ], 1244 | [ 1245 | 288, 1246 | 96, 1247 | 0, 1248 | 8, 1249 | 0, 1250 | "LATENT" 1251 | ], 1252 | [ 1253 | 314, 1254 | 18, 1255 | 0, 1256 | 113, 1257 | 0, 1258 | "MODEL" 1259 | ], 1260 | [ 1261 | 317, 1262 | 113, 1263 | 0, 1264 | 96, 1265 | 0, 1266 | "MODEL" 1267 | ], 1268 | [ 1269 | 318, 1270 | 117, 1271 | 0, 1272 | 113, 1273 | 2, 1274 | "CONTEXT_OPTIONS" 1275 | ], 1276 | [ 1277 | 320, 1278 | 118, 1279 | 0, 1280 | 117, 1281 | 2, 1282 | "INT" 1283 | ], 1284 | [ 1285 | 321, 1286 | 119, 1287 | 0, 1288 | 15, 1289 | 2, 1290 | "INT" 1291 | ], 1292 | [ 1293 | 322, 1294 | 118, 1295 | 0, 1296 | 15, 1297 | 1, 1298 | "INT" 1299 | ], 1300 | [ 1301 | 323, 1302 | 119, 1303 | 0, 1304 | 117, 1305 | 3, 1306 | "INT" 1307 | ], 1308 | [ 1309 | 324, 1310 | 113, 1311 | 0, 1312 | 95, 1313 | 0, 1314 | "MODEL" 1315 | ] 1316 | ], 1317 | "groups": [], 1318 | "config": {}, 1319 | "extra": {}, 1320 | "version": 0.4 1321 | } --------------------------------------------------------------------------------