├── .gitignore ├── DiT ├── LICENSE-DiT ├── conf.py ├── labels │ └── imagenet1000.json ├── loader.py ├── model.py └── nodes.py ├── HunYuan ├── __init__.py ├── conf.py ├── loader.py ├── models │ ├── __init__.py │ ├── attn_layers.py │ ├── embedders.py │ ├── models.py │ ├── norm_layers.py │ ├── poolers.py │ ├── posemb_layers.py │ └── text_encoder.py ├── nodes.py ├── wf.json └── wf.png ├── LICENSE ├── PixArt ├── LICENSE-PixArt ├── conf.py ├── diffusers_convert.py ├── loader.py ├── lora.py ├── models │ ├── PixArt.py │ ├── PixArtMS.py │ ├── PixArt_blocks.py │ ├── pixart_controlnet.py │ └── utils.py └── nodes.py ├── README.md ├── T5 ├── LICENSE-ComfyUI ├── LICENSE-T5 ├── loader.py ├── nodes.py ├── t5_tokenizer │ ├── special_tokens_map.json │ ├── spiece.model │ └── tokenizer_config.json ├── t5v11-xxl_config.json └── t5v11.py ├── VAE ├── conf.py ├── loader.py ├── models │ ├── LICENSE-Consistency-Decoder │ ├── LICENSE-Kandinsky-3 │ ├── LICENSE-Latent-Diffusion │ ├── LICENSE-SAI │ ├── LICENSE-SDV │ ├── LICENSE-Taming-Transformers │ ├── consistencydecoder.py │ ├── kl.py │ ├── movq3.py │ ├── temporal_ae.py │ └── vq.py └── nodes.py ├── __init__.py ├── requirements.txt └── utils └── dtype.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /DiT/conf.py: -------------------------------------------------------------------------------- 1 | """ 2 | List of all DiT model types / settings 3 | """ 4 | sampling_settings = { 5 | "beta_schedule" : "sqrt_linear", 6 | "linear_start" : 0.0001, 7 | "linear_end" : 0.02, 8 | "timesteps" : 1000, 9 | } 10 | 11 | dit_conf = { 12 | "XL/2": { # DiT_XL_2 13 | "unet_config": { 14 | "depth" : 28, 15 | "num_heads" : 16, 16 | "patch_size" : 2, 17 | "hidden_size" : 1152, 18 | }, 19 | "sampling_settings" : sampling_settings, 20 | }, 21 | "XL/4": { # DiT_XL_4 22 | "unet_config": { 23 | "depth" : 28, 24 | "num_heads" : 16, 25 | "patch_size" : 4, 26 | "hidden_size" : 1152, 27 | }, 28 | "sampling_settings" : sampling_settings, 29 | }, 30 | "XL/8": { # DiT_XL_8 31 | "unet_config": { 32 | "depth" : 28, 33 | "num_heads" : 16, 34 | "patch_size" : 8, 35 | "hidden_size" : 1152, 36 | }, 37 | "sampling_settings" : sampling_settings, 38 | }, 39 | "L/2": { # DiT_L_2 40 | "unet_config": { 41 | "depth" : 24, 42 | "num_heads" : 16, 43 | "patch_size" : 2, 44 | "hidden_size" : 1024, 45 | }, 46 | "sampling_settings" : sampling_settings, 47 | }, 48 | "L/4": { # DiT_L_4 49 | "unet_config": { 50 | "depth" : 24, 51 | "num_heads" : 16, 52 | "patch_size" : 4, 53 | "hidden_size" : 1024, 54 | }, 55 | "sampling_settings" : sampling_settings, 56 | }, 57 | "L/8": { # DiT_L_8 58 | "unet_config": { 59 | "depth" : 24, 60 | "num_heads" : 16, 61 | "patch_size" : 8, 62 | "hidden_size" : 1024, 63 | }, 64 | "sampling_settings" : sampling_settings, 65 | }, 66 | "B/2": { # DiT_B_2 67 | "unet_config": { 68 | "depth" : 12, 69 | "num_heads" : 12, 70 | "patch_size" : 2, 71 | "hidden_size" : 768, 72 | }, 73 | "sampling_settings" : sampling_settings, 74 | }, 75 | "B/4": { # DiT_B_4 76 | "unet_config": { 77 | "depth" : 12, 78 | "num_heads" : 12, 79 | "patch_size" : 4, 80 | "hidden_size" : 768, 81 | }, 82 | "sampling_settings" : sampling_settings, 83 | }, 84 | "B/8": { # DiT_B_8 85 | "unet_config": { 86 | "depth" : 12, 87 | "num_heads" : 12, 88 | "patch_size" : 8, 89 | "hidden_size" : 768, 90 | }, 91 | "sampling_settings" : sampling_settings, 92 | }, 93 | "S/2": { # DiT_S_2 94 | "unet_config": { 95 | "depth" : 12, 96 | "num_heads" : 6, 97 | "patch_size" : 2, 98 | "hidden_size" : 384, 99 | }, 100 | "sampling_settings" : sampling_settings, 101 | }, 102 | "S/4": { # DiT_S_4 103 | "unet_config": { 104 | "depth" : 12, 105 | "num_heads" : 6, 106 | "patch_size" : 4, 107 | "hidden_size" : 384, 108 | }, 109 | "sampling_settings" : sampling_settings, 110 | }, 111 | "S/8": { # DiT_S_8 112 | "unet_config": { 113 | "depth" : 12, 114 | "num_heads" : 6, 115 | "patch_size" : 8, 116 | "hidden_size" : 384, 117 | }, 118 | "sampling_settings" : sampling_settings, 119 | }, 120 | } 121 | -------------------------------------------------------------------------------- /DiT/loader.py: -------------------------------------------------------------------------------- 1 | import comfy.supported_models_base 2 | import comfy.latent_formats 3 | import comfy.model_patcher 4 | import comfy.model_base 5 | import comfy.utils 6 | import torch 7 | from comfy import model_management 8 | 9 | class EXM_DiT(comfy.supported_models_base.BASE): 10 | unet_config = {} 11 | unet_extra_config = {} 12 | latent_format = comfy.latent_formats.SD15 13 | 14 | def __init__(self, model_conf): 15 | self.unet_config = model_conf.get("unet_config", {}) 16 | self.sampling_settings = model_conf.get("sampling_settings", {}) 17 | self.latent_format = self.latent_format() 18 | # UNET is handled by extension 19 | self.unet_config["disable_unet_model_creation"] = True 20 | 21 | def model_type(self, state_dict, prefix=""): 22 | return comfy.model_base.ModelType.EPS 23 | 24 | def load_dit(model_path, model_conf): 25 | state_dict = comfy.utils.load_torch_file(model_path) 26 | state_dict = state_dict.get("model", state_dict) 27 | parameters = comfy.utils.calculate_parameters(state_dict) 28 | unet_dtype = model_management.unet_dtype(model_params=parameters) 29 | load_device = comfy.model_management.get_torch_device() 30 | offload_device = comfy.model_management.unet_offload_device() 31 | 32 | # ignore fp8/etc and use directly for now 33 | manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device) 34 | if manual_cast_dtype: 35 | print(f"DiT: falling back to {manual_cast_dtype}") 36 | unet_dtype = manual_cast_dtype 37 | 38 | model_conf["unet_config"]["num_classes"] = state_dict["y_embedder.embedding_table.weight"].shape[0] - 1 # adj. for empty 39 | 40 | model_conf = EXM_DiT(model_conf) 41 | model = comfy.model_base.BaseModel( 42 | model_conf, 43 | model_type=comfy.model_base.ModelType.EPS, 44 | device=model_management.get_torch_device() 45 | ) 46 | 47 | from .model import DiT 48 | model.diffusion_model = DiT(**model_conf.unet_config) 49 | 50 | model.diffusion_model.load_state_dict(state_dict) 51 | model.diffusion_model.dtype = unet_dtype 52 | model.diffusion_model.eval() 53 | model.diffusion_model.to(unet_dtype) 54 | 55 | model_patcher = comfy.model_patcher.ModelPatcher( 56 | model, 57 | load_device = load_device, 58 | offload_device = offload_device, 59 | current_device = "cpu", 60 | ) 61 | return model_patcher 62 | -------------------------------------------------------------------------------- /DiT/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # GLIDE: https://github.com/openai/glide-text2im 9 | # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py 10 | # -------------------------------------------------------- 11 | 12 | import torch 13 | import torch.nn as nn 14 | import numpy as np 15 | import math 16 | from timm.models.vision_transformer import PatchEmbed, Attention, Mlp 17 | 18 | 19 | def modulate(x, shift, scale): 20 | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) 21 | 22 | 23 | ################################################################################# 24 | # Embedding Layers for Timesteps and Class Labels # 25 | ################################################################################# 26 | 27 | class TimestepEmbedder(nn.Module): 28 | """ 29 | Embeds scalar timesteps into vector representations. 30 | """ 31 | def __init__(self, hidden_size, frequency_embedding_size=256): 32 | super().__init__() 33 | self.mlp = nn.Sequential( 34 | nn.Linear(frequency_embedding_size, hidden_size, bias=True), 35 | nn.SiLU(), 36 | nn.Linear(hidden_size, hidden_size, bias=True), 37 | ) 38 | self.frequency_embedding_size = frequency_embedding_size 39 | 40 | @staticmethod 41 | def timestep_embedding(t, dim, max_period=10000): 42 | """ 43 | Create sinusoidal timestep embeddings. 44 | :param t: a 1-D Tensor of N indices, one per batch element. 45 | These may be fractional. 46 | :param dim: the dimension of the output. 47 | :param max_period: controls the minimum frequency of the embeddings. 48 | :return: an (N, D) Tensor of positional embeddings. 49 | """ 50 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py 51 | half = dim // 2 52 | freqs = torch.exp( 53 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 54 | ).to(device=t.device) 55 | args = t[:, None].float() * freqs[None] 56 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 57 | if dim % 2: 58 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 59 | return embedding.to(dtype=t.dtype) 60 | 61 | def forward(self, t): 62 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size) 63 | t_emb = self.mlp(t_freq) 64 | return t_emb 65 | 66 | 67 | class LabelEmbedder(nn.Module): 68 | """ 69 | Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. 70 | """ 71 | def __init__(self, num_classes, hidden_size, dropout_prob): 72 | super().__init__() 73 | use_cfg_embedding = dropout_prob > 0 74 | self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) 75 | self.num_classes = num_classes 76 | self.dropout_prob = dropout_prob 77 | 78 | def token_drop(self, labels, force_drop_ids=None): 79 | """ 80 | Drops labels to enable classifier-free guidance. 81 | """ 82 | if force_drop_ids is None: 83 | drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob 84 | else: 85 | drop_ids = force_drop_ids == 1 86 | labels = torch.where(drop_ids, self.num_classes, labels) 87 | return labels 88 | 89 | def forward(self, labels, train, force_drop_ids=None): 90 | use_dropout = self.dropout_prob > 0 91 | if (train and use_dropout) or (force_drop_ids is not None): 92 | labels = self.token_drop(labels, force_drop_ids) 93 | embeddings = self.embedding_table(labels) 94 | return embeddings 95 | 96 | 97 | ################################################################################# 98 | # Core DiT Model # 99 | ################################################################################# 100 | 101 | class DiTBlock(nn.Module): 102 | """ 103 | A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. 104 | """ 105 | def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs): 106 | super().__init__() 107 | self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 108 | self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs) 109 | self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 110 | mlp_hidden_dim = int(hidden_size * mlp_ratio) 111 | approx_gelu = lambda: nn.GELU(approximate="tanh") 112 | self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) 113 | self.adaLN_modulation = nn.Sequential( 114 | nn.SiLU(), 115 | nn.Linear(hidden_size, 6 * hidden_size, bias=True) 116 | ) 117 | 118 | def forward(self, x, c): 119 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) 120 | x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) 121 | x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) 122 | return x 123 | 124 | 125 | class FinalLayer(nn.Module): 126 | """ 127 | The final layer of DiT. 128 | """ 129 | def __init__(self, hidden_size, patch_size, out_channels): 130 | super().__init__() 131 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 132 | self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) 133 | self.adaLN_modulation = nn.Sequential( 134 | nn.SiLU(), 135 | nn.Linear(hidden_size, 2 * hidden_size, bias=True) 136 | ) 137 | 138 | def forward(self, x, c): 139 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) 140 | x = modulate(self.norm_final(x), shift, scale) 141 | x = self.linear(x) 142 | return x 143 | 144 | 145 | class DiT(nn.Module): 146 | """ 147 | Diffusion model with a Transformer backbone. 148 | """ 149 | def __init__( 150 | self, 151 | input_size=32, 152 | patch_size=2, 153 | in_channels=4, 154 | hidden_size=1152, 155 | depth=28, 156 | num_heads=16, 157 | mlp_ratio=4.0, 158 | class_dropout_prob=0.1, 159 | num_classes=1000, 160 | learn_sigma=True, 161 | **kwargs, 162 | ): 163 | super().__init__() 164 | self.learn_sigma = learn_sigma 165 | self.in_channels = in_channels 166 | self.out_channels = in_channels * 2 if learn_sigma else in_channels 167 | self.patch_size = patch_size 168 | self.num_heads = num_heads 169 | 170 | self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True) 171 | self.t_embedder = TimestepEmbedder(hidden_size) 172 | self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob) 173 | num_patches = self.x_embedder.num_patches 174 | # Will use fixed sin-cos embedding: 175 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False) 176 | 177 | self.blocks = nn.ModuleList([ 178 | DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth) 179 | ]) 180 | self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels) 181 | self.initialize_weights() 182 | 183 | def initialize_weights(self): 184 | # Initialize transformer layers: 185 | def _basic_init(module): 186 | if isinstance(module, nn.Linear): 187 | torch.nn.init.xavier_uniform_(module.weight) 188 | if module.bias is not None: 189 | nn.init.constant_(module.bias, 0) 190 | self.apply(_basic_init) 191 | 192 | # Initialize (and freeze) pos_embed by sin-cos embedding: 193 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5)) 194 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 195 | 196 | # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): 197 | w = self.x_embedder.proj.weight.data 198 | nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 199 | nn.init.constant_(self.x_embedder.proj.bias, 0) 200 | 201 | # Initialize label embedding table: 202 | nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) 203 | 204 | # Initialize timestep embedding MLP: 205 | nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) 206 | nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) 207 | 208 | # Zero-out adaLN modulation layers in DiT blocks: 209 | for block in self.blocks: 210 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0) 211 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0) 212 | 213 | # Zero-out output layers: 214 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) 215 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) 216 | nn.init.constant_(self.final_layer.linear.weight, 0) 217 | nn.init.constant_(self.final_layer.linear.bias, 0) 218 | 219 | def unpatchify(self, x): 220 | """ 221 | x: (N, T, patch_size**2 * C) 222 | imgs: (N, H, W, C) 223 | """ 224 | c = self.out_channels 225 | p = self.x_embedder.patch_size[0] 226 | h = w = int(x.shape[1] ** 0.5) 227 | assert h * w == x.shape[1] 228 | 229 | x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) 230 | x = torch.einsum('nhwpqc->nchpwq', x) 231 | imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p)) 232 | return imgs 233 | 234 | def forward_raw(self, x, t, y): 235 | """ 236 | Forward pass of DiT. 237 | x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) 238 | t: (N,) tensor of diffusion timesteps 239 | y: (N,) tensor of class labels 240 | """ 241 | x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2 242 | t = self.t_embedder(t) # (N, D) 243 | y = self.y_embedder(y, self.training) # (N, D) 244 | c = t + y # (N, D) 245 | for block in self.blocks: 246 | x = block(x, c) # (N, T, D) 247 | x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels) 248 | x = self.unpatchify(x) # (N, out_channels, H, W) 249 | return x 250 | 251 | def forward(self, x, timesteps, context, y=None, **kwargs): 252 | """ 253 | Forward pass that adapts comfy input to original forward function 254 | x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) 255 | timesteps: (N,) tensor of diffusion timesteps 256 | context: (N, [LabelID]) conditioning 257 | y: extra conditioning. 258 | """ 259 | ## Remove outer array from cond 260 | context = context[:, 0] 261 | 262 | ## run original forward pass 263 | out = self.forward_raw( 264 | x = x.to(self.dtype), 265 | t = timesteps.to(self.dtype), 266 | y = context.to(torch.int), 267 | ) 268 | 269 | ## only return EPS 270 | out = out.to(torch.float) 271 | eps, rest = out[:, :self.in_channels], out[:, self.in_channels:] 272 | return eps 273 | 274 | ################################################################################# 275 | # Sine/Cosine Positional Embedding Functions # 276 | ################################################################################# 277 | # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py 278 | 279 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): 280 | """ 281 | grid_size: int of the grid height and width 282 | return: 283 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 284 | """ 285 | grid_h = np.arange(grid_size, dtype=np.float32) 286 | grid_w = np.arange(grid_size, dtype=np.float32) 287 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 288 | grid = np.stack(grid, axis=0) 289 | 290 | grid = grid.reshape([2, 1, grid_size, grid_size]) 291 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 292 | if cls_token and extra_tokens > 0: 293 | pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) 294 | return pos_embed 295 | 296 | 297 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 298 | assert embed_dim % 2 == 0 299 | 300 | # use half of dimensions to encode grid_h 301 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 302 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 303 | 304 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 305 | return emb 306 | 307 | 308 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 309 | """ 310 | embed_dim: output dimension for each position 311 | pos: a list of positions to be encoded: size (M,) 312 | out: (M, D) 313 | """ 314 | assert embed_dim % 2 == 0 315 | omega = np.arange(embed_dim // 2, dtype=np.float64) 316 | omega /= embed_dim / 2. 317 | omega = 1. / 10000**omega # (D/2,) 318 | 319 | pos = pos.reshape(-1) # (M,) 320 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 321 | 322 | emb_sin = np.sin(out) # (M, D/2) 323 | emb_cos = np.cos(out) # (M, D/2) 324 | 325 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 326 | return emb 327 | -------------------------------------------------------------------------------- /DiT/nodes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import folder_paths 5 | 6 | from .conf import dit_conf 7 | from .loader import load_dit 8 | 9 | class DitCheckpointLoader: 10 | @classmethod 11 | def INPUT_TYPES(s): 12 | return { 13 | "required": { 14 | "ckpt_name": (folder_paths.get_filename_list("checkpoints"),), 15 | "model": (list(dit_conf.keys()),), 16 | "image_size": ([256, 512],), 17 | # "num_classes": ("INT", {"default": 1000, "min": 0,}), 18 | } 19 | } 20 | RETURN_TYPES = ("MODEL",) 21 | RETURN_NAMES = ("model",) 22 | FUNCTION = "load_checkpoint" 23 | CATEGORY = "ExtraModels/DiT" 24 | TITLE = "DitCheckpointLoader" 25 | 26 | def load_checkpoint(self, ckpt_name, model, image_size): 27 | ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) 28 | model_conf = dit_conf[model] 29 | model_conf["unet_config"]["input_size"] = image_size // 8 30 | # model_conf["unet_config"]["num_classes"] = num_classes 31 | dit = load_dit( 32 | model_path = ckpt_path, 33 | model_conf = model_conf, 34 | ) 35 | return (dit,) 36 | 37 | # todo: this needs frontend code to display properly 38 | def get_label_data(label_file="labels/imagenet1000.json"): 39 | label_path = os.path.join( 40 | os.path.dirname(os.path.realpath(__file__)), 41 | label_file, 42 | ) 43 | label_data = {0: "None"} 44 | with open(label_path, "r") as f: 45 | label_data = json.loads(f.read()) 46 | return label_data 47 | label_data = get_label_data() 48 | 49 | class DiTCondLabelSelect: 50 | @classmethod 51 | def INPUT_TYPES(s): 52 | global label_data 53 | return { 54 | "required": { 55 | "model" : ("MODEL",), 56 | "label_name": (list(label_data.values()),), 57 | } 58 | } 59 | 60 | RETURN_TYPES = ("CONDITIONING",) 61 | RETURN_NAMES = ("class",) 62 | FUNCTION = "cond_label" 63 | CATEGORY = "ExtraModels/DiT" 64 | TITLE = "DiTCondLabelSelect" 65 | 66 | def cond_label(self, model, label_name): 67 | global label_data 68 | class_labels = [int(k) for k,v in label_data.items() if v == label_name] 69 | y = torch.tensor([[class_labels[0]]]).to(torch.int) 70 | return ([[y, {}]], ) 71 | 72 | class DiTCondLabelEmpty: 73 | @classmethod 74 | def INPUT_TYPES(s): 75 | global label_data 76 | return { 77 | "required": { 78 | "model" : ("MODEL",), 79 | } 80 | } 81 | 82 | RETURN_TYPES = ("CONDITIONING",) 83 | RETURN_NAMES = ("empty",) 84 | FUNCTION = "cond_empty" 85 | CATEGORY = "ExtraModels/DiT" 86 | TITLE = "DiTCondLabelEmpty" 87 | 88 | def cond_empty(self, model): 89 | # [ID of last class + 1] == [num_classes] 90 | y_null = model.model.model_config.unet_config["num_classes"] 91 | y = torch.tensor([[y_null]]).to(torch.int) 92 | return ([[y, {}]], ) 93 | 94 | NODE_CLASS_MAPPINGS = { 95 | "DitCheckpointLoader" : DitCheckpointLoader, 96 | "DiTCondLabelSelect" : DiTCondLabelSelect, 97 | "DiTCondLabelEmpty" : DiTCondLabelEmpty, 98 | } 99 | -------------------------------------------------------------------------------- /HunYuan/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI_ExtraModels/d8b11e401de830ccfb27fa84bdd0091b52408af8/HunYuan/__init__.py -------------------------------------------------------------------------------- /HunYuan/conf.py: -------------------------------------------------------------------------------- 1 | """ 2 | List of all DiT model types / settings 3 | """ 4 | sampling_settings = { 5 | "beta_schedule" : "linear", 6 | "linear_start" : 0.00085, 7 | "linear_end" : 0.03, 8 | "timesteps" : 1000, 9 | 'steps_offset': 1, 10 | 'clip_sample': False, 11 | 'clip_sample_range': 1.0, 12 | 'beta_start': 0.00085, 13 | 'beta_end': 0.03, 14 | 'prediction_type': 'v_prediction', 15 | } 16 | 17 | dit_conf = { 18 | "DiT-g/2": { # DiT-g/2 19 | "unet_config": { 20 | "depth" : 40, 21 | "num_heads" : 16, 22 | "patch_size" : 2, 23 | "hidden_size" : 1408, 24 | 'mlp_ratio': 4.3637, 25 | }, 26 | "sampling_settings" : sampling_settings, 27 | }, 28 | "DiT-XL/2": { # DiT_XL_2 29 | "unet_config": { 30 | "depth" : 28, 31 | "num_heads" : 16, 32 | "patch_size" : 2, 33 | "hidden_size" : 1152, 34 | }, 35 | "sampling_settings" : sampling_settings, 36 | }, 37 | "DiT-L/2": { # DiT_L_2 38 | "unet_config": { 39 | "depth" : 24, 40 | "num_heads" : 16, 41 | "patch_size" : 2, 42 | "hidden_size" : 1024, 43 | }, 44 | "sampling_settings" : sampling_settings, 45 | }, 46 | "DiT-B/2": { # DiT_B_2 47 | "unet_config": { 48 | "depth" : 12, 49 | "num_heads" : 12, 50 | "patch_size" : 2, 51 | "hidden_size" : 768, 52 | }, 53 | "sampling_settings" : sampling_settings, 54 | }, 55 | } 56 | -------------------------------------------------------------------------------- /HunYuan/loader.py: -------------------------------------------------------------------------------- 1 | import comfy.supported_models_base 2 | import comfy.latent_formats 3 | import comfy.model_patcher 4 | import comfy.model_base 5 | import comfy.utils 6 | import torch 7 | from comfy import model_management 8 | from ..PixArt.diffusers_convert import convert_state_dict 9 | 10 | class EXM_DiT(comfy.supported_models_base.BASE): 11 | unet_config = {} 12 | unet_extra_config = {} 13 | latent_format = comfy.latent_formats.SDXL 14 | 15 | def __init__(self, model_conf): 16 | self.model_target = model_conf.get("target") 17 | self.unet_config = model_conf.get("unet_config", {}) 18 | self.sampling_settings = model_conf.get("sampling_settings", {}) 19 | self.latent_format = self.latent_format() 20 | # UNET is handled by extension 21 | self.unet_config["disable_unet_model_creation"] = True 22 | 23 | def model_type(self, state_dict, prefix=""): 24 | return comfy.model_base.ModelType.V_PREDICTION 25 | 26 | class EXM_Dit_Model(comfy.model_base.BaseModel): 27 | def __init__(self, *args, **kwargs): 28 | super().__init__(*args, **kwargs) 29 | 30 | def extra_conds(self, **kwargs): 31 | out = super().extra_conds(**kwargs) 32 | 33 | clip_prompt_embeds = kwargs.get("clip_prompt_embeds", None) 34 | if clip_prompt_embeds is not None: 35 | out["clip_prompt_embeds"] = comfy.conds.CONDRegular(torch.tensor(clip_prompt_embeds)) 36 | 37 | clip_attention_mask = kwargs.get("clip_attention_mask", None) 38 | if clip_attention_mask is not None: 39 | out["clip_attention_mask"] = comfy.conds.CONDRegular(torch.tensor(clip_attention_mask)) 40 | 41 | mt5_prompt_embeds = kwargs.get("mt5_prompt_embeds", None) 42 | if mt5_prompt_embeds is not None: 43 | out["mt5_prompt_embeds"] = comfy.conds.CONDRegular(torch.tensor(mt5_prompt_embeds)) 44 | 45 | mt5_attention_mask = kwargs.get("mt5_attention_mask", None) 46 | if mt5_attention_mask is not None: 47 | out["mt5_attention_mask"] = comfy.conds.CONDRegular(torch.tensor(mt5_attention_mask)) 48 | 49 | return out 50 | 51 | def load_dit(model_path, model_conf): 52 | from comfy.diffusers_convert import convert_unet_state_dict 53 | state_dict = comfy.utils.load_torch_file(model_path) 54 | #state_dict=convert_unet_state_dict(state_dict) 55 | #state_dict = state_dict.get("model", state_dict) 56 | 57 | parameters = comfy.utils.calculate_parameters(state_dict) 58 | unet_dtype = torch.float16 #model_management.unet_dtype(model_params=parameters) 59 | load_device = comfy.model_management.get_torch_device() 60 | offload_device = comfy.model_management.unet_offload_device() 61 | 62 | # ignore fp8/etc and use directly for now 63 | #manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device) 64 | #if manual_cast_dtype: 65 | # print(f"DiT: falling back to {manual_cast_dtype}") 66 | # unet_dtype = manual_cast_dtype 67 | 68 | #model_conf["unet_config"]["num_classes"] = state_dict["y_embedder.embedding_table.weight"].shape[0] - 1 # adj. for empty 69 | 70 | model_conf = EXM_DiT(model_conf) 71 | 72 | model = EXM_Dit_Model( # same as comfy.model_base.BaseModel 73 | model_conf, 74 | model_type=comfy.model_base.ModelType.V_PREDICTION, 75 | device=model_management.get_torch_device() 76 | ) 77 | 78 | from .models.models import HunYuan 79 | model.diffusion_model = HunYuan(**model_conf.unet_config) 80 | model.latent_format = comfy.latent_formats.SDXL() 81 | 82 | model.diffusion_model.load_state_dict(state_dict) 83 | model.diffusion_model.dtype = unet_dtype 84 | model.diffusion_model.eval() 85 | model.diffusion_model.to(unet_dtype) 86 | 87 | model_patcher = comfy.model_patcher.ModelPatcher( 88 | model, 89 | load_device = load_device, 90 | offload_device = offload_device, 91 | current_device = "cpu", 92 | ) 93 | return model_patcher 94 | -------------------------------------------------------------------------------- /HunYuan/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI_ExtraModels/d8b11e401de830ccfb27fa84bdd0091b52408af8/HunYuan/models/__init__.py -------------------------------------------------------------------------------- /HunYuan/models/embedders.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from einops import repeat 5 | 6 | from timm.models.layers import to_2tuple 7 | 8 | 9 | class PatchEmbed(nn.Module): 10 | """ 2D Image to Patch Embedding 11 | 12 | Image to Patch Embedding using Conv2d 13 | 14 | A convolution based approach to patchifying a 2D image w/ embedding projection. 15 | 16 | Based on the impl in https://github.com/google-research/vision_transformer 17 | 18 | Hacked together by / Copyright 2020 Ross Wightman 19 | 20 | Remove the _assert function in forward function to be compatible with multi-resolution images. 21 | """ 22 | def __init__( 23 | self, 24 | img_size=224, 25 | patch_size=16, 26 | in_chans=3, 27 | embed_dim=768, 28 | norm_layer=None, 29 | flatten=True, 30 | bias=True, 31 | ): 32 | super().__init__() 33 | if isinstance(img_size, int): 34 | img_size = to_2tuple(img_size) 35 | elif isinstance(img_size, (tuple, list)) and len(img_size) == 2: 36 | img_size = tuple(img_size) 37 | else: 38 | raise ValueError(f"img_size must be int or tuple/list of length 2. Got {img_size}") 39 | patch_size = to_2tuple(patch_size) 40 | self.img_size = img_size 41 | self.patch_size = patch_size 42 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 43 | self.num_patches = self.grid_size[0] * self.grid_size[1] 44 | self.flatten = flatten 45 | 46 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) 47 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 48 | 49 | def update_image_size(self, img_size): 50 | self.img_size = img_size 51 | self.grid_size = (img_size[0] // self.patch_size[0], img_size[1] // self.patch_size[1]) 52 | self.num_patches = self.grid_size[0] * self.grid_size[1] 53 | 54 | def forward(self, x): 55 | # B, C, H, W = x.shape 56 | # _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).") 57 | # _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).") 58 | x = self.proj(x) 59 | if self.flatten: 60 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC 61 | x = self.norm(x) 62 | return x 63 | 64 | 65 | def timestep_embedding(t, dim, max_period=10000, repeat_only=False): 66 | """ 67 | Create sinusoidal timestep embeddings. 68 | :param t: a 1-D Tensor of N indices, one per batch element. 69 | These may be fractional. 70 | :param dim: the dimension of the output. 71 | :param max_period: controls the minimum frequency of the embeddings. 72 | :return: an (N, D) Tensor of positional embeddings. 73 | """ 74 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py 75 | if not repeat_only: 76 | half = dim // 2 77 | freqs = torch.exp( 78 | -math.log(max_period) 79 | * torch.arange(start=0, end=half, dtype=torch.float32) 80 | / half 81 | ).to(device=t.device) # size: [dim/2], 一个指数衰减的曲线 82 | args = t[:, None].float() * freqs[None] 83 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 84 | if dim % 2: 85 | embedding = torch.cat( 86 | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 87 | ) 88 | else: 89 | embedding = repeat(t, "b -> b d", d=dim) 90 | return embedding 91 | 92 | 93 | class TimestepEmbedder(nn.Module): 94 | """ 95 | Embeds scalar timesteps into vector representations. 96 | """ 97 | def __init__(self, hidden_size, frequency_embedding_size=256, out_size=None): 98 | super().__init__() 99 | if out_size is None: 100 | out_size = hidden_size 101 | self.mlp = nn.Sequential( 102 | nn.Linear(frequency_embedding_size, hidden_size, bias=True), 103 | nn.SiLU(), 104 | nn.Linear(hidden_size, out_size, bias=True), 105 | ) 106 | self.frequency_embedding_size = frequency_embedding_size 107 | 108 | def forward(self, t): 109 | t_freq = timestep_embedding(t, self.frequency_embedding_size).type(self.mlp[0].weight.dtype) 110 | t_emb = self.mlp(t_freq) 111 | return t_emb 112 | -------------------------------------------------------------------------------- /HunYuan/models/norm_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class RMSNorm(nn.Module): 6 | def __init__(self, dim: int, elementwise_affine=True, eps: float = 1e-6): 7 | """ 8 | Initialize the RMSNorm normalization layer. 9 | 10 | Args: 11 | dim (int): The dimension of the input tensor. 12 | eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. 13 | 14 | Attributes: 15 | eps (float): A small value added to the denominator for numerical stability. 16 | weight (nn.Parameter): Learnable scaling parameter. 17 | 18 | """ 19 | super().__init__() 20 | self.eps = eps 21 | if elementwise_affine: 22 | self.weight = nn.Parameter(torch.ones(dim)) 23 | 24 | def _norm(self, x): 25 | """ 26 | Apply the RMSNorm normalization to the input tensor. 27 | 28 | Args: 29 | x (torch.Tensor): The input tensor. 30 | 31 | Returns: 32 | torch.Tensor: The normalized tensor. 33 | 34 | """ 35 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 36 | 37 | def forward(self, x): 38 | """ 39 | Forward pass through the RMSNorm layer. 40 | 41 | Args: 42 | x (torch.Tensor): The input tensor. 43 | 44 | Returns: 45 | torch.Tensor: The output tensor after applying RMSNorm. 46 | 47 | """ 48 | output = self._norm(x.float()).type_as(x) 49 | if hasattr(self, "weight"): 50 | output = output * self.weight 51 | return output 52 | 53 | 54 | class GroupNorm32(nn.GroupNorm): 55 | def __init__(self, num_groups, num_channels, eps=1e-5, dtype=None): 56 | super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps, dtype=dtype) 57 | 58 | def forward(self, x): 59 | y = super().forward(x).to(x.dtype) 60 | return y 61 | 62 | def normalization(channels, dtype=None): 63 | """ 64 | Make a standard normalization layer. 65 | :param channels: number of input channels. 66 | :return: an nn.Module for normalization. 67 | """ 68 | return GroupNorm32(num_channels=channels, num_groups=32, dtype=dtype) 69 | -------------------------------------------------------------------------------- /HunYuan/models/poolers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class AttentionPool(nn.Module): 7 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 8 | super().__init__() 9 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim ** 0.5) 10 | self.k_proj = nn.Linear(embed_dim, embed_dim) 11 | self.q_proj = nn.Linear(embed_dim, embed_dim) 12 | self.v_proj = nn.Linear(embed_dim, embed_dim) 13 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 14 | self.num_heads = num_heads 15 | 16 | def forward(self, x): 17 | x = x.permute(1, 0, 2) # NLC -> LNC 18 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC 19 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC 20 | x, _ = F.multi_head_attention_forward( 21 | query=x[:1], key=x, value=x, 22 | embed_dim_to_check=x.shape[-1], 23 | num_heads=self.num_heads, 24 | q_proj_weight=self.q_proj.weight, 25 | k_proj_weight=self.k_proj.weight, 26 | v_proj_weight=self.v_proj.weight, 27 | in_proj_weight=None, 28 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 29 | bias_k=None, 30 | bias_v=None, 31 | add_zero_attn=False, 32 | dropout_p=0, 33 | out_proj_weight=self.c_proj.weight, 34 | out_proj_bias=self.c_proj.bias, 35 | use_separate_proj_weight=True, 36 | training=self.training, 37 | need_weights=False 38 | ) 39 | return x.squeeze(0) 40 | -------------------------------------------------------------------------------- /HunYuan/models/posemb_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from typing import Union 4 | 5 | 6 | def _to_tuple(x): 7 | if isinstance(x, int): 8 | return x, x 9 | else: 10 | return x 11 | 12 | 13 | def get_fill_resize_and_crop(src, tgt): # src 来源的分辨率 tgt base 分辨率 14 | th, tw = _to_tuple(tgt) 15 | h, w = _to_tuple(src) 16 | 17 | tr = th / tw # base 分辨率 18 | r = h / w # 目标分辨率 19 | 20 | # resize 21 | if r > tr: 22 | resize_height = th 23 | resize_width = int(round(th / h * w)) 24 | else: 25 | resize_width = tw 26 | resize_height = int(round(tw / w * h)) # 根据base分辨率,将目标分辨率resize下来 27 | 28 | crop_top = int(round((th - resize_height) / 2.0)) 29 | crop_left = int(round((tw - resize_width) / 2.0)) 30 | 31 | return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) 32 | 33 | 34 | def get_meshgrid(start, *args): 35 | if len(args) == 0: 36 | # start is grid_size 37 | num = _to_tuple(start) 38 | start = (0, 0) 39 | stop = num 40 | elif len(args) == 1: 41 | # start is start, args[0] is stop, step is 1 42 | start = _to_tuple(start) 43 | stop = _to_tuple(args[0]) 44 | num = (stop[0] - start[0], stop[1] - start[1]) 45 | elif len(args) == 2: 46 | # start is start, args[0] is stop, args[1] is num 47 | start = _to_tuple(start) # 左上角 eg: 12,0 48 | stop = _to_tuple(args[0]) # 右下角 eg: 20,32 49 | num = _to_tuple(args[1]) # 目标大小 eg: 32,124 50 | else: 51 | raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}") 52 | 53 | grid_h = np.linspace(start[0], stop[0], num[0], endpoint=False, dtype=np.float32) # 12-20 中间差值32份 0-32 中间差值124份 54 | grid_w = np.linspace(start[1], stop[1], num[1], endpoint=False, dtype=np.float32) 55 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 56 | grid = np.stack(grid, axis=0) # [2, W, H] 57 | return grid 58 | 59 | ################################################################################# 60 | # Sine/Cosine Positional Embedding Functions # 61 | ################################################################################# 62 | # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py 63 | 64 | def get_2d_sincos_pos_embed(embed_dim, start, *args, cls_token=False, extra_tokens=0): 65 | """ 66 | grid_size: int of the grid height and width 67 | return: 68 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 69 | """ 70 | grid = get_meshgrid(start, *args) # [2, H, w] 71 | # grid_h = np.arange(grid_size, dtype=np.float32) 72 | # grid_w = np.arange(grid_size, dtype=np.float32) 73 | # grid = np.meshgrid(grid_w, grid_h) # here w goes first 74 | # grid = np.stack(grid, axis=0) # [2, W, H] 75 | 76 | grid = grid.reshape([2, 1, *grid.shape[1:]]) 77 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 78 | if cls_token and extra_tokens > 0: 79 | pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) 80 | return pos_embed 81 | 82 | 83 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 84 | assert embed_dim % 2 == 0 85 | 86 | # use half of dimensions to encode grid_h 87 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 88 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 89 | 90 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 91 | return emb 92 | 93 | 94 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 95 | """ 96 | embed_dim: output dimension for each position 97 | pos: a list of positions to be encoded: size (W,H) 98 | out: (M, D) 99 | """ 100 | assert embed_dim % 2 == 0 101 | omega = np.arange(embed_dim // 2, dtype=np.float64) 102 | omega /= embed_dim / 2. 103 | omega = 1. / 10000**omega # (D/2,) 104 | 105 | pos = pos.reshape(-1) # (M,) 106 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 107 | 108 | emb_sin = np.sin(out) # (M, D/2) 109 | emb_cos = np.cos(out) # (M, D/2) 110 | 111 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 112 | return emb 113 | 114 | 115 | ################################################################################# 116 | # Rotary Positional Embedding Functions # 117 | ################################################################################# 118 | # https://github.com/facebookresearch/llama/blob/main/llama/model.py#L443 119 | 120 | def get_2d_rotary_pos_embed(embed_dim, start, *args, use_real=True): 121 | """ 122 | This is a 2d version of precompute_freqs_cis, which is a RoPE for image tokens with 2d structure. 123 | 124 | Parameters 125 | ---------- 126 | embed_dim: int 127 | embedding dimension size 128 | start: int or tuple of int 129 | If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop, step is 1; 130 | If len(args) == 2, start is start, args[0] is stop, args[1] is num. 131 | use_real: bool 132 | If True, return real part and imaginary part separately. Otherwise, return complex numbers. 133 | 134 | Returns 135 | ------- 136 | pos_embed: torch.Tensor 137 | [HW, D/2] 138 | """ 139 | grid = get_meshgrid(start, *args) # [2, H, w] 140 | grid = grid.reshape([2, 1, *grid.shape[1:]]) # 返回一个采样矩阵 分辨率与目标分辨率一致 141 | pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real) 142 | return pos_embed 143 | 144 | 145 | def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False): 146 | assert embed_dim % 4 == 0 147 | 148 | # use half of dimensions to encode grid_h 149 | emb_h = get_1d_rotary_pos_embed(embed_dim // 2, grid[0].reshape(-1), use_real=use_real) # (H*W, D/4) 150 | emb_w = get_1d_rotary_pos_embed(embed_dim // 2, grid[1].reshape(-1), use_real=use_real) # (H*W, D/4) 151 | 152 | if use_real: 153 | cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D/2) 154 | sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D/2) 155 | return cos, sin 156 | else: 157 | emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2) 158 | return emb 159 | 160 | 161 | def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False): 162 | """ 163 | Precompute the frequency tensor for complex exponentials (cis) with given dimensions. 164 | 165 | This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' 166 | and the end index 'end'. The 'theta' parameter scales the frequencies. 167 | The returned tensor contains complex values in complex64 data type. 168 | 169 | Args: 170 | dim (int): Dimension of the frequency tensor. 171 | pos (np.ndarray, int): Position indices for the frequency tensor. [S] or scalar 172 | theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. 173 | use_real (bool, optional): If True, return real part and imaginary part separately. 174 | Otherwise, return complex numbers. 175 | 176 | Returns: 177 | torch.Tensor: Precomputed frequency tensor with complex exponentials. [S, D/2] 178 | 179 | """ 180 | if isinstance(pos, int): 181 | pos = np.arange(pos) 182 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2] 183 | t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S] 184 | freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2] 185 | if use_real: 186 | freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D] 187 | freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D] 188 | return freqs_cos, freqs_sin 189 | else: 190 | freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] 191 | return freqs_cis 192 | 193 | 194 | 195 | def calc_sizes(rope_img, patch_size, th, tw): 196 | """ 计算 RoPE 的尺寸. """ 197 | if rope_img == 'extend': 198 | # 拓展模式 199 | sub_args = [(th, tw)] 200 | elif rope_img.startswith('base'): 201 | # 基于一个尺寸, 其他尺寸插值获得. 202 | base_size = int(rope_img[4:]) // 8 // patch_size # 基于512作为base,其他根据512差值得到 203 | start, stop = get_fill_resize_and_crop((th, tw), base_size) # 需要在32x32里面 crop的左上角和右下角 204 | sub_args = [start, stop, (th, tw)] 205 | else: 206 | raise ValueError(f"Unknown rope_img: {rope_img}") 207 | return sub_args 208 | 209 | 210 | def init_image_posemb(rope_img, 211 | resolutions, 212 | patch_size, 213 | hidden_size, 214 | num_heads, 215 | log_fn, 216 | rope_real=True, 217 | ): 218 | freqs_cis_img = {} 219 | for reso in resolutions: 220 | th, tw = reso.height // 8 // patch_size, reso.width // 8 // patch_size 221 | sub_args = calc_sizes(rope_img, patch_size, th, tw) # [左上角, 右下角, 目标高宽] 需要在32x32里面 crop的左上角和右下角 222 | freqs_cis_img[str(reso)] = get_2d_rotary_pos_embed(hidden_size // num_heads, *sub_args, use_real=rope_real) 223 | log_fn(f" Using image RoPE ({rope_img}) ({'real' if rope_real else 'complex'}): {sub_args} | ({reso}) " 224 | f"{freqs_cis_img[str(reso)][0].shape if rope_real else freqs_cis_img[str(reso)].shape}") 225 | return freqs_cis_img 226 | -------------------------------------------------------------------------------- /HunYuan/models/text_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from transformers import AutoTokenizer, T5EncoderModel, T5ForConditionalGeneration 4 | 5 | 6 | class MT5Embedder(nn.Module): 7 | available_models = ["t5-v1_1-xxl"] 8 | 9 | def __init__( 10 | self, 11 | model_dir="t5-v1_1-xxl", 12 | model_kwargs=None, 13 | torch_dtype=None, 14 | use_tokenizer_only=False, 15 | conditional_generation=False, 16 | max_length=128, 17 | device="cuda", 18 | ): 19 | super().__init__() 20 | self.device = device #"cuda" if torch.cuda.is_available() else "cpu" 21 | self.torch_dtype = torch_dtype or torch.bfloat16 22 | self.max_length = max_length 23 | if model_kwargs is None: 24 | model_kwargs = { 25 | # "low_cpu_mem_usage": True, 26 | "torch_dtype": self.torch_dtype, 27 | } 28 | model_kwargs["device_map"] = {"shared": self.device, "encoder": self.device} 29 | self.tokenizer = AutoTokenizer.from_pretrained(model_dir) 30 | if use_tokenizer_only: 31 | return 32 | if conditional_generation: 33 | self.model = None 34 | self.generation_model = T5ForConditionalGeneration.from_pretrained( 35 | model_dir 36 | ) 37 | return 38 | self.model = T5EncoderModel.from_pretrained(model_dir, **model_kwargs).eval().to(self.torch_dtype) 39 | 40 | def get_tokens_and_mask(self, texts): 41 | text_tokens_and_mask = self.tokenizer( 42 | texts, 43 | max_length=self.max_length, 44 | padding="max_length", 45 | truncation=True, 46 | return_attention_mask=True, 47 | add_special_tokens=True, 48 | return_tensors="pt", 49 | ) 50 | tokens = text_tokens_and_mask["input_ids"][0] 51 | mask = text_tokens_and_mask["attention_mask"][0] 52 | # tokens = torch.tensor(tokens).clone().detach() 53 | # mask = torch.tensor(mask, dtype=torch.bool).clone().detach() 54 | return tokens, mask 55 | 56 | def get_text_embeddings(self, texts, attention_mask=True, layer_index=-1): 57 | text_tokens_and_mask = self.tokenizer( 58 | texts, 59 | max_length=self.max_length, 60 | padding="max_length", 61 | truncation=True, 62 | return_attention_mask=True, 63 | add_special_tokens=True, 64 | return_tensors="pt", 65 | ) 66 | 67 | with torch.no_grad(): 68 | outputs = self.model( 69 | input_ids=text_tokens_and_mask["input_ids"].to(self.device), 70 | attention_mask=text_tokens_and_mask["attention_mask"].to(self.device) 71 | if attention_mask 72 | else None, 73 | output_hidden_states=True, 74 | ) 75 | text_encoder_embs = outputs["hidden_states"][layer_index].detach() 76 | 77 | return text_encoder_embs, text_tokens_and_mask["attention_mask"].to(self.device) 78 | 79 | @torch.no_grad() 80 | def __call__(self, tokens, attention_mask, layer_index=-1): 81 | with torch.cuda.amp.autocast(): 82 | outputs = self.model( 83 | input_ids=tokens, 84 | attention_mask=attention_mask, 85 | output_hidden_states=True, 86 | ) 87 | 88 | z = outputs.hidden_states[layer_index].detach() 89 | return z 90 | 91 | def general(self, text: str): 92 | # input_ids = input_ids = torch.tensor([list(text.encode("utf-8"))]) + num_special_tokens 93 | input_ids = self.tokenizer(text, max_length=128).input_ids 94 | print(input_ids) 95 | outputs = self.generation_model(input_ids) 96 | return outputs -------------------------------------------------------------------------------- /HunYuan/nodes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import folder_paths 5 | 6 | from .conf import dit_conf 7 | from .loader import load_dit 8 | from .models.text_encoder import MT5Embedder 9 | from transformers import BertModel, BertTokenizer 10 | 11 | class MT5Loader: 12 | @classmethod 13 | def INPUT_TYPES(s): 14 | return { 15 | "required": { 16 | "HunyuanDiTfolder": (os.listdir(os.path.join(folder_paths.models_dir,"diffusers")), {"default": "HunyuanDiT"}), 17 | "device": (["cpu", "cuda"], {"default": "cuda"}), 18 | } 19 | } 20 | RETURN_TYPES = ("MT5","CLIP","Tokenizer",) 21 | FUNCTION = "load_model" 22 | CATEGORY = "ExtraModels/T5" 23 | TITLE = "MT5 Loader" 24 | 25 | def load_model(self, HunyuanDiTfolder, device): 26 | HunyuanDiTfolder=os.path.join(os.path.join(folder_paths.models_dir,"diffusers"),HunyuanDiTfolder) 27 | mt5folder=os.path.join(HunyuanDiTfolder,"t2i/mt5") 28 | clipfolder=os.path.join(HunyuanDiTfolder,"t2i/clip_text_encoder") 29 | tokenizerfolder=os.path.join(HunyuanDiTfolder,"t2i/tokenizer") 30 | torch_dtype=torch.float16 31 | if device=="cpu": 32 | torch_dtype=torch.float32 33 | clip_text_encoder = BertModel.from_pretrained(str(clipfolder), False, revision=None).to(device) 34 | tokenizer = BertTokenizer.from_pretrained(str(tokenizerfolder)) 35 | embedder_t5 = MT5Embedder(mt5folder, torch_dtype=torch_dtype, max_length=256, device=device) 36 | 37 | return (embedder_t5,clip_text_encoder,tokenizer,) 38 | 39 | def clip_get_text_embeddings(clip_text_encoder,tokenizer,text,device): 40 | max_length=tokenizer.model_max_length 41 | text_inputs = tokenizer( 42 | text, 43 | padding="max_length", 44 | max_length=max_length, 45 | truncation=True, 46 | return_attention_mask=True, 47 | return_tensors="pt", 48 | ) 49 | text_input_ids = text_inputs.input_ids 50 | attention_mask = text_inputs.attention_mask.to(device) 51 | prompt_embeds = clip_text_encoder( 52 | text_input_ids.to(device), 53 | attention_mask=attention_mask, 54 | ) 55 | prompt_embeds = prompt_embeds[0] 56 | attention_mask = attention_mask.repeat(1, 1) 57 | 58 | return (prompt_embeds,attention_mask) 59 | 60 | class MT5TextEncode: 61 | @classmethod 62 | def INPUT_TYPES(s): 63 | return { 64 | "required": { 65 | "embedder_t5": ("MT5",), 66 | "clip_text_encoder": ("CLIP",), 67 | "tokenizer": ("Tokenizer",), 68 | "prompt": ("STRING", {"multiline": True}), 69 | "negative_prompt": ("STRING", {"multiline": True}), 70 | } 71 | } 72 | 73 | RETURN_TYPES = ("CONDITIONING","CONDITIONING",) 74 | RETURN_NAMES = ("positive","negative",) 75 | FUNCTION = "encode" 76 | CATEGORY = "ExtraModels/T5" 77 | TITLE = "MT5 Text Encode" 78 | 79 | def encode(self, embedder_t5, clip_text_encoder, tokenizer, prompt, negative_prompt): 80 | print(f'prompt{prompt}') 81 | clip_prompt_embeds,clip_attention_mask = clip_get_text_embeddings(clip_text_encoder,tokenizer,prompt,embedder_t5.device) 82 | 83 | clip_negative_prompt_embeds,clip_negative_attention_mask = clip_get_text_embeddings(clip_text_encoder,tokenizer,negative_prompt,embedder_t5.device) 84 | 85 | mt5_prompt_embeds,mt5_attention_mask = embedder_t5.get_text_embeddings(prompt) 86 | 87 | mt5_negative_prompt_embeds,mt5_negative_attention_mask = embedder_t5.get_text_embeddings(negative_prompt) 88 | 89 | return ([[clip_prompt_embeds, {"clip_prompt_embeds":clip_prompt_embeds,"clip_attention_mask":clip_attention_mask,"mt5_prompt_embeds":mt5_prompt_embeds,"mt5_attention_mask":mt5_attention_mask}]],[[clip_negative_prompt_embeds, {"clip_prompt_embeds":clip_negative_prompt_embeds,"clip_attention_mask":clip_negative_attention_mask,"mt5_prompt_embeds":mt5_negative_prompt_embeds,"mt5_attention_mask":mt5_negative_attention_mask}]], ) 90 | 91 | class HunYuanDitCheckpointLoader: 92 | @classmethod 93 | def INPUT_TYPES(s): 94 | return { 95 | "required": { 96 | "HunyuanDiTfolder": (os.listdir(os.path.join(folder_paths.models_dir,"diffusers")), {"default": "HunyuanDiT"}), 97 | "model": (list(dit_conf.keys()),), 98 | "image_size_width": ("INT",{"default":1024}), 99 | "image_size_height": ("INT",{"default":1024}), 100 | # "num_classes": ("INT", {"default": 1000, "min": 0,}), 101 | } 102 | } 103 | RETURN_TYPES = ("MODEL",) 104 | RETURN_NAMES = ("model",) 105 | FUNCTION = "load_checkpoint" 106 | CATEGORY = "ExtraModels/DiT" 107 | TITLE = "HunYuanDitCheckpointLoader" 108 | 109 | def load_checkpoint(self, HunyuanDiTfolder, model, image_size_width, image_size_height): 110 | image_size_width = int((image_size_width // 16) * 16) 111 | image_size_height = int((image_size_height // 16) * 16) 112 | HunyuanDiTfolder=os.path.join(os.path.join(folder_paths.models_dir,"diffusers"),HunyuanDiTfolder) 113 | ckpt_path=os.path.join(HunyuanDiTfolder,"t2i/model/pytorch_model_ema.pt") 114 | model_conf = dit_conf[model] 115 | model_conf["unet_config"]["input_size"] = (image_size_height // 8, image_size_width // 8) 116 | # model_conf["unet_config"]["num_classes"] = num_classes 117 | dit = load_dit( 118 | model_path = ckpt_path, 119 | model_conf = model_conf, 120 | ) 121 | return (dit,) 122 | 123 | NODE_CLASS_MAPPINGS = { 124 | "HunYuanDitCheckpointLoader" : HunYuanDitCheckpointLoader, 125 | "MT5Loader" : MT5Loader, 126 | "MT5TextEncode" : MT5TextEncode, 127 | } 128 | -------------------------------------------------------------------------------- /HunYuan/wf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI_ExtraModels/d8b11e401de830ccfb27fa84bdd0091b52408af8/HunYuan/wf.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /PixArt/conf.py: -------------------------------------------------------------------------------- 1 | """ 2 | List of all PixArt model types / settings 3 | """ 4 | 5 | sampling_settings = { 6 | "beta_schedule" : "sqrt_linear", 7 | "linear_start" : 0.0001, 8 | "linear_end" : 0.02, 9 | "timesteps" : 1000, 10 | } 11 | 12 | pixart_conf = { 13 | "PixArtMS_XL_2": { # models/PixArtMS 14 | "target": "PixArtMS", 15 | "unet_config": { 16 | "input_size" : 1024//8, 17 | "depth" : 28, 18 | "num_heads" : 16, 19 | "patch_size" : 2, 20 | "hidden_size" : 1152, 21 | "pe_interpolation": 2, 22 | }, 23 | "sampling_settings" : sampling_settings, 24 | }, 25 | "PixArtMS_Sigma_XL_2": { 26 | "target": "PixArtMSSigma", 27 | "unet_config": { 28 | "input_size" : 1024//8, 29 | "token_num" : 300, 30 | "depth" : 28, 31 | "num_heads" : 16, 32 | "patch_size" : 2, 33 | "hidden_size" : 1152, 34 | "micro_condition": False, 35 | "pe_interpolation": 2, 36 | "model_max_length": 300, 37 | }, 38 | "sampling_settings" : sampling_settings, 39 | }, 40 | "PixArtMS_Sigma_XL_2_2K": { 41 | "target": "PixArtMSSigma", 42 | "unet_config": { 43 | "input_size" : 2048//8, 44 | "token_num" : 300, 45 | "depth" : 28, 46 | "num_heads" : 16, 47 | "patch_size" : 2, 48 | "hidden_size" : 1152, 49 | "micro_condition": False, 50 | "pe_interpolation": 4, 51 | "model_max_length": 300, 52 | }, 53 | "sampling_settings" : sampling_settings, 54 | }, 55 | "PixArt_XL_2": { # models/PixArt 56 | "target": "PixArt", 57 | "unet_config": { 58 | "input_size" : 512//8, 59 | "token_num" : 120, 60 | "depth" : 28, 61 | "num_heads" : 16, 62 | "patch_size" : 2, 63 | "hidden_size" : 1152, 64 | "pe_interpolation": 1, 65 | }, 66 | "sampling_settings" : sampling_settings, 67 | }, 68 | } 69 | 70 | pixart_conf.update({ # controlnet models 71 | "ControlPixArtHalf": { 72 | "target": "ControlPixArtHalf", 73 | "unet_config": pixart_conf["PixArt_XL_2"]["unet_config"], 74 | "sampling_settings": pixart_conf["PixArt_XL_2"]["sampling_settings"], 75 | }, 76 | "ControlPixArtMSHalf": { 77 | "target": "ControlPixArtMSHalf", 78 | "unet_config": pixart_conf["PixArtMS_XL_2"]["unet_config"], 79 | "sampling_settings": pixart_conf["PixArtMS_XL_2"]["sampling_settings"], 80 | } 81 | }) 82 | 83 | pixart_res = { 84 | "PixArtMS_XL_2": { # models/PixArtMS 1024x1024 85 | '0.25': [512, 2048], '0.26': [512, 1984], '0.27': [512, 1920], '0.28': [512, 1856], 86 | '0.32': [576, 1792], '0.33': [576, 1728], '0.35': [576, 1664], '0.40': [640, 1600], 87 | '0.42': [640, 1536], '0.48': [704, 1472], '0.50': [704, 1408], '0.52': [704, 1344], 88 | '0.57': [768, 1344], '0.60': [768, 1280], '0.68': [832, 1216], '0.72': [832, 1152], 89 | '0.78': [896, 1152], '0.82': [896, 1088], '0.88': [960, 1088], '0.94': [960, 1024], 90 | '1.00': [1024,1024], '1.07': [1024, 960], '1.13': [1088, 960], '1.21': [1088, 896], 91 | '1.29': [1152, 896], '1.38': [1152, 832], '1.46': [1216, 832], '1.67': [1280, 768], 92 | '1.75': [1344, 768], '2.00': [1408, 704], '2.09': [1472, 704], '2.40': [1536, 640], 93 | '2.50': [1600, 640], '2.89': [1664, 576], '3.00': [1728, 576], '3.11': [1792, 576], 94 | '3.62': [1856, 512], '3.75': [1920, 512], '3.88': [1984, 512], '4.00': [2048, 512], 95 | }, 96 | "PixArt_XL_2": { # models/PixArt 512x512 97 | '0.25': [256,1024], '0.26': [256, 992], '0.27': [256, 960], '0.28': [256, 928], 98 | '0.32': [288, 896], '0.33': [288, 864], '0.35': [288, 832], '0.40': [320, 800], 99 | '0.42': [320, 768], '0.48': [352, 736], '0.50': [352, 704], '0.52': [352, 672], 100 | '0.57': [384, 672], '0.60': [384, 640], '0.68': [416, 608], '0.72': [416, 576], 101 | '0.78': [448, 576], '0.82': [448, 544], '0.88': [480, 544], '0.94': [480, 512], 102 | '1.00': [512, 512], '1.07': [512, 480], '1.13': [544, 480], '1.21': [544, 448], 103 | '1.29': [576, 448], '1.38': [576, 416], '1.46': [608, 416], '1.67': [640, 384], 104 | '1.75': [672, 384], '2.00': [704, 352], '2.09': [736, 352], '2.40': [768, 320], 105 | '2.50': [800, 320], '2.89': [832, 288], '3.00': [864, 288], '3.11': [896, 288], 106 | '3.62': [928, 256], '3.75': [960, 256], '3.88': [992, 256], '4.00': [1024,256] 107 | }, 108 | "PixArtMS_Sigma_XL_2_2K": { 109 | '0.25': [1024, 4096], '0.26': [1024, 3968], '0.27': [1024, 3840], '0.28': [1024, 3712], 110 | '0.32': [1152, 3584], '0.33': [1152, 3456], '0.35': [1152, 3328], '0.40': [1280, 3200], 111 | '0.42': [1280, 3072], '0.48': [1408, 2944], '0.50': [1408, 2816], '0.52': [1408, 2688], 112 | '0.57': [1536, 2688], '0.60': [1536, 2560], '0.68': [1664, 2432], '0.72': [1664, 2304], 113 | '0.78': [1792, 2304], '0.82': [1792, 2176], '0.88': [1920, 2176], '0.94': [1920, 2048], 114 | '1.00': [2048, 2048], '1.07': [2048, 1920], '1.13': [2176, 1920], '1.21': [2176, 1792], 115 | '1.29': [2304, 1792], '1.38': [2304, 1664], '1.46': [2432, 1664], '1.67': [2560, 1536], 116 | '1.75': [2688, 1536], '2.00': [2816, 1408], '2.09': [2944, 1408], '2.40': [3072, 1280], 117 | '2.50': [3200, 1280], '2.89': [3328, 1152], '3.00': [3456, 1152], '3.11': [3584, 1152], 118 | '3.62': [3712, 1024], '3.75': [3840, 1024], '3.88': [3968, 1024], '4.00': [4096, 1024] 119 | } 120 | } 121 | # These should be the same 122 | pixart_res.update({ 123 | "PixArtMS_Sigma_XL_2": pixart_res["PixArtMS_XL_2"], 124 | "PixArtMS_Sigma_XL_2_512": pixart_res["PixArt_XL_2"], 125 | }) 126 | -------------------------------------------------------------------------------- /PixArt/diffusers_convert.py: -------------------------------------------------------------------------------- 1 | # For using the diffusers format weights 2 | # Based on the original ComfyUI function + 3 | # https://github.com/PixArt-alpha/PixArt-alpha/blob/master/tools/convert_pixart_alpha_to_diffusers.py 4 | import torch 5 | 6 | conversion_map = [ # main SD conversion map (PixArt reference, HF Diffusers) 7 | # Patch embeddings 8 | ("x_embedder.proj.weight", "pos_embed.proj.weight"), 9 | ("x_embedder.proj.bias", "pos_embed.proj.bias"), 10 | # Caption projection 11 | ("y_embedder.y_embedding", "caption_projection.y_embedding"), 12 | ("y_embedder.y_proj.fc1.weight", "caption_projection.linear_1.weight"), 13 | ("y_embedder.y_proj.fc1.bias", "caption_projection.linear_1.bias"), 14 | ("y_embedder.y_proj.fc2.weight", "caption_projection.linear_2.weight"), 15 | ("y_embedder.y_proj.fc2.bias", "caption_projection.linear_2.bias"), 16 | # AdaLN-single LN 17 | ("t_embedder.mlp.0.weight", "adaln_single.emb.timestep_embedder.linear_1.weight"), 18 | ("t_embedder.mlp.0.bias", "adaln_single.emb.timestep_embedder.linear_1.bias"), 19 | ("t_embedder.mlp.2.weight", "adaln_single.emb.timestep_embedder.linear_2.weight"), 20 | ("t_embedder.mlp.2.bias", "adaln_single.emb.timestep_embedder.linear_2.bias"), 21 | # Shared norm 22 | ("t_block.1.weight", "adaln_single.linear.weight"), 23 | ("t_block.1.bias", "adaln_single.linear.bias"), 24 | # Final block 25 | ("final_layer.linear.weight", "proj_out.weight"), 26 | ("final_layer.linear.bias", "proj_out.bias"), 27 | ("final_layer.scale_shift_table", "scale_shift_table"), 28 | ] 29 | 30 | conversion_map_ms = [ # for multi_scale_train (MS) 31 | # Resolution 32 | ("csize_embedder.mlp.0.weight", "adaln_single.emb.resolution_embedder.linear_1.weight"), 33 | ("csize_embedder.mlp.0.bias", "adaln_single.emb.resolution_embedder.linear_1.bias"), 34 | ("csize_embedder.mlp.2.weight", "adaln_single.emb.resolution_embedder.linear_2.weight"), 35 | ("csize_embedder.mlp.2.bias", "adaln_single.emb.resolution_embedder.linear_2.bias"), 36 | # Aspect ratio 37 | ("ar_embedder.mlp.0.weight", "adaln_single.emb.aspect_ratio_embedder.linear_1.weight"), 38 | ("ar_embedder.mlp.0.bias", "adaln_single.emb.aspect_ratio_embedder.linear_1.bias"), 39 | ("ar_embedder.mlp.2.weight", "adaln_single.emb.aspect_ratio_embedder.linear_2.weight"), 40 | ("ar_embedder.mlp.2.bias", "adaln_single.emb.aspect_ratio_embedder.linear_2.bias"), 41 | ] 42 | 43 | # Add actual transformer blocks 44 | for depth in range(28): 45 | # Transformer blocks 46 | conversion_map += [ 47 | (f"blocks.{depth}.scale_shift_table", f"transformer_blocks.{depth}.scale_shift_table"), 48 | # Projection 49 | (f"blocks.{depth}.attn.proj.weight", f"transformer_blocks.{depth}.attn1.to_out.0.weight"), 50 | (f"blocks.{depth}.attn.proj.bias", f"transformer_blocks.{depth}.attn1.to_out.0.bias"), 51 | # Feed-forward 52 | (f"blocks.{depth}.mlp.fc1.weight", f"transformer_blocks.{depth}.ff.net.0.proj.weight"), 53 | (f"blocks.{depth}.mlp.fc1.bias", f"transformer_blocks.{depth}.ff.net.0.proj.bias"), 54 | (f"blocks.{depth}.mlp.fc2.weight", f"transformer_blocks.{depth}.ff.net.2.weight"), 55 | (f"blocks.{depth}.mlp.fc2.bias", f"transformer_blocks.{depth}.ff.net.2.bias"), 56 | # Cross-attention (proj) 57 | (f"blocks.{depth}.cross_attn.proj.weight" ,f"transformer_blocks.{depth}.attn2.to_out.0.weight"), 58 | (f"blocks.{depth}.cross_attn.proj.bias" ,f"transformer_blocks.{depth}.attn2.to_out.0.bias"), 59 | ] 60 | 61 | def find_prefix(state_dict, target_key): 62 | prefix = "" 63 | for k in state_dict.keys(): 64 | if k.endswith(target_key): 65 | prefix = k.split(target_key)[0] 66 | break 67 | return prefix 68 | 69 | def convert_state_dict(state_dict): 70 | if "adaln_single.emb.resolution_embedder.linear_1.weight" in state_dict.keys(): 71 | cmap = conversion_map + conversion_map_ms 72 | else: 73 | cmap = conversion_map 74 | 75 | missing = [k for k,v in cmap if v not in state_dict] 76 | new_state_dict = {k: state_dict[v] for k,v in cmap if k not in missing} 77 | matched = list(v for k,v in cmap if v in state_dict.keys()) 78 | 79 | for depth in range(28): 80 | for wb in ["weight", "bias"]: 81 | # Self Attention 82 | key = lambda a: f"transformer_blocks.{depth}.attn1.to_{a}.{wb}" 83 | new_state_dict[f"blocks.{depth}.attn.qkv.{wb}"] = torch.cat(( 84 | state_dict[key('q')], state_dict[key('k')], state_dict[key('v')] 85 | ), dim=0) 86 | matched += [key('q'), key('k'), key('v')] 87 | 88 | # Cross-attention (linear) 89 | key = lambda a: f"transformer_blocks.{depth}.attn2.to_{a}.{wb}" 90 | new_state_dict[f"blocks.{depth}.cross_attn.q_linear.{wb}"] = state_dict[key('q')] 91 | new_state_dict[f"blocks.{depth}.cross_attn.kv_linear.{wb}"] = torch.cat(( 92 | state_dict[key('k')], state_dict[key('v')] 93 | ), dim=0) 94 | matched += [key('q'), key('k'), key('v')] 95 | 96 | if len(matched) < len(state_dict): 97 | print(f"PixArt: UNET conversion has leftover keys! ({len(matched)} vs {len(state_dict)})") 98 | print(list( set(state_dict.keys()) - set(matched) )) 99 | 100 | if len(missing) > 0: 101 | print(f"PixArt: UNET conversion has missing keys!") 102 | print(missing) 103 | 104 | return new_state_dict 105 | 106 | # Same as above but for LoRA weights: 107 | def convert_lora_state_dict(state_dict): 108 | # peft 109 | rep_ap = lambda x: x.replace(".weight", ".lora_A.weight") 110 | rep_bp = lambda x: x.replace(".weight", ".lora_B.weight") 111 | # koyha 112 | rep_ak = lambda x: x.replace(".weight", ".lora_down.weight") 113 | rep_bk = lambda x: x.replace(".weight", ".lora_up.weight") 114 | 115 | prefix = find_prefix(state_dict, "adaln_single.linear.lora_A.weight") 116 | state_dict = {k[len(prefix):]:v for k,v in state_dict.items()} 117 | 118 | cmap = [] 119 | cmap_unet = conversion_map + conversion_map_ms # todo: 512 model 120 | for k, v in cmap_unet: 121 | if not v.endswith(".weight"): 122 | continue 123 | cmap.append((rep_ak(k), rep_ap(v))) 124 | cmap.append((rep_bk(k), rep_bp(v))) 125 | 126 | missing = [k for k,v in cmap if v not in state_dict] 127 | new_state_dict = {k: state_dict[v] for k,v in cmap if k not in missing} 128 | matched = list(v for k,v in cmap if v in state_dict.keys()) 129 | 130 | for fp, fk in ((rep_ap, rep_ak),(rep_bp, rep_bk)): 131 | for depth in range(28): 132 | # Self Attention 133 | key = lambda a: fp(f"transformer_blocks.{depth}.attn1.to_{a}.weight") 134 | new_state_dict[fk(f"blocks.{depth}.attn.qkv.weight")] = torch.cat(( 135 | state_dict[key('q')], state_dict[key('k')], state_dict[key('v')] 136 | ), dim=0) 137 | matched += [key('q'), key('k'), key('v')] 138 | 139 | # Cross-attention (linear) 140 | key = lambda a: fp(f"transformer_blocks.{depth}.attn2.to_{a}.weight") 141 | new_state_dict[fk(f"blocks.{depth}.cross_attn.q_linear.weight")] = state_dict[key('q')] 142 | new_state_dict[fk(f"blocks.{depth}.cross_attn.kv_linear.weight")] = torch.cat(( 143 | state_dict[key('k')], state_dict[key('v')] 144 | ), dim=0) 145 | matched += [key('q'), key('k'), key('v')] 146 | 147 | if len(matched) < len(state_dict): 148 | print(f"PixArt: LoRA conversion has leftover keys! ({len(matched)} vs {len(state_dict)})") 149 | print(list( set(state_dict.keys()) - set(matched) )) 150 | 151 | if len(missing) > 0: 152 | print(f"PixArt: LoRA conversion has missing keys!") 153 | print(missing) 154 | 155 | return new_state_dict 156 | -------------------------------------------------------------------------------- /PixArt/loader.py: -------------------------------------------------------------------------------- 1 | import comfy.supported_models_base 2 | import comfy.latent_formats 3 | import comfy.model_patcher 4 | import comfy.model_base 5 | import comfy.utils 6 | import comfy.conds 7 | import torch 8 | from comfy import model_management 9 | from .diffusers_convert import convert_state_dict 10 | 11 | class EXM_PixArt(comfy.supported_models_base.BASE): 12 | unet_config = {} 13 | unet_extra_config = {} 14 | latent_format = comfy.latent_formats.SD15 15 | 16 | def __init__(self, model_conf): 17 | self.model_target = model_conf.get("target") 18 | self.unet_config = model_conf.get("unet_config", {}) 19 | self.sampling_settings = model_conf.get("sampling_settings", {}) 20 | self.latent_format = self.latent_format() 21 | # UNET is handled by extension 22 | self.unet_config["disable_unet_model_creation"] = True 23 | 24 | def model_type(self, state_dict, prefix=""): 25 | return comfy.model_base.ModelType.EPS 26 | 27 | class EXM_PixArt_Model(comfy.model_base.BaseModel): 28 | def __init__(self, *args, **kwargs): 29 | super().__init__(*args, **kwargs) 30 | 31 | def extra_conds(self, **kwargs): 32 | out = super().extra_conds(**kwargs) 33 | 34 | img_hw = kwargs.get("img_hw", None) 35 | if img_hw is not None: 36 | out["img_hw"] = comfy.conds.CONDRegular(torch.tensor(img_hw)) 37 | 38 | aspect_ratio = kwargs.get("aspect_ratio", None) 39 | if aspect_ratio is not None: 40 | out["aspect_ratio"] = comfy.conds.CONDRegular(torch.tensor(aspect_ratio)) 41 | 42 | cn_hint = kwargs.get("cn_hint", None) 43 | if cn_hint is not None: 44 | out["cn_hint"] = comfy.conds.CONDRegular(cn_hint) 45 | 46 | return out 47 | 48 | def load_pixart(model_path, model_conf): 49 | state_dict = comfy.utils.load_torch_file(model_path) 50 | state_dict = state_dict.get("model", state_dict) 51 | 52 | # prefix 53 | for prefix in ["model.diffusion_model.",]: 54 | if any(True for x in state_dict if x.startswith(prefix)): 55 | state_dict = {k[len(prefix):]:v for k,v in state_dict.items()} 56 | 57 | # diffusers 58 | if "adaln_single.linear.weight" in state_dict: 59 | state_dict = convert_state_dict(state_dict) # Diffusers 60 | 61 | parameters = comfy.utils.calculate_parameters(state_dict) 62 | unet_dtype = model_management.unet_dtype(model_params=parameters) 63 | load_device = comfy.model_management.get_torch_device() 64 | offload_device = comfy.model_management.unet_offload_device() 65 | 66 | # ignore fp8/etc and use directly for now 67 | manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device) 68 | if manual_cast_dtype: 69 | print(f"PixArt: falling back to {manual_cast_dtype}") 70 | unet_dtype = manual_cast_dtype 71 | 72 | model_conf = EXM_PixArt(model_conf) # convert to object 73 | model = EXM_PixArt_Model( # same as comfy.model_base.BaseModel 74 | model_conf, 75 | model_type=comfy.model_base.ModelType.EPS, 76 | device=model_management.get_torch_device() 77 | ) 78 | 79 | if model_conf.model_target == "PixArtMS": 80 | from .models.PixArtMS import PixArtMS 81 | model.diffusion_model = PixArtMS(**model_conf.unet_config) 82 | elif model_conf.model_target == "PixArt": 83 | from .models.PixArt import PixArt 84 | model.diffusion_model = PixArt(**model_conf.unet_config) 85 | elif model_conf.model_target == "PixArtMSSigma": 86 | from .models.PixArtMS import PixArtMS 87 | model.diffusion_model = PixArtMS(**model_conf.unet_config) 88 | model.latent_format = comfy.latent_formats.SDXL() 89 | elif model_conf.model_target == "ControlPixArtMSHalf": 90 | from .models.PixArtMS import PixArtMS 91 | from .models.pixart_controlnet import ControlPixArtMSHalf 92 | model.diffusion_model = PixArtMS(**model_conf.unet_config) 93 | model.diffusion_model = ControlPixArtMSHalf(model.diffusion_model) 94 | elif model_conf.model_target == "ControlPixArtHalf": 95 | from .models.PixArt import PixArt 96 | from .models.pixart_controlnet import ControlPixArtHalf 97 | model.diffusion_model = PixArt(**model_conf.unet_config) 98 | model.diffusion_model = ControlPixArtHalf(model.diffusion_model) 99 | else: 100 | raise NotImplementedError(f"Unknown model target '{model_conf.model_target}'") 101 | 102 | m, u = model.diffusion_model.load_state_dict(state_dict, strict=False) 103 | if len(m) > 0: print("Missing UNET keys", m) 104 | if len(u) > 0: print("Leftover UNET keys", u) 105 | model.diffusion_model.dtype = unet_dtype 106 | model.diffusion_model.eval() 107 | model.diffusion_model.to(unet_dtype) 108 | 109 | model_patcher = comfy.model_patcher.ModelPatcher( 110 | model, 111 | load_device = load_device, 112 | offload_device = offload_device, 113 | current_device = "cpu", 114 | ) 115 | return model_patcher 116 | -------------------------------------------------------------------------------- /PixArt/lora.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import json 4 | import torch 5 | import comfy.lora 6 | import comfy.model_management 7 | from comfy.model_patcher import ModelPatcher 8 | from .diffusers_convert import convert_lora_state_dict 9 | 10 | class EXM_PixArt_ModelPatcher(ModelPatcher): 11 | def calculate_weight(self, patches, weight, key): 12 | """ 13 | This is almost the same as the comfy function, but stripped down to just the LoRA patch code. 14 | The problem with the original code is the q/k/v keys being combined into one for the attention. 15 | In the diffusers code, they're treated as separate keys, but in the reference code they're recombined (q+kv|qkv). 16 | This means, for example, that the [1152,1152] weights become [3456,1152] in the state dict. 17 | The issue with this is that the LoRA weights are [128,1152],[1152,128] and become [384,1162],[3456,128] instead. 18 | 19 | This is the best thing I could think of that would fix that, but it's very fragile. 20 | - Check key shape to determine if it needs the fallback logic 21 | - Cut the input into parts based on the shape (undoing the torch.cat) 22 | - Do the matrix multiplication logic 23 | - Recombine them to match the expected shape 24 | """ 25 | for p in patches: 26 | alpha = p[0] 27 | v = p[1] 28 | strength_model = p[2] 29 | if strength_model != 1.0: 30 | weight *= strength_model 31 | 32 | if isinstance(v, list): 33 | v = (self.calculate_weight(v[1:], v[0].clone(), key), ) 34 | 35 | if len(v) == 2: 36 | patch_type = v[0] 37 | v = v[1] 38 | 39 | if patch_type == "lora": 40 | mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32) 41 | mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32) 42 | if v[2] is not None: 43 | alpha *= v[2] / mat2.shape[0] 44 | try: 45 | mat1 = mat1.flatten(start_dim=1) 46 | mat2 = mat2.flatten(start_dim=1) 47 | 48 | ch1 = mat1.shape[0] // mat2.shape[1] 49 | ch2 = mat2.shape[0] // mat1.shape[1] 50 | ### Fallback logic for shape mismatch ### 51 | if mat1.shape[0] != mat2.shape[1] and ch1 == ch2 and (mat1.shape[0]/mat2.shape[1])%1 == 0: 52 | mat1 = mat1.chunk(ch1, dim=0) 53 | mat2 = mat2.chunk(ch1, dim=0) 54 | weight += torch.cat( 55 | [alpha * torch.mm(mat1[x], mat2[x]) for x in range(ch1)], 56 | dim=0, 57 | ).reshape(weight.shape).type(weight.dtype) 58 | else: 59 | weight += (alpha * torch.mm(mat1, mat2)).reshape(weight.shape).type(weight.dtype) 60 | except Exception as e: 61 | print("ERROR", key, e) 62 | return weight 63 | 64 | def clone(self): 65 | n = EXM_PixArt_ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update) 66 | n.patches = {} 67 | for k in self.patches: 68 | n.patches[k] = self.patches[k][:] 69 | 70 | n.object_patches = self.object_patches.copy() 71 | n.model_options = copy.deepcopy(self.model_options) 72 | n.model_keys = self.model_keys 73 | return n 74 | 75 | def replace_model_patcher(model): 76 | n = EXM_PixArt_ModelPatcher( 77 | model = model.model, 78 | size = model.size, 79 | load_device = model.load_device, 80 | offload_device = model.offload_device, 81 | current_device = model.current_device, 82 | weight_inplace_update = model.weight_inplace_update, 83 | ) 84 | n.patches = {} 85 | for k in model.patches: 86 | n.patches[k] = model.patches[k][:] 87 | 88 | n.object_patches = model.object_patches.copy() 89 | n.model_options = copy.deepcopy(model.model_options) 90 | n.model_keys = model.model_keys 91 | return n 92 | 93 | def find_peft_alpha(path): 94 | def load_json(json_path): 95 | with open(json_path) as f: 96 | data = json.load(f) 97 | alpha = data.get("lora_alpha") 98 | alpha = alpha or data.get("alpha") 99 | if not alpha: 100 | print(" Found config but `lora_alpha` is missing!") 101 | else: 102 | print(f" Found config at {json_path} [alpha:{alpha}]") 103 | return alpha 104 | 105 | # For some weird reason peft doesn't include the alpha in the actual model 106 | print("PixArt: Warning! This is a PEFT LoRA. Trying to find config...") 107 | files = [ 108 | f"{os.path.splitext(path)[0]}.json", 109 | f"{os.path.splitext(path)[0]}.config.json", 110 | os.path.join(os.path.dirname(path),"adapter_config.json"), 111 | ] 112 | for file in files: 113 | if os.path.isfile(file): 114 | return load_json(file) 115 | 116 | print(" Missing config/alpha! assuming alpha of 8. Consider converting it/adding a config json to it.") 117 | return 8.0 118 | 119 | def load_pixart_lora(model, lora, lora_path, strength): 120 | k_back = lambda x: x.replace(".lora_up.weight", "") 121 | # need to convert the actual weights for this to work. 122 | if any(True for x in lora.keys() if x.endswith("adaln_single.linear.lora_A.weight")): 123 | lora = convert_lora_state_dict(lora) 124 | alpha = find_peft_alpha(lora_path) 125 | lora.update({f"{k_back(x)}.alpha":torch.tensor(alpha) for x in lora.keys() if "lora_up" in x}) 126 | 127 | key_map = {k_back(x):f"diffusion_model.{k_back(x)}.weight" for x in lora.keys() if "lora_up" in x} # fake 128 | 129 | loaded = comfy.lora.load_lora(lora, key_map) 130 | if model is not None: 131 | # switch to custom model patcher when using LoRAs 132 | if isinstance(model, EXM_PixArt_ModelPatcher): 133 | new_modelpatcher = model.clone() 134 | else: 135 | new_modelpatcher = replace_model_patcher(model) 136 | k = new_modelpatcher.add_patches(loaded, strength) 137 | else: 138 | k = () 139 | new_modelpatcher = None 140 | 141 | k = set(k) 142 | for x in loaded: 143 | if (x not in k): 144 | print("NOT LOADED", x) 145 | 146 | return new_modelpatcher 147 | -------------------------------------------------------------------------------- /PixArt/models/PixArt.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # GLIDE: https://github.com/openai/glide-text2im 9 | # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py 10 | # -------------------------------------------------------- 11 | import math 12 | import torch 13 | import torch.nn as nn 14 | import os 15 | import numpy as np 16 | from timm.models.layers import DropPath 17 | from timm.models.vision_transformer import PatchEmbed, Mlp 18 | 19 | 20 | from .utils import auto_grad_checkpoint, to_2tuple 21 | from .PixArt_blocks import t2i_modulate, CaptionEmbedder, AttentionKVCompress, MultiHeadCrossAttention, T2IFinalLayer, TimestepEmbedder, LabelEmbedder, FinalLayer 22 | 23 | 24 | class PixArtBlock(nn.Module): 25 | """ 26 | A PixArt block with adaptive layer norm (adaLN-single) conditioning. 27 | """ 28 | def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0, input_size=None, sampling=None, sr_ratio=1, qk_norm=False, **block_kwargs): 29 | super().__init__() 30 | self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 31 | self.attn = AttentionKVCompress( 32 | hidden_size, num_heads=num_heads, qkv_bias=True, sampling=sampling, sr_ratio=sr_ratio, 33 | qk_norm=qk_norm, **block_kwargs 34 | ) 35 | self.cross_attn = MultiHeadCrossAttention(hidden_size, num_heads, **block_kwargs) 36 | self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 37 | # to be compatible with lower version pytorch 38 | approx_gelu = lambda: nn.GELU(approximate="tanh") 39 | self.mlp = Mlp(in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0) 40 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 41 | self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5) 42 | self.sampling = sampling 43 | self.sr_ratio = sr_ratio 44 | 45 | def forward(self, x, y, t, mask=None, **kwargs): 46 | B, N, C = x.shape 47 | 48 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None] + t.reshape(B, 6, -1)).chunk(6, dim=1) 49 | x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa)).reshape(B, N, C)) 50 | x = x + self.cross_attn(x, y, mask) 51 | x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp))) 52 | 53 | return x 54 | 55 | 56 | ### Core PixArt Model ### 57 | class PixArt(nn.Module): 58 | """ 59 | Diffusion model with a Transformer backbone. 60 | """ 61 | def __init__( 62 | self, 63 | input_size=32, 64 | patch_size=2, 65 | in_channels=4, 66 | hidden_size=1152, 67 | depth=28, 68 | num_heads=16, 69 | mlp_ratio=4.0, 70 | class_dropout_prob=0.1, 71 | pred_sigma=True, 72 | drop_path: float = 0., 73 | caption_channels=4096, 74 | pe_interpolation=1.0, 75 | config=None, 76 | model_max_length=120, 77 | qk_norm=False, 78 | kv_compress_config=None, 79 | **kwargs, 80 | ): 81 | super().__init__() 82 | self.pred_sigma = pred_sigma 83 | self.in_channels = in_channels 84 | self.out_channels = in_channels * 2 if pred_sigma else in_channels 85 | self.patch_size = patch_size 86 | self.num_heads = num_heads 87 | self.pe_interpolation = pe_interpolation 88 | self.depth = depth 89 | 90 | self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True) 91 | self.t_embedder = TimestepEmbedder(hidden_size) 92 | num_patches = self.x_embedder.num_patches 93 | self.base_size = input_size // self.patch_size 94 | # Will use fixed sin-cos embedding: 95 | self.register_buffer("pos_embed", torch.zeros(1, num_patches, hidden_size)) 96 | 97 | approx_gelu = lambda: nn.GELU(approximate="tanh") 98 | self.t_block = nn.Sequential( 99 | nn.SiLU(), 100 | nn.Linear(hidden_size, 6 * hidden_size, bias=True) 101 | ) 102 | self.y_embedder = CaptionEmbedder( 103 | in_channels=caption_channels, hidden_size=hidden_size, uncond_prob=class_dropout_prob, 104 | act_layer=approx_gelu, token_num=model_max_length 105 | ) 106 | drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule 107 | self.kv_compress_config = kv_compress_config 108 | if kv_compress_config is None: 109 | self.kv_compress_config = { 110 | 'sampling': None, 111 | 'scale_factor': 1, 112 | 'kv_compress_layer': [], 113 | } 114 | self.blocks = nn.ModuleList([ 115 | PixArtBlock( 116 | hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i], 117 | input_size=(input_size // patch_size, input_size // patch_size), 118 | sampling=self.kv_compress_config['sampling'], 119 | sr_ratio=int( 120 | self.kv_compress_config['scale_factor'] 121 | ) if i in self.kv_compress_config['kv_compress_layer'] else 1, 122 | qk_norm=qk_norm, 123 | ) 124 | for i in range(depth) 125 | ]) 126 | self.final_layer = T2IFinalLayer(hidden_size, patch_size, self.out_channels) 127 | 128 | def forward_raw(self, x, t, y, mask=None, data_info=None): 129 | """ 130 | Original forward pass of PixArt. 131 | x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) 132 | t: (N,) tensor of diffusion timesteps 133 | y: (N, 1, 120, C) tensor of class labels 134 | """ 135 | x = x.to(self.dtype) 136 | timestep = t.to(self.dtype) 137 | y = y.to(self.dtype) 138 | pos_embed = self.pos_embed.to(self.dtype) 139 | self.h, self.w = x.shape[-2]//self.patch_size, x.shape[-1]//self.patch_size 140 | x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2 141 | t = self.t_embedder(timestep.to(x.dtype)) # (N, D) 142 | t0 = self.t_block(t) 143 | y = self.y_embedder(y, self.training) # (N, 1, L, D) 144 | if mask is not None: 145 | if mask.shape[0] != y.shape[0]: 146 | mask = mask.repeat(y.shape[0] // mask.shape[0], 1) 147 | mask = mask.squeeze(1).squeeze(1) 148 | y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) 149 | y_lens = mask.sum(dim=1).tolist() 150 | else: 151 | y_lens = [y.shape[2]] * y.shape[0] 152 | y = y.squeeze(1).view(1, -1, x.shape[-1]) 153 | for block in self.blocks: 154 | x = auto_grad_checkpoint(block, x, y, t0, y_lens) # (N, T, D) #support grad checkpoint 155 | x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels) 156 | x = self.unpatchify(x) # (N, out_channels, H, W) 157 | return x 158 | 159 | def forward(self, x, timesteps, context, y=None, **kwargs): 160 | """ 161 | Forward pass that adapts comfy input to original forward function 162 | x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) 163 | timesteps: (N,) tensor of diffusion timesteps 164 | context: (N, 1, 120, C) conditioning 165 | y: extra conditioning. 166 | """ 167 | ## Still accepts the input w/o that dim but returns garbage 168 | if len(context.shape) == 3: 169 | context = context.unsqueeze(1) 170 | 171 | ## run original forward pass 172 | out = self.forward_raw( 173 | x = x.to(self.dtype), 174 | t = timesteps.to(self.dtype), 175 | y = context.to(self.dtype), 176 | ) 177 | 178 | ## only return EPS 179 | out = out.to(torch.float) 180 | eps, rest = out[:, :self.in_channels], out[:, self.in_channels:] 181 | return eps 182 | 183 | def unpatchify(self, x): 184 | """ 185 | x: (N, T, patch_size**2 * C) 186 | imgs: (N, H, W, C) 187 | """ 188 | c = self.out_channels 189 | p = self.x_embedder.patch_size[0] 190 | h = w = int(x.shape[1] ** 0.5) 191 | assert h * w == x.shape[1] 192 | 193 | x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) 194 | x = torch.einsum('nhwpqc->nchpwq', x) 195 | imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p)) 196 | return imgs 197 | 198 | 199 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, pe_interpolation=1.0, base_size=16): 200 | """ 201 | grid_size: int of the grid height and width 202 | return: 203 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 204 | """ 205 | if isinstance(grid_size, int): 206 | grid_size = to_2tuple(grid_size) 207 | grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0]/base_size) / pe_interpolation 208 | grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1]/base_size) / pe_interpolation 209 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 210 | grid = np.stack(grid, axis=0) 211 | grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) 212 | 213 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 214 | if cls_token and extra_tokens > 0: 215 | pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) 216 | return pos_embed 217 | 218 | 219 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 220 | assert embed_dim % 2 == 0 221 | 222 | # use half of dimensions to encode grid_h 223 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 224 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 225 | 226 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 227 | return emb 228 | 229 | 230 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 231 | """ 232 | embed_dim: output dimension for each position 233 | pos: a list of positions to be encoded: size (M,) 234 | out: (M, D) 235 | """ 236 | assert embed_dim % 2 == 0 237 | omega = np.arange(embed_dim // 2, dtype=np.float64) 238 | omega /= embed_dim / 2. 239 | omega = 1. / 10000 ** omega # (D/2,) 240 | 241 | pos = pos.reshape(-1) # (M,) 242 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 243 | 244 | emb_sin = np.sin(out) # (M, D/2) 245 | emb_cos = np.cos(out) # (M, D/2) 246 | 247 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 248 | return emb 249 | -------------------------------------------------------------------------------- /PixArt/models/PixArtMS.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # GLIDE: https://github.com/openai/glide-text2im 9 | # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py 10 | # -------------------------------------------------------- 11 | import torch 12 | import torch.nn as nn 13 | from tqdm import tqdm 14 | from timm.models.layers import DropPath 15 | from timm.models.vision_transformer import Mlp 16 | 17 | from .utils import auto_grad_checkpoint, to_2tuple 18 | from .PixArt_blocks import t2i_modulate, CaptionEmbedder, AttentionKVCompress, MultiHeadCrossAttention, T2IFinalLayer, TimestepEmbedder, SizeEmbedder 19 | from .PixArt import PixArt, get_2d_sincos_pos_embed 20 | 21 | 22 | class PatchEmbed(nn.Module): 23 | """ 24 | 2D Image to Patch Embedding 25 | """ 26 | def __init__( 27 | self, 28 | patch_size=16, 29 | in_chans=3, 30 | embed_dim=768, 31 | norm_layer=None, 32 | flatten=True, 33 | bias=True, 34 | ): 35 | super().__init__() 36 | patch_size = to_2tuple(patch_size) 37 | self.patch_size = patch_size 38 | self.flatten = flatten 39 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) 40 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 41 | 42 | def forward(self, x): 43 | x = self.proj(x) 44 | if self.flatten: 45 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC 46 | x = self.norm(x) 47 | return x 48 | 49 | 50 | class PixArtMSBlock(nn.Module): 51 | """ 52 | A PixArt block with adaptive layer norm zero (adaLN-Zero) conditioning. 53 | """ 54 | def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0., input_size=None, 55 | sampling=None, sr_ratio=1, qk_norm=False, **block_kwargs): 56 | super().__init__() 57 | self.hidden_size = hidden_size 58 | self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 59 | self.attn = AttentionKVCompress( 60 | hidden_size, num_heads=num_heads, qkv_bias=True, sampling=sampling, sr_ratio=sr_ratio, 61 | qk_norm=qk_norm, **block_kwargs 62 | ) 63 | self.cross_attn = MultiHeadCrossAttention(hidden_size, num_heads, **block_kwargs) 64 | self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 65 | # to be compatible with lower version pytorch 66 | approx_gelu = lambda: nn.GELU(approximate="tanh") 67 | self.mlp = Mlp(in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0) 68 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 69 | self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5) 70 | 71 | def forward(self, x, y, t, mask=None, HW=None, **kwargs): 72 | B, N, C = x.shape 73 | 74 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None] + t.reshape(B, 6, -1)).chunk(6, dim=1) 75 | x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=HW)) 76 | x = x + self.cross_attn(x, y, mask) 77 | x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp))) 78 | 79 | return x 80 | 81 | 82 | ### Core PixArt Model ### 83 | class PixArtMS(PixArt): 84 | """ 85 | Diffusion model with a Transformer backbone. 86 | """ 87 | def __init__( 88 | self, 89 | input_size=32, 90 | patch_size=2, 91 | in_channels=4, 92 | hidden_size=1152, 93 | depth=28, 94 | num_heads=16, 95 | mlp_ratio=4.0, 96 | class_dropout_prob=0.1, 97 | learn_sigma=True, 98 | pred_sigma=True, 99 | drop_path: float = 0., 100 | caption_channels=4096, 101 | pe_interpolation=1., 102 | config=None, 103 | model_max_length=120, 104 | micro_condition=True, 105 | qk_norm=False, 106 | kv_compress_config=None, 107 | **kwargs, 108 | ): 109 | super().__init__( 110 | input_size=input_size, 111 | patch_size=patch_size, 112 | in_channels=in_channels, 113 | hidden_size=hidden_size, 114 | depth=depth, 115 | num_heads=num_heads, 116 | mlp_ratio=mlp_ratio, 117 | class_dropout_prob=class_dropout_prob, 118 | learn_sigma=learn_sigma, 119 | pred_sigma=pred_sigma, 120 | drop_path=drop_path, 121 | pe_interpolation=pe_interpolation, 122 | config=config, 123 | model_max_length=model_max_length, 124 | qk_norm=qk_norm, 125 | kv_compress_config=kv_compress_config, 126 | **kwargs, 127 | ) 128 | self.dtype = torch.get_default_dtype() 129 | self.h = self.w = 0 130 | approx_gelu = lambda: nn.GELU(approximate="tanh") 131 | self.t_block = nn.Sequential( 132 | nn.SiLU(), 133 | nn.Linear(hidden_size, 6 * hidden_size, bias=True) 134 | ) 135 | self.x_embedder = PatchEmbed(patch_size, in_channels, hidden_size, bias=True) 136 | self.y_embedder = CaptionEmbedder(in_channels=caption_channels, hidden_size=hidden_size, uncond_prob=class_dropout_prob, act_layer=approx_gelu, token_num=model_max_length) 137 | self.micro_conditioning = micro_condition 138 | if self.micro_conditioning: 139 | self.csize_embedder = SizeEmbedder(hidden_size//3) # c_size embed 140 | self.ar_embedder = SizeEmbedder(hidden_size//3) # aspect ratio embed 141 | drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule 142 | if kv_compress_config is None: 143 | kv_compress_config = { 144 | 'sampling': None, 145 | 'scale_factor': 1, 146 | 'kv_compress_layer': [], 147 | } 148 | self.blocks = nn.ModuleList([ 149 | PixArtMSBlock( 150 | hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i], 151 | input_size=(input_size // patch_size, input_size // patch_size), 152 | sampling=kv_compress_config['sampling'], 153 | sr_ratio=int(kv_compress_config['scale_factor']) if i in kv_compress_config['kv_compress_layer'] else 1, 154 | qk_norm=qk_norm, 155 | ) 156 | for i in range(depth) 157 | ]) 158 | self.final_layer = T2IFinalLayer(hidden_size, patch_size, self.out_channels) 159 | 160 | def forward_raw(self, x, t, y, mask=None, data_info=None, **kwargs): 161 | """ 162 | Original forward pass of PixArt. 163 | x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) 164 | t: (N,) tensor of diffusion timesteps 165 | y: (N, 1, 120, C) tensor of class labels 166 | """ 167 | bs = x.shape[0] 168 | x = x.to(self.dtype) 169 | timestep = t.to(self.dtype) 170 | y = y.to(self.dtype) 171 | self.h, self.w = x.shape[-2]//self.patch_size, x.shape[-1]//self.patch_size 172 | pos_embed = torch.from_numpy( 173 | get_2d_sincos_pos_embed( 174 | self.pos_embed.shape[-1], (self.h, self.w), pe_interpolation=self.pe_interpolation, 175 | base_size=self.base_size 176 | ) 177 | ).unsqueeze(0).to(x.device).to(self.dtype) 178 | 179 | x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2 180 | t = self.t_embedder(timestep) # (N, D) 181 | 182 | if self.micro_conditioning: 183 | c_size, ar = data_info['img_hw'].to(self.dtype), data_info['aspect_ratio'].to(self.dtype) 184 | csize = self.csize_embedder(c_size, bs) # (N, D) 185 | ar = self.ar_embedder(ar, bs) # (N, D) 186 | t = t + torch.cat([csize, ar], dim=1) 187 | 188 | t0 = self.t_block(t) 189 | y = self.y_embedder(y, self.training) # (N, D) 190 | 191 | if mask is not None: 192 | if mask.shape[0] != y.shape[0]: 193 | mask = mask.repeat(y.shape[0] // mask.shape[0], 1) 194 | mask = mask.squeeze(1).squeeze(1) 195 | y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) 196 | y_lens = mask.sum(dim=1).tolist() 197 | else: 198 | y_lens = [y.shape[2]] * y.shape[0] 199 | y = y.squeeze(1).view(1, -1, x.shape[-1]) 200 | for block in self.blocks: 201 | x = auto_grad_checkpoint(block, x, y, t0, y_lens, (self.h, self.w), **kwargs) # (N, T, D) #support grad checkpoint 202 | 203 | x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels) 204 | x = self.unpatchify(x) # (N, out_channels, H, W) 205 | 206 | return x 207 | 208 | def forward(self, x, timesteps, context, img_hw=None, aspect_ratio=None, **kwargs): 209 | """ 210 | Forward pass that adapts comfy input to original forward function 211 | x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) 212 | timesteps: (N,) tensor of diffusion timesteps 213 | context: (N, 1, 120, C) conditioning 214 | img_hw: height|width conditioning 215 | aspect_ratio: aspect ratio conditioning 216 | """ 217 | ## size/ar from cond with fallback based on the latent image shape. 218 | bs = x.shape[0] 219 | data_info = {} 220 | if img_hw is None: 221 | data_info["img_hw"] = torch.tensor( 222 | [[x.shape[2]*8, x.shape[3]*8]], 223 | dtype=self.dtype, 224 | device=x.device 225 | ).repeat(bs, 1) 226 | else: 227 | data_info["img_hw"] = img_hw.to(x.dtype).to(x.device) 228 | if aspect_ratio is None or True: 229 | data_info["aspect_ratio"] = torch.tensor( 230 | [[x.shape[2]/x.shape[3]]], 231 | dtype=self.dtype, 232 | device=x.device 233 | ).repeat(bs, 1) 234 | else: 235 | data_info["aspect_ratio"] = aspect_ratio.to(x.dtype).to(x.device) 236 | 237 | ## Still accepts the input w/o that dim but returns garbage 238 | if len(context.shape) == 3: 239 | context = context.unsqueeze(1) 240 | 241 | ## run original forward pass 242 | out = self.forward_raw( 243 | x = x.to(self.dtype), 244 | t = timesteps.to(self.dtype), 245 | y = context.to(self.dtype), 246 | data_info=data_info, 247 | ) 248 | 249 | ## only return EPS 250 | out = out.to(torch.float) 251 | eps, rest = out[:, :self.in_channels], out[:, self.in_channels:] 252 | return eps 253 | 254 | def unpatchify(self, x): 255 | """ 256 | x: (N, T, patch_size**2 * C) 257 | imgs: (N, H, W, C) 258 | """ 259 | c = self.out_channels 260 | p = self.x_embedder.patch_size[0] 261 | assert self.h * self.w == x.shape[1] 262 | 263 | x = x.reshape(shape=(x.shape[0], self.h, self.w, p, p, c)) 264 | x = torch.einsum('nhwpqc->nchpwq', x) 265 | imgs = x.reshape(shape=(x.shape[0], c, self.h * p, self.w * p)) 266 | return imgs 267 | -------------------------------------------------------------------------------- /PixArt/models/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.utils.checkpoint import checkpoint, checkpoint_sequential 5 | from collections.abc import Iterable 6 | from itertools import repeat 7 | 8 | def _ntuple(n): 9 | def parse(x): 10 | if isinstance(x, Iterable) and not isinstance(x, str): 11 | return x 12 | return tuple(repeat(x, n)) 13 | return parse 14 | 15 | to_1tuple = _ntuple(1) 16 | to_2tuple = _ntuple(2) 17 | 18 | def set_grad_checkpoint(model, use_fp32_attention=False, gc_step=1): 19 | assert isinstance(model, nn.Module) 20 | 21 | def set_attr(module): 22 | module.grad_checkpointing = True 23 | module.fp32_attention = use_fp32_attention 24 | module.grad_checkpointing_step = gc_step 25 | model.apply(set_attr) 26 | 27 | def auto_grad_checkpoint(module, *args, **kwargs): 28 | if getattr(module, 'grad_checkpointing', False): 29 | if isinstance(module, Iterable): 30 | gc_step = module[0].grad_checkpointing_step 31 | return checkpoint_sequential(module, gc_step, *args, **kwargs) 32 | else: 33 | return checkpoint(module, *args, **kwargs) 34 | return module(*args, **kwargs) 35 | 36 | def checkpoint_sequential(functions, step, input, *args, **kwargs): 37 | 38 | # Hack for keyword-only parameter in a python 2.7-compliant way 39 | preserve = kwargs.pop('preserve_rng_state', True) 40 | if kwargs: 41 | raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)) 42 | 43 | def run_function(start, end, functions): 44 | def forward(input): 45 | for j in range(start, end + 1): 46 | input = functions[j](input, *args) 47 | return input 48 | return forward 49 | 50 | if isinstance(functions, torch.nn.Sequential): 51 | functions = list(functions.children()) 52 | 53 | # the last chunk has to be non-volatile 54 | end = -1 55 | segment = len(functions) // step 56 | for start in range(0, step * (segment - 1), step): 57 | end = start + step - 1 58 | input = checkpoint(run_function(start, end, functions), input, preserve_rng_state=preserve) 59 | return run_function(end + 1, len(functions) - 1, functions)(input) 60 | 61 | def get_rel_pos(q_size, k_size, rel_pos): 62 | """ 63 | Get relative positional embeddings according to the relative positions of 64 | query and key sizes. 65 | Args: 66 | q_size (int): size of query q. 67 | k_size (int): size of key k. 68 | rel_pos (Tensor): relative position embeddings (L, C). 69 | 70 | Returns: 71 | Extracted positional embeddings according to relative positions. 72 | """ 73 | max_rel_dist = int(2 * max(q_size, k_size) - 1) 74 | # Interpolate rel pos if needed. 75 | if rel_pos.shape[0] != max_rel_dist: 76 | # Interpolate rel pos. 77 | rel_pos_resized = F.interpolate( 78 | rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), 79 | size=max_rel_dist, 80 | mode="linear", 81 | ) 82 | rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) 83 | else: 84 | rel_pos_resized = rel_pos 85 | 86 | # Scale the coords with short length if shapes for q and k are different. 87 | q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) 88 | k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) 89 | relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) 90 | 91 | return rel_pos_resized[relative_coords.long()] 92 | 93 | def add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size): 94 | """ 95 | Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. 96 | https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 97 | Args: 98 | attn (Tensor): attention map. 99 | q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). 100 | rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. 101 | rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. 102 | q_size (Tuple): spatial sequence size of query q with (q_h, q_w). 103 | k_size (Tuple): spatial sequence size of key k with (k_h, k_w). 104 | 105 | Returns: 106 | attn (Tensor): attention map with added relative positional embeddings. 107 | """ 108 | q_h, q_w = q_size 109 | k_h, k_w = k_size 110 | Rh = get_rel_pos(q_h, k_h, rel_pos_h) 111 | Rw = get_rel_pos(q_w, k_w, rel_pos_w) 112 | 113 | B, _, dim = q.shape 114 | r_q = q.reshape(B, q_h, q_w, dim) 115 | rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) 116 | rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) 117 | 118 | attn = ( 119 | attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] 120 | ).view(B, q_h * q_w, k_h * k_w) 121 | 122 | return attn 123 | -------------------------------------------------------------------------------- /PixArt/nodes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import folder_paths 5 | 6 | from comfy import utils 7 | from .conf import pixart_conf, pixart_res 8 | from .lora import load_pixart_lora 9 | from .loader import load_pixart 10 | 11 | class PixArtCheckpointLoader: 12 | @classmethod 13 | def INPUT_TYPES(s): 14 | return { 15 | "required": { 16 | "ckpt_name": (folder_paths.get_filename_list("checkpoints"),), 17 | "model": (list(pixart_conf.keys()),), 18 | } 19 | } 20 | RETURN_TYPES = ("MODEL",) 21 | RETURN_NAMES = ("model",) 22 | FUNCTION = "load_checkpoint" 23 | CATEGORY = "ExtraModels/PixArt" 24 | TITLE = "PixArt Checkpoint Loader" 25 | 26 | def load_checkpoint(self, ckpt_name, model): 27 | ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) 28 | model_conf = pixart_conf[model] 29 | model = load_pixart( 30 | model_path = ckpt_path, 31 | model_conf = model_conf, 32 | ) 33 | return (model,) 34 | 35 | class PixArtResolutionSelect(): 36 | @classmethod 37 | def INPUT_TYPES(s): 38 | return { 39 | "required": { 40 | "model": (list(pixart_res.keys()),), 41 | # keys are the same for both 42 | "ratio": (list(pixart_res["PixArtMS_XL_2"].keys()),{"default":"1.00"}), 43 | } 44 | } 45 | RETURN_TYPES = ("INT","INT") 46 | RETURN_NAMES = ("width","height") 47 | FUNCTION = "get_res" 48 | CATEGORY = "ExtraModels/PixArt" 49 | TITLE = "PixArt Resolution Select" 50 | 51 | def get_res(self, model, ratio): 52 | width, height = pixart_res[model][ratio] 53 | return (width,height) 54 | 55 | class PixArtLoraLoader: 56 | def __init__(self): 57 | self.loaded_lora = None 58 | 59 | @classmethod 60 | def INPUT_TYPES(s): 61 | return { 62 | "required": { 63 | "model": ("MODEL",), 64 | "lora_name": (folder_paths.get_filename_list("loras"), ), 65 | "strength": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), 66 | } 67 | } 68 | RETURN_TYPES = ("MODEL",) 69 | FUNCTION = "load_lora" 70 | CATEGORY = "ExtraModels/PixArt" 71 | TITLE = "PixArt Load LoRA" 72 | 73 | def load_lora(self, model, lora_name, strength,): 74 | if strength == 0: 75 | return (model) 76 | 77 | lora_path = folder_paths.get_full_path("loras", lora_name) 78 | lora = None 79 | if self.loaded_lora is not None: 80 | if self.loaded_lora[0] == lora_path: 81 | lora = self.loaded_lora[1] 82 | else: 83 | temp = self.loaded_lora 84 | self.loaded_lora = None 85 | del temp 86 | 87 | if lora is None: 88 | lora = utils.load_torch_file(lora_path, safe_load=True) 89 | self.loaded_lora = (lora_path, lora) 90 | 91 | model_lora = load_pixart_lora(model, lora, lora_path, strength,) 92 | return (model_lora,) 93 | 94 | class PixArtResolutionCond: 95 | @classmethod 96 | def INPUT_TYPES(s): 97 | return { 98 | "required": { 99 | "cond": ("CONDITIONING", ), 100 | "width": ("INT", {"default": 1024.0, "min": 0, "max": 8192}), 101 | "height": ("INT", {"default": 1024.0, "min": 0, "max": 8192}), 102 | } 103 | } 104 | 105 | RETURN_TYPES = ("CONDITIONING",) 106 | RETURN_NAMES = ("cond",) 107 | FUNCTION = "add_cond" 108 | CATEGORY = "ExtraModels/PixArt" 109 | TITLE = "PixArt Resolution Conditioning" 110 | 111 | def add_cond(self, cond, width, height): 112 | for c in range(len(cond)): 113 | cond[c][1].update({ 114 | "img_hw": [[height, width]], 115 | "aspect_ratio": [[height/width]], 116 | }) 117 | return (cond,) 118 | 119 | class PixArtControlNetCond: 120 | @classmethod 121 | def INPUT_TYPES(s): 122 | return { 123 | "required": { 124 | "cond": ("CONDITIONING",), 125 | "latent": ("LATENT",), 126 | # "image": ("IMAGE",), 127 | # "vae": ("VAE",), 128 | # "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}) 129 | } 130 | } 131 | 132 | RETURN_TYPES = ("CONDITIONING",) 133 | RETURN_NAMES = ("cond",) 134 | FUNCTION = "add_cond" 135 | CATEGORY = "ExtraModels/PixArt" 136 | TITLE = "PixArt ControlNet Conditioning" 137 | 138 | def add_cond(self, cond, latent): 139 | for c in range(len(cond)): 140 | cond[c][1]["cn_hint"] = latent["samples"] * 0.18215 141 | return (cond,) 142 | 143 | class PixArtT5TextEncode: 144 | """ 145 | Reference code, mostly to verify compatibility. 146 | Once everything works, this should instead inherit from the 147 | T5 text encode node and simply add the extra conds (res/ar). 148 | """ 149 | @classmethod 150 | def INPUT_TYPES(s): 151 | return { 152 | "required": { 153 | "text": ("STRING", {"multiline": True}), 154 | "T5": ("T5",), 155 | } 156 | } 157 | 158 | RETURN_TYPES = ("CONDITIONING",) 159 | FUNCTION = "encode" 160 | CATEGORY = "ExtraModels/PixArt" 161 | TITLE = "PixArt T5 Text Encode [Reference]" 162 | 163 | def mask_feature(self, emb, mask): 164 | if emb.shape[0] == 1: 165 | keep_index = mask.sum().item() 166 | return emb[:, :, :keep_index, :], keep_index 167 | else: 168 | masked_feature = emb * mask[:, None, :, None] 169 | return masked_feature, emb.shape[2] 170 | 171 | def encode(self, text, T5): 172 | text = text.lower().strip() 173 | tokenizer_out = T5.tokenizer.tokenizer( 174 | text, 175 | max_length = 120, 176 | padding = 'max_length', 177 | truncation = True, 178 | return_attention_mask = True, 179 | add_special_tokens = True, 180 | return_tensors = 'pt' 181 | ) 182 | tokens = tokenizer_out["input_ids"] 183 | mask = tokenizer_out["attention_mask"] 184 | embs = T5.cond_stage_model.transformer( 185 | input_ids = tokens.to(T5.load_device), 186 | attention_mask = mask.to(T5.load_device), 187 | )['last_hidden_state'].float()[:, None] 188 | masked_embs, keep_index = self.mask_feature( 189 | embs.detach().to("cpu"), 190 | mask.detach().to("cpu") 191 | ) 192 | masked_embs = masked_embs.squeeze(0) # match CLIP/internal 193 | print("Encoded T5:", masked_embs.shape) 194 | return ([[masked_embs, {}]], ) 195 | 196 | NODE_CLASS_MAPPINGS = { 197 | "PixArtCheckpointLoader" : PixArtCheckpointLoader, 198 | "PixArtResolutionSelect" : PixArtResolutionSelect, 199 | "PixArtLoraLoader" : PixArtLoraLoader, 200 | "PixArtT5TextEncode" : PixArtT5TextEncode, 201 | "PixArtResolutionCond" : PixArtResolutionCond, 202 | "PixArtControlNetCond" : PixArtControlNetCond, 203 | } 204 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ComfyUI HunyuanDiT (WIP) 2 | 3 | [HunyuanDiT](https://github.com/Tencent/HunyuanDiT) 4 | 5 | ``` 6 | huggingface-cli download --resume-download Tencent-Hunyuan/HunyuanDiT --local-dir ComfyUI/models/diffusers --local-dir-use-symlinks False 7 | ``` 8 | 9 | sdxl vae 10 | 11 | ## workflow 12 | 13 | [Recommended complete Workflow](https://github.com/chaojie/ComfyUI_ExtraModels/blob/main/HunYuan/wf.json) 14 | 15 | -------------------------------------------------------------------------------- /T5/LICENSE-T5: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /T5/loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import comfy.utils 4 | import comfy.model_patcher 5 | from comfy import model_management 6 | import folder_paths 7 | 8 | from .t5v11 import T5v11Model, T5v11Tokenizer 9 | 10 | class EXM_T5v11: 11 | def __init__(self, textmodel_ver="xxl", embedding_directory=None, textmodel_path=None, no_init=False, device="cpu", dtype=None): 12 | if no_init: 13 | return 14 | 15 | if device == "auto": 16 | size = 0 17 | self.load_device = model_management.text_encoder_device() 18 | self.offload_device = model_management.text_encoder_offload_device() 19 | self.init_device = "cpu" 20 | elif dtype == "bnb8bit": 21 | # BNB doesn't support size enum 22 | size = 12.4 * (1024**3) 23 | # Or moving between devices 24 | self.load_device = model_management.get_torch_device() 25 | self.offload_device = self.load_device 26 | self.init_device = self.load_device 27 | elif dtype == "bnb4bit": 28 | # This seems to use the same VRAM as 8bit on Pascal? 29 | size = 6.2 * (1024**3) 30 | self.load_device = model_management.get_torch_device() 31 | self.offload_device = self.load_device 32 | self.init_device = self.load_device 33 | elif device == "cpu": 34 | size = 0 35 | self.load_device = "cpu" 36 | self.offload_device = "cpu" 37 | self.init_device="cpu" 38 | elif device.startswith("cuda"): 39 | print("Direct CUDA device override!\nVRAM will not be freed by default.") 40 | size = 0 41 | self.load_device = device 42 | self.offload_device = device 43 | self.init_device = device 44 | else: 45 | size = 0 46 | self.load_device = model_management.get_torch_device() 47 | self.offload_device = "cpu" 48 | self.init_device="cpu" 49 | 50 | self.cond_stage_model = T5v11Model( 51 | textmodel_ver = textmodel_ver, 52 | textmodel_path = textmodel_path, 53 | device = device, 54 | dtype = dtype, 55 | ) 56 | self.tokenizer = T5v11Tokenizer(embedding_directory=embedding_directory) 57 | self.patcher = comfy.model_patcher.ModelPatcher( 58 | self.cond_stage_model, 59 | load_device = self.load_device, 60 | offload_device = self.offload_device, 61 | current_device = self.load_device, 62 | size = size, 63 | ) 64 | 65 | def clone(self): 66 | n = T5(no_init=True) 67 | n.patcher = self.patcher.clone() 68 | n.cond_stage_model = self.cond_stage_model 69 | n.tokenizer = self.tokenizer 70 | return n 71 | 72 | def tokenize(self, text, return_word_ids=False): 73 | return self.tokenizer.tokenize_with_weights(text, return_word_ids) 74 | 75 | def encode_from_tokens(self, tokens): 76 | self.load_model() 77 | return self.cond_stage_model.encode_token_weights(tokens) 78 | 79 | def encode(self, text): 80 | tokens = self.tokenize(text) 81 | return self.encode_from_tokens(tokens) 82 | 83 | def load_sd(self, sd): 84 | return self.cond_stage_model.load_sd(sd) 85 | 86 | def get_sd(self): 87 | return self.cond_stage_model.state_dict() 88 | 89 | def load_model(self): 90 | if self.load_device != "cpu": 91 | model_management.load_model_gpu(self.patcher) 92 | return self.patcher 93 | 94 | def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): 95 | return self.patcher.add_patches(patches, strength_patch, strength_model) 96 | 97 | def get_key_patches(self): 98 | return self.patcher.get_key_patches() 99 | 100 | 101 | def load_t5(model_type, model_ver, model_path, path_type="file", device="cpu", dtype=None): 102 | assert model_type in ["t5v11"] # Only supported model for now 103 | model_args = { 104 | "textmodel_ver" : model_ver, 105 | "device" : device, 106 | "dtype" : dtype, 107 | } 108 | 109 | if path_type == "folder": 110 | # pass directly to transformers and initialize there 111 | # this is to avoid having to handle multi-file state dict loading for now. 112 | model_args["textmodel_path"] = os.path.dirname(model_path) 113 | return EXM_T5v11(**model_args) 114 | else: 115 | # for some reason this returns garbage with torch.int8 weights, or just OOMs 116 | model = EXM_T5v11(**model_args) 117 | sd = comfy.utils.load_torch_file(model_path) 118 | model.load_sd(sd) 119 | return model 120 | -------------------------------------------------------------------------------- /T5/nodes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import folder_paths 5 | 6 | from .loader import load_t5 7 | from ..utils.dtype import string_to_dtype 8 | 9 | # initialize custom folder path 10 | os.makedirs( 11 | os.path.join(folder_paths.models_dir,"t5"), 12 | exist_ok = True, 13 | ) 14 | folder_paths.folder_names_and_paths["t5"] = ( 15 | [ 16 | os.path.join(folder_paths.models_dir,"t5"), 17 | *folder_paths.folder_names_and_paths.get("t5", [[],set()])[0] 18 | ], 19 | folder_paths.supported_pt_extensions 20 | ) 21 | 22 | dtypes = [ 23 | "default", 24 | "auto (comfy)", 25 | "FP32", 26 | "FP16", 27 | # Note: remove these at some point 28 | "bnb8bit", 29 | "bnb4bit", 30 | ] 31 | try: torch.float8_e5m2 32 | except AttributeError: print("Torch version too old for FP8") 33 | else: dtypes += ["FP8 E4M3", "FP8 E5M2"] 34 | 35 | class T5v11Loader: 36 | @classmethod 37 | def INPUT_TYPES(s): 38 | devices = ["auto", "cpu", "gpu"] 39 | # hack for using second GPU as offload 40 | for k in range(1, torch.cuda.device_count()): 41 | devices.append(f"cuda:{k}") 42 | return { 43 | "required": { 44 | "t5v11_name": (folder_paths.get_filename_list("t5"),), 45 | "t5v11_ver": (["xxl"],), 46 | "path_type": (["folder", "file"],), 47 | "device": (devices, {"default":"cpu"}), 48 | "dtype": (dtypes,), 49 | } 50 | } 51 | RETURN_TYPES = ("T5",) 52 | FUNCTION = "load_model" 53 | CATEGORY = "ExtraModels/T5" 54 | TITLE = "T5v1.1 Loader" 55 | 56 | def load_model(self, t5v11_name, t5v11_ver, path_type, device, dtype): 57 | if "bnb" in dtype: 58 | assert device == "gpu" or device.startswith("cuda"), "BitsAndBytes only works on CUDA! Set device to 'gpu'." 59 | dtype = string_to_dtype(dtype, "text_encoder") 60 | if device == "cpu": 61 | assert dtype in [None, torch.float32], f"Can't use dtype '{dtype}' with CPU! Set dtype to 'default'." 62 | 63 | return (load_t5( 64 | model_type = "t5v11", 65 | model_ver = t5v11_ver, 66 | model_path = folder_paths.get_full_path("t5", t5v11_name), 67 | path_type = path_type, 68 | device = device, 69 | dtype = dtype, 70 | ),) 71 | 72 | class T5TextEncode: 73 | @classmethod 74 | def INPUT_TYPES(s): 75 | return { 76 | "required": { 77 | "text": ("STRING", {"multiline": True}), 78 | "T5": ("T5",), 79 | } 80 | } 81 | 82 | RETURN_TYPES = ("CONDITIONING",) 83 | FUNCTION = "encode" 84 | CATEGORY = "ExtraModels/T5" 85 | TITLE = "T5 Text Encode" 86 | 87 | def encode(self, text, T5=None): 88 | tokens = T5.tokenize(text) 89 | cond = T5.encode_from_tokens(tokens) 90 | return ([[cond, {}]], ) 91 | 92 | NODE_CLASS_MAPPINGS = { 93 | "T5v11Loader" : T5v11Loader, 94 | "T5TextEncode" : T5TextEncode, 95 | } 96 | -------------------------------------------------------------------------------- /T5/t5_tokenizer/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | {"eos_token": "", "unk_token": "", "pad_token": "", "additional_special_tokens": ["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", ""]} -------------------------------------------------------------------------------- /T5/t5_tokenizer/spiece.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI_ExtraModels/d8b11e401de830ccfb27fa84bdd0091b52408af8/T5/t5_tokenizer/spiece.model -------------------------------------------------------------------------------- /T5/t5_tokenizer/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | {"eos_token": "", "unk_token": "", "pad_token": "", "extra_ids": 100, "additional_special_tokens": ["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", ""], "model_max_length": 512, "name_or_path": "t5-small"} -------------------------------------------------------------------------------- /T5/t5v11-xxl_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "google/t5-v1_1-xxl", 3 | "architectures": [ 4 | "T5EncoderModel" 5 | ], 6 | "d_ff": 10240, 7 | "d_kv": 64, 8 | "d_model": 4096, 9 | "decoder_start_token_id": 0, 10 | "dense_act_fn": "gelu_new", 11 | "dropout_rate": 0.1, 12 | "eos_token_id": 1, 13 | "feed_forward_proj": "gated-gelu", 14 | "initializer_factor": 1.0, 15 | "is_encoder_decoder": true, 16 | "is_gated_act": true, 17 | "layer_norm_epsilon": 1e-06, 18 | "model_type": "t5", 19 | "num_decoder_layers": 24, 20 | "num_heads": 64, 21 | "num_layers": 24, 22 | "output_past": true, 23 | "pad_token_id": 0, 24 | "relative_attention_max_distance": 128, 25 | "relative_attention_num_buckets": 32, 26 | "tie_word_embeddings": false, 27 | "torch_dtype": "float32", 28 | "transformers_version": "4.21.1", 29 | "use_cache": true, 30 | "vocab_size": 32128 31 | } 32 | -------------------------------------------------------------------------------- /T5/t5v11.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from comfyui CLIP code. 3 | https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/sd1_clip.py 4 | """ 5 | 6 | import os 7 | 8 | from transformers import T5Tokenizer, T5EncoderModel, T5Config, modeling_utils 9 | import torch 10 | import traceback 11 | import zipfile 12 | from comfy import model_management 13 | 14 | from comfy.sd1_clip import parse_parentheses, token_weights, escape_important, unescape_important, safe_load_embed_zip, expand_directory_list, load_embed 15 | 16 | class T5v11Model(torch.nn.Module): 17 | def __init__(self, textmodel_ver="xxl", textmodel_json_config=None, textmodel_path=None, device="cpu", max_length=120, freeze=True, dtype=None): 18 | super().__init__() 19 | 20 | self.num_layers = 24 21 | self.max_length = max_length 22 | self.bnb = False 23 | 24 | if textmodel_path is not None: 25 | model_args = {} 26 | model_args["low_cpu_mem_usage"] = True # Don't take 2x system ram on cpu 27 | if dtype == "bnb8bit": 28 | self.bnb = True 29 | model_args["load_in_8bit"] = True 30 | elif dtype == "bnb4bit": 31 | self.bnb = True 32 | model_args["load_in_4bit"] = True 33 | else: 34 | if dtype: model_args["torch_dtype"] = dtype 35 | self.bnb = False 36 | # second GPU offload hack part 2 37 | if device.startswith("cuda"): 38 | model_args["device_map"] = device 39 | print(f"Loading T5 from '{textmodel_path}'") 40 | self.transformer = T5EncoderModel.from_pretrained(textmodel_path, **model_args) 41 | else: 42 | if textmodel_json_config is None: 43 | textmodel_json_config = os.path.join( 44 | os.path.dirname(os.path.realpath(__file__)), 45 | f"t5v11-{textmodel_ver}_config.json" 46 | ) 47 | config = T5Config.from_json_file(textmodel_json_config) 48 | self.num_layers = config.num_hidden_layers 49 | with modeling_utils.no_init_weights(): 50 | self.transformer = T5EncoderModel(config) 51 | 52 | if freeze: 53 | self.freeze() 54 | self.empty_tokens = [[0] * self.max_length] # token 55 | 56 | def freeze(self): 57 | self.transformer = self.transformer.eval() 58 | for param in self.parameters(): 59 | param.requires_grad = False 60 | 61 | def forward(self, tokens): 62 | device = self.transformer.get_input_embeddings().weight.device 63 | tokens = torch.LongTensor(tokens).to(device) 64 | attention_mask = torch.zeros_like(tokens) 65 | max_token = 1 # token 66 | for x in range(attention_mask.shape[0]): 67 | for y in range(attention_mask.shape[1]): 68 | attention_mask[x, y] = 1 69 | if tokens[x, y] == max_token: 70 | break 71 | 72 | outputs = self.transformer(input_ids=tokens, attention_mask=attention_mask) 73 | 74 | z = outputs['last_hidden_state'] 75 | z.detach().cpu().float() 76 | return z 77 | 78 | def encode(self, tokens): 79 | return self(tokens) 80 | 81 | def load_sd(self, sd): 82 | return self.transformer.load_state_dict(sd, strict=False) 83 | 84 | def to(self, *args, **kwargs): 85 | """BNB complains if you try to change the device or dtype""" 86 | if self.bnb: 87 | print("Thanks to BitsAndBytes, T5 becomes an immovable rock.", args, kwargs) 88 | else: 89 | self.transformer.to(*args, **kwargs) 90 | 91 | def encode_token_weights(self, token_weight_pairs, return_padded=False): 92 | to_encode = list(self.empty_tokens) 93 | for x in token_weight_pairs: 94 | tokens = list(map(lambda a: a[0], x)) 95 | to_encode.append(tokens) 96 | 97 | out = self.encode(to_encode) 98 | z_empty = out[0:1] 99 | 100 | output = [] 101 | for k in range(1, out.shape[0]): 102 | z = out[k:k+1] 103 | for i in range(len(z)): 104 | for j in range(len(z[i])): 105 | weight = token_weight_pairs[k - 1][j][1] 106 | z[i][j] = (z[i][j] - z_empty[0][j]) * weight + z_empty[0][j] 107 | output.append(z) 108 | 109 | if (len(output) == 0): 110 | return z_empty.cpu() 111 | 112 | out = torch.cat(output, dim=-2) 113 | if not return_padded: 114 | # Count number of tokens that aren't , then use that number as an index. 115 | keep_index = sum([sum([1 for y in x if y[0] != 0]) for x in token_weight_pairs]) 116 | out = out[:, :keep_index, :] 117 | return out 118 | 119 | 120 | class T5v11Tokenizer: 121 | """ 122 | This is largely just based on the ComfyUI CLIP code. 123 | """ 124 | def __init__(self, tokenizer_path=None, max_length=120, embedding_directory=None, embedding_size=4096, embedding_key='t5'): 125 | if tokenizer_path is None: 126 | tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") 127 | self.tokenizer = T5Tokenizer.from_pretrained(tokenizer_path) 128 | self.max_length = max_length 129 | self.max_tokens_per_section = self.max_length - 1 # but no 130 | 131 | self.pad_token = self.tokenizer("", add_special_tokens=False)["input_ids"][0] 132 | self.end_token = self.tokenizer("", add_special_tokens=False)["input_ids"][0] 133 | vocab = self.tokenizer.get_vocab() 134 | self.inv_vocab = {v: k for k, v in vocab.items()} 135 | self.embedding_directory = embedding_directory 136 | self.max_word_length = 8 # haven't verified this 137 | self.embedding_identifier = "embedding:" 138 | self.embedding_size = embedding_size 139 | self.embedding_key = embedding_key 140 | 141 | def _try_get_embedding(self, embedding_name:str): 142 | ''' 143 | Takes a potential embedding name and tries to retrieve it. 144 | Returns a Tuple consisting of the embedding and any leftover string, embedding can be None. 145 | ''' 146 | embed = load_embed(embedding_name, self.embedding_directory, self.embedding_size, self.embedding_key) 147 | if embed is None: 148 | stripped = embedding_name.strip(',') 149 | if len(stripped) < len(embedding_name): 150 | embed = load_embed(stripped, self.embedding_directory, self.embedding_size, self.embedding_key) 151 | return (embed, embedding_name[len(stripped):]) 152 | return (embed, "") 153 | 154 | def tokenize_with_weights(self, text:str, return_word_ids=False): 155 | ''' 156 | Takes a prompt and converts it to a list of (token, weight, word id) elements. 157 | Tokens can both be integer tokens and pre computed T5 tensors. 158 | Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens. 159 | Returned list has the dimensions NxM where M is the input size of T5 160 | ''' 161 | pad_token = self.pad_token 162 | text = escape_important(text) 163 | parsed_weights = token_weights(text, 1.0) 164 | 165 | #tokenize words 166 | tokens = [] 167 | for weighted_segment, weight in parsed_weights: 168 | to_tokenize = unescape_important(weighted_segment).replace("\n", " ").split(' ') 169 | to_tokenize = [x for x in to_tokenize if x != ""] 170 | for word in to_tokenize: 171 | #if we find an embedding, deal with the embedding 172 | if word.startswith(self.embedding_identifier) and self.embedding_directory is not None: 173 | embedding_name = word[len(self.embedding_identifier):].strip('\n') 174 | embed, leftover = self._try_get_embedding(embedding_name) 175 | if embed is None: 176 | print(f"warning, embedding:{embedding_name} does not exist, ignoring") 177 | else: 178 | if len(embed.shape) == 1: 179 | tokens.append([(embed, weight)]) 180 | else: 181 | tokens.append([(embed[x], weight) for x in range(embed.shape[0])]) 182 | #if we accidentally have leftover text, continue parsing using leftover, else move on to next word 183 | if leftover != "": 184 | word = leftover 185 | else: 186 | continue 187 | #parse word 188 | tokens.append([(t, weight) for t in self.tokenizer(word, add_special_tokens=False)["input_ids"]]) 189 | 190 | #reshape token array to T5 input size 191 | batched_tokens = [] 192 | batch = [] 193 | batched_tokens.append(batch) 194 | for i, t_group in enumerate(tokens): 195 | #determine if we're going to try and keep the tokens in a single batch 196 | is_large = len(t_group) >= self.max_word_length 197 | 198 | while len(t_group) > 0: 199 | if len(t_group) + len(batch) > self.max_length - 1: 200 | remaining_length = self.max_length - len(batch) - 1 201 | #break word in two and add end token 202 | if is_large: 203 | batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]]) 204 | batch.append((self.end_token, 1.0, 0)) 205 | t_group = t_group[remaining_length:] 206 | #add end token and pad 207 | else: 208 | batch.append((self.end_token, 1.0, 0)) 209 | batch.extend([(self.pad_token, 1.0, 0)] * (remaining_length)) 210 | #start new batch 211 | batch = [] 212 | batched_tokens.append(batch) 213 | else: 214 | batch.extend([(t,w,i+1) for t,w in t_group]) 215 | t_group = [] 216 | 217 | # fill last batch 218 | batch.extend([(self.end_token, 1.0, 0)] + [(self.pad_token, 1.0, 0)] * (self.max_length - len(batch) - 1)) 219 | # instead of filling, just add EOS (DEBUG) 220 | # batch.extend([(self.end_token, 1.0, 0)]) 221 | 222 | if not return_word_ids: 223 | batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens] 224 | return batched_tokens 225 | 226 | def untokenize(self, token_weight_pair): 227 | return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair)) 228 | -------------------------------------------------------------------------------- /VAE/conf.py: -------------------------------------------------------------------------------- 1 | """ 2 | List of all VAE configs, with training parts stripped. 3 | """ 4 | vae_conf = { 5 | ### AutoencoderKL ### 6 | "kl-f4": { 7 | "type" : "AutoencoderKL", 8 | "embed_scale" : 4, 9 | "embed_dim" : 3, 10 | "z_channels" : 3, 11 | "double_z" : True, 12 | "resolution" : 256, 13 | "in_channels" : 3, 14 | "out_ch" : 3, 15 | "ch" : 128, 16 | "ch_mult" : [1,2,4], 17 | "num_res_blocks" : 2, 18 | "attn_resolutions" : [], 19 | }, 20 | "kl-f8": { # Default SD1.5 VAE 21 | "type" : "AutoencoderKL", 22 | "embed_scale" : 8, 23 | "embed_dim" : 4, 24 | "z_channels" : 4, 25 | "double_z" : True, 26 | "resolution" : 256, 27 | "in_channels" : 3, 28 | "out_ch" : 3, 29 | "ch" : 128, 30 | "ch_mult" : [1,2,4,4], 31 | "num_res_blocks" : 2, 32 | "attn_resolutions" : [], 33 | }, 34 | "kl-f16": { 35 | "type" : "AutoencoderKL", 36 | "embed_scale" : 16, 37 | "embed_dim" : 16, 38 | "z_channels" : 16, 39 | "double_z" : True, 40 | "resolution" : 256, 41 | "in_channels" : 3, 42 | "out_ch" : 3, 43 | "ch" : 128, 44 | "ch_mult" : [1,1,2,2,4], 45 | "num_res_blocks" : 2, 46 | "attn_resolutions" : [16], 47 | }, 48 | "kl-f32": { 49 | "type" : "AutoencoderKL", 50 | "embed_scale" : 32, 51 | "embed_dim" : 64, 52 | "z_channels" : 64, 53 | "double_z" : True, 54 | "resolution" : 256, 55 | "in_channels" : 3, 56 | "out_ch" : 3, 57 | "ch" : 128, 58 | "ch_mult" : [1,1,2,2,4,4], 59 | "num_res_blocks" : 2, 60 | "attn_resolutions" : [16,8], 61 | }, 62 | ### VQModel ### 63 | "vq-f4": { 64 | "type" : "VQModel", 65 | "embed_scale" : 4, 66 | "n_embed" : 8192, 67 | "embed_dim" : 3, 68 | "z_channels" : 3, 69 | "double_z" : False, 70 | "resolution" : 256, 71 | "in_channels" : 3, 72 | "out_ch" : 3, 73 | "ch" : 128, 74 | "ch_mult" : [1,2,4], 75 | "num_res_blocks" : 2, 76 | "attn_resolutions" : [], 77 | }, 78 | "vq-f8": { 79 | "type" : "VQModel", 80 | "embed_scale" : 8, 81 | "n_embed" : 16384, 82 | "embed_dim" : 4, 83 | "z_channels" : 4, 84 | "double_z" : False, 85 | "resolution" : 256, 86 | "in_channels" : 3, 87 | "out_ch" : 3, 88 | "ch" : 128, 89 | "ch_mult" : [1,2,2,4], 90 | "num_res_blocks" : 2, 91 | "attn_resolutions" : [32], 92 | }, 93 | "vq-f16": { 94 | "type" : "VQModel", 95 | "embed_scale" : 16, 96 | "n_embed" : 16384, 97 | "embed_dim" : 8, 98 | "z_channels" : 8, 99 | "double_z" : False, 100 | "resolution" : 256, 101 | "in_channels" : 3, 102 | "out_ch" : 3, 103 | "ch" : 128, 104 | "ch_mult" : [1,1,2,2,4], 105 | "num_res_blocks" : 2, 106 | "attn_resolutions" : [16], 107 | }, 108 | # OpenAI Consistency Decoder 109 | "Consistency-Decoder": { 110 | "type" : "ConsistencyDecoder", 111 | "embed_scale" : 8, 112 | "embed_dim" : 4, 113 | }, 114 | # SAI Video Decoder 115 | "SDV-VideoDecoder": { 116 | "type" : "AutoencoderKL-VideoDecoder", 117 | "embed_scale" : 8, 118 | "embed_dim" : 4, 119 | "z_channels" : 4, 120 | "double_z" : True, 121 | "resolution" : 256, 122 | "in_channels" : 3, 123 | "out_ch" : 3, 124 | "ch" : 128, 125 | "ch_mult" : [1,2,4,4], 126 | "num_res_blocks" : 2, 127 | "attn_resolutions" : [], 128 | "video_kernel_size": [3, 1, 1] 129 | }, 130 | # Kandinsky-3 131 | "MoVQ3": { 132 | "type" : "MoVQ3", 133 | "embed_scale" : 8, 134 | "embed_dim" : 4, 135 | "double_z" : False, 136 | "z_channels" : 4, 137 | "resolution" : 256, 138 | "in_channels" : 3, 139 | "out_ch" : 3, 140 | "ch" : 256, 141 | "ch_mult" : [1, 2, 2, 4], 142 | "num_res_blocks" : 2, 143 | "attn_resolutions" : [32], 144 | } 145 | } 146 | -------------------------------------------------------------------------------- /VAE/loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import comfy.sd 3 | import comfy.utils 4 | from comfy import model_management 5 | from comfy import diffusers_convert 6 | 7 | class EXVAE(comfy.sd.VAE): 8 | def __init__(self, model_path, model_conf, dtype=torch.float32): 9 | self.latent_dim = model_conf["embed_dim"] 10 | self.latent_scale = model_conf["embed_scale"] 11 | self.device = model_management.vae_device() 12 | self.offload_device = model_management.vae_offload_device() 13 | self.vae_dtype = dtype 14 | 15 | sd = comfy.utils.load_torch_file(model_path) 16 | model = None 17 | if model_conf["type"] == "AutoencoderKL": 18 | from .models.kl import AutoencoderKL 19 | model = AutoencoderKL(config=model_conf) 20 | if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): 21 | sd = diffusers_convert.convert_vae_state_dict(sd) 22 | elif model_conf["type"] == "AutoencoderKL-VideoDecoder": 23 | from .models.temporal_ae import AutoencoderKL 24 | model = AutoencoderKL(config=model_conf) 25 | elif model_conf["type"] == "VQModel": 26 | from .models.vq import VQModel 27 | model = VQModel(config=model_conf) 28 | elif model_conf["type"] == "ConsistencyDecoder": 29 | from .models.consistencydecoder import ConsistencyDecoder 30 | model = ConsistencyDecoder() 31 | sd = {f"model.{k}":v for k,v in sd.items()} 32 | elif model_conf["type"] == "MoVQ3": 33 | from .models.movq3 import MoVQ 34 | model = MoVQ(model_conf) 35 | else: 36 | raise NotImplementedError(f"Unknown VAE type '{model_conf['type']}'") 37 | 38 | self.first_stage_model = model.eval() 39 | m, u = self.first_stage_model.load_state_dict(sd, strict=False) 40 | if len(m) > 0: print("Missing VAE keys", m) 41 | if len(u) > 0: print("Leftover VAE keys", u) 42 | 43 | self.first_stage_model.to(self.vae_dtype).to(self.offload_device) 44 | 45 | ### Encode/Decode functions below needed due to source repo having 4 VAE channels and a scale factor of 8 hardcoded 46 | def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): 47 | steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap) 48 | steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap) 49 | steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap) 50 | pbar = comfy.utils.ProgressBar(steps) 51 | 52 | decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)) + 1.0).float() 53 | output = torch.clamp(( 54 | (comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.latent_scale, pbar = pbar) + 55 | comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.latent_scale, pbar = pbar) + 56 | comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = self.latent_scale, pbar = pbar)) 57 | / 3.0) / 2.0, min=0.0, max=1.0) 58 | return output 59 | 60 | def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): 61 | steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap) 62 | steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap) 63 | steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap) 64 | pbar = comfy.utils.ProgressBar(steps) 65 | 66 | encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).float() 67 | samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.latent_scale), out_channels=self.latent_dim, pbar=pbar) 68 | samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.latent_scale), out_channels=self.latent_dim, pbar=pbar) 69 | samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.latent_scale), out_channels=self.latent_dim, pbar=pbar) 70 | samples /= 3.0 71 | return samples 72 | 73 | def decode(self, samples_in): 74 | self.first_stage_model = self.first_stage_model.to(self.device) 75 | try: 76 | memory_used = (2562 * samples_in.shape[2] * samples_in.shape[3] * 64) * 1.7 77 | model_management.free_memory(memory_used, self.device) 78 | free_memory = model_management.get_free_memory(self.device) 79 | batch_number = int(free_memory / memory_used) 80 | batch_number = max(1, batch_number) 81 | 82 | pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * self.latent_scale), round(samples_in.shape[3] * self.latent_scale)), device="cpu") 83 | for x in range(0, samples_in.shape[0], batch_number): 84 | samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device) 85 | pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples).cpu().float() + 1.0) / 2.0, min=0.0, max=1.0) 86 | except model_management.OOM_EXCEPTION as e: 87 | print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") 88 | pixel_samples = self.decode_tiled_(samples_in) 89 | 90 | self.first_stage_model = self.first_stage_model.to(self.offload_device) 91 | pixel_samples = pixel_samples.cpu().movedim(1,-1) 92 | return pixel_samples 93 | 94 | def encode(self, pixel_samples): 95 | self.first_stage_model = self.first_stage_model.to(self.device) 96 | pixel_samples = pixel_samples.movedim(-1,1) 97 | try: 98 | memory_used = (2078 * pixel_samples.shape[2] * pixel_samples.shape[3]) * 1.7 #NOTE: this constant along with the one in the decode above are estimated from the mem usage for the VAE and could change. 99 | model_management.free_memory(memory_used, self.device) 100 | free_memory = model_management.get_free_memory(self.device) 101 | batch_number = int(free_memory / memory_used) 102 | batch_number = max(1, batch_number) 103 | samples = torch.empty((pixel_samples.shape[0], self.latent_dim, round(pixel_samples.shape[2] // self.latent_scale), round(pixel_samples.shape[3] // self.latent_scale)), device="cpu") 104 | for x in range(0, pixel_samples.shape[0], batch_number): 105 | pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.vae_dtype).to(self.device) 106 | samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).cpu().float() 107 | 108 | except model_management.OOM_EXCEPTION as e: 109 | print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") 110 | samples = self.encode_tiled_(pixel_samples) 111 | 112 | self.first_stage_model = self.first_stage_model.to(self.offload_device) 113 | return samples 114 | -------------------------------------------------------------------------------- /VAE/models/LICENSE-Consistency-Decoder: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 OpenAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /VAE/models/LICENSE-Kandinsky-3: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /VAE/models/LICENSE-Latent-Diffusion: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /VAE/models/LICENSE-SAI: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Stability AI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /VAE/models/LICENSE-SDV: -------------------------------------------------------------------------------- 1 | STABLE VIDEO DIFFUSION NON-COMMERCIAL COMMUNITY LICENSE AGREEMENT 2 | Dated: November 21, 2023 3 | 4 | “AUP” means the Stability AI Acceptable Use Policy available at https://stability.ai/use-policy, as may be updated from time to time. 5 | 6 | "Agreement" means the terms and conditions for use, reproduction, distribution and modification of the Software Products set forth herein. 7 | "Derivative Work(s)” means (a) any derivative work of the Software Products as recognized by U.S. copyright laws and (b) any modifications to a Model, and any other model created which is based on or derived from the Model or the Model’s output. For clarity, Derivative Works do not include the output of any Model. 8 | “Documentation” means any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software. 9 | 10 | "Licensee" or "you" means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity's behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf. 11 | 12 | "Stability AI" or "we" means Stability AI Ltd. 13 | 14 | "Software" means, collectively, Stability AI’s proprietary models and algorithms, including machine-learning models, trained model weights and other elements of the foregoing, made available under this Agreement. 15 | 16 | “Software Products” means Software and Documentation. 17 | 18 | By using or distributing any portion or element of the Software Products, you agree to be bound by this Agreement. 19 | 20 | 21 | 22 | License Rights and Redistribution. 23 | Subject to your compliance with this Agreement, the AUP (which is hereby incorporated herein by reference), and the Documentation, Stability AI grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license under Stability AI’s intellectual property or other rights owned by Stability AI embodied in the Software Products to reproduce, distribute, and create Derivative Works of the Software Products for purposes other than commercial or production use. 24 | b. If you distribute or make the Software Products, or any Derivative Works thereof, available to a third party, the Software Products, Derivative Works, or any portion thereof, respectively, will remain subject to this Agreement and you must (i) provide a copy of this Agreement to such third party, and (ii) retain the following attribution notice within a "Notice" text file distributed as a part of such copies: "Stable Video Diffusion is licensed under the Stable Video Diffusion Research License, Copyright (c) Stability AI Ltd. All Rights Reserved.” If you create a Derivative Work of a Software Product, you may add your own attribution notices to the Notice file included with the Software Product, provided that you clearly indicate which attributions apply to the Software Product and you must state in the NOTICE file that you changed the Software Product and how it was modified. 25 | 2. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE SOFTWARE PRODUCTS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE SOFTWARE PRODUCTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE SOFTWARE PRODUCTS AND ANY OUTPUT AND RESULTS. 26 | 3. Limitation of Liability. IN NO EVENT WILL STABILITY AI OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF STABILITY AI OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING. 27 | 3. Intellectual Property. 28 | a. No trademark licenses are granted under this Agreement, and in connection with the Software Products, neither Stability AI nor Licensee may use any name or mark owned by or associated with the other or any of its affiliates, except as required for reasonable and customary use in describing and redistributing the Software Products. 29 | Subject to Stability AI’s ownership of the Software Products and Derivative Works made by or for Stability AI, with respect to any Derivative Works that are made by you, as between you and Stability AI, you are and will be the owner of such Derivative Works. 30 | If you institute litigation or other proceedings against Stability AI (including a cross-claim or counterclaim in a lawsuit) alleging that the Software Products or associated outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Stability AI from and against any claim by any third party arising out of or related to your use or distribution of the Software Products in violation of this Agreement. 31 | 4. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Software Products and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Stability AI may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the Software Products. Sections 2-4 shall survive the termination of this Agreement. 32 | -------------------------------------------------------------------------------- /VAE/models/LICENSE-Taming-Transformers: -------------------------------------------------------------------------------- 1 | Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 14 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 15 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 16 | IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, 17 | DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 18 | OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE 19 | OR OTHER DEALINGS IN THE SOFTWARE./ 20 | -------------------------------------------------------------------------------- /VAE/models/consistencydecoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | import torch.nn as nn 5 | 6 | """ 7 | Code below ported from https://github.com/openai/consistencydecoder 8 | """ 9 | 10 | def _extract_into_tensor(arr, timesteps, broadcast_shape): 11 | # from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/gaussian_diffusion.py#L895 """ 12 | res = arr[timesteps.to(torch.int).cpu()].float().to(timesteps.device) 13 | dims_to_append = len(broadcast_shape) - len(res.shape) 14 | return res[(...,) + (None,) * dims_to_append] 15 | 16 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 17 | # from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/gaussian_diffusion.py#L45 18 | betas = [] 19 | for i in range(num_diffusion_timesteps): 20 | t1 = i / num_diffusion_timesteps 21 | t2 = (i + 1) / num_diffusion_timesteps 22 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 23 | return torch.tensor(betas) 24 | 25 | class ConsistencyDecoder(torch.nn.Module): 26 | # From https://github.com/openai/consistencydecoder 27 | def __init__(self): 28 | super().__init__() 29 | self.model = ConvUNetVAE() 30 | self.n_distilled_steps = 64 31 | 32 | sigma_data = 0.5 33 | betas = betas_for_alpha_bar( 34 | 1024, lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 35 | ) 36 | alphas = 1.0 - betas 37 | alphas_cumprod = torch.cumprod(alphas, dim=0) 38 | self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) 39 | self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod) 40 | sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod) 41 | sigmas = torch.sqrt(1.0 / alphas_cumprod - 1) 42 | self.c_skip = ( 43 | sqrt_recip_alphas_cumprod 44 | * sigma_data**2 45 | / (sigmas**2 + sigma_data**2) 46 | ) 47 | self.c_out = sigmas * sigma_data / (sigmas**2 + sigma_data**2) ** 0.5 48 | self.c_in = sqrt_recip_alphas_cumprod / (sigmas**2 + sigma_data**2) ** 0.5 49 | 50 | @staticmethod 51 | def round_timesteps(timesteps, total_timesteps, n_distilled_steps, truncate_start=True): 52 | with torch.no_grad(): 53 | space = torch.div(total_timesteps, n_distilled_steps, rounding_mode="floor") 54 | rounded_timesteps = ( 55 | torch.div(timesteps, space, rounding_mode="floor") + 1 56 | ) * space 57 | if truncate_start: 58 | rounded_timesteps[rounded_timesteps == total_timesteps] -= space 59 | else: 60 | rounded_timesteps[rounded_timesteps == total_timesteps] -= space 61 | rounded_timesteps[rounded_timesteps == 0] += space 62 | return rounded_timesteps 63 | 64 | @staticmethod 65 | def ldm_transform_latent(z, extra_scale_factor=1): 66 | channel_means = [0.38862467, 0.02253063, 0.07381133, -0.0171294] 67 | channel_stds = [0.9654121, 1.0440036, 0.76147926, 0.77022034] 68 | 69 | if len(z.shape) != 4: 70 | raise ValueError() 71 | 72 | z = z * 0.18215 73 | channels = [z[:, i] for i in range(z.shape[1])] 74 | 75 | channels = [ 76 | extra_scale_factor * (c - channel_means[i]) / channel_stds[i] 77 | for i, c in enumerate(channels) 78 | ] 79 | return torch.stack(channels, dim=1) 80 | 81 | @torch.no_grad() 82 | def decode(self, features: torch.Tensor, schedule=[1.0, 0.5]): 83 | features = self.ldm_transform_latent(features) 84 | ts = self.round_timesteps( 85 | torch.arange(0, 1024), 86 | 1024, 87 | self.n_distilled_steps, 88 | truncate_start=False, 89 | ) 90 | shape = ( 91 | features.size(0), 92 | 3, 93 | 8 * features.size(2), 94 | 8 * features.size(3), 95 | ) 96 | x_start = torch.zeros(shape, device=features.device, dtype=features.dtype) 97 | schedule_timesteps = [int((1024 - 1) * s) for s in schedule] 98 | for i in schedule_timesteps: 99 | t = ts[i].item() 100 | t_ = torch.tensor([t] * features.shape[0], device=features.device) 101 | noise = torch.randn_like(x_start, device=features.device) 102 | x_start = ( 103 | _extract_into_tensor(self.sqrt_alphas_cumprod, t_, x_start.shape) 104 | * x_start 105 | + _extract_into_tensor( 106 | self.sqrt_one_minus_alphas_cumprod, t_, x_start.shape 107 | ) 108 | * noise 109 | ) 110 | c_in = _extract_into_tensor(self.c_in, t_, x_start.shape) 111 | model_output = self.model((c_in * x_start).to(features.dtype), t_, features=features) 112 | B, C = x_start.shape[:2] 113 | model_output, _ = torch.split(model_output, C, dim=1) 114 | pred_xstart = ( 115 | _extract_into_tensor(self.c_out, t_, x_start.shape) * model_output 116 | + _extract_into_tensor(self.c_skip, t_, x_start.shape) * x_start 117 | ).clamp(-1, 1) 118 | x_start = pred_xstart 119 | return x_start 120 | 121 | def encode(self, *args, **kwargs): 122 | raise NotImplementedError("ConsistencyDecoder can't be used for encoding!") 123 | 124 | """ 125 | Model definitions ported from: 126 | https://gist.github.com/madebyollin/865fa6a18d9099351ddbdfbe7299ccbf 127 | https://gist.github.com/mrsteyk/74ad3ec2f6f823111ae4c90e168505ac. 128 | """ 129 | 130 | class TimestepEmbedding(nn.Module): 131 | def __init__(self, n_time=1024, n_emb=320, n_out=1280) -> None: 132 | super().__init__() 133 | self.emb = nn.Embedding(n_time, n_emb) 134 | self.f_1 = nn.Linear(n_emb, n_out) 135 | self.f_2 = nn.Linear(n_out, n_out) 136 | 137 | def forward(self, x) -> torch.Tensor: 138 | x = self.emb(x) 139 | x = self.f_1(x) 140 | x = F.silu(x) 141 | return self.f_2(x) 142 | 143 | 144 | class ImageEmbedding(nn.Module): 145 | def __init__(self, in_channels=7, out_channels=320) -> None: 146 | super().__init__() 147 | self.f = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) 148 | 149 | def forward(self, x) -> torch.Tensor: 150 | return self.f(x) 151 | 152 | 153 | class ImageUnembedding(nn.Module): 154 | def __init__(self, in_channels=320, out_channels=6) -> None: 155 | super().__init__() 156 | self.gn = nn.GroupNorm(32, in_channels) 157 | self.f = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) 158 | 159 | def forward(self, x) -> torch.Tensor: 160 | return self.f(F.silu(self.gn(x))) 161 | 162 | 163 | class ConvResblock(nn.Module): 164 | def __init__(self, in_features=320, out_features=320) -> None: 165 | super().__init__() 166 | self.f_t = nn.Linear(1280, out_features * 2) 167 | 168 | self.gn_1 = nn.GroupNorm(32, in_features) 169 | self.f_1 = nn.Conv2d(in_features, out_features, kernel_size=3, padding=1) 170 | 171 | self.gn_2 = nn.GroupNorm(32, out_features) 172 | self.f_2 = nn.Conv2d(out_features, out_features, kernel_size=3, padding=1) 173 | 174 | skip_conv = in_features != out_features 175 | self.f_s = ( 176 | nn.Conv2d(in_features, out_features, kernel_size=1, padding=0) 177 | if skip_conv 178 | else nn.Identity() 179 | ) 180 | 181 | def forward(self, x, t): 182 | x_skip = x 183 | t = self.f_t(F.silu(t)) 184 | t = t.chunk(2, dim=1) 185 | t_1 = t[0].unsqueeze(dim=2).unsqueeze(dim=3) + 1 186 | t_2 = t[1].unsqueeze(dim=2).unsqueeze(dim=3) 187 | 188 | gn_1 = F.silu(self.gn_1(x)) 189 | f_1 = self.f_1(gn_1) 190 | 191 | gn_2 = self.gn_2(f_1) 192 | 193 | return self.f_s(x_skip) + self.f_2(F.silu(gn_2 * t_1 + t_2)) 194 | 195 | 196 | # Also ConvResblock 197 | class Downsample(nn.Module): 198 | def __init__(self, in_channels=320) -> None: 199 | super().__init__() 200 | self.f_t = nn.Linear(1280, in_channels * 2) 201 | 202 | self.gn_1 = nn.GroupNorm(32, in_channels) 203 | self.f_1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1) 204 | self.gn_2 = nn.GroupNorm(32, in_channels) 205 | 206 | self.f_2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1) 207 | 208 | def forward(self, x, t) -> torch.Tensor: 209 | x_skip = x 210 | 211 | t = self.f_t(F.silu(t)) 212 | t_1, t_2 = t.chunk(2, dim=1) 213 | t_1 = t_1.unsqueeze(2).unsqueeze(3) + 1 214 | t_2 = t_2.unsqueeze(2).unsqueeze(3) 215 | 216 | gn_1 = F.silu(self.gn_1(x)) 217 | avg_pool2d = F.avg_pool2d(gn_1, kernel_size=(2, 2), stride=None) 218 | f_1 = self.f_1(avg_pool2d) 219 | gn_2 = self.gn_2(f_1) 220 | 221 | f_2 = self.f_2(F.silu(t_2 + (t_1 * gn_2))) 222 | 223 | return f_2 + F.avg_pool2d(x_skip, kernel_size=(2, 2), stride=None) 224 | 225 | 226 | # Also ConvResblock 227 | class Upsample(nn.Module): 228 | def __init__(self, in_channels=1024) -> None: 229 | super().__init__() 230 | self.f_t = nn.Linear(1280, in_channels * 2) 231 | 232 | self.gn_1 = nn.GroupNorm(32, in_channels) 233 | self.f_1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1) 234 | self.gn_2 = nn.GroupNorm(32, in_channels) 235 | 236 | self.f_2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1) 237 | 238 | def forward(self, x, t) -> torch.Tensor: 239 | x_skip = x 240 | 241 | t = self.f_t(F.silu(t)) 242 | t_1, t_2 = t.chunk(2, dim=1) 243 | t_1 = t_1.unsqueeze(2).unsqueeze(3) + 1 244 | t_2 = t_2.unsqueeze(2).unsqueeze(3) 245 | 246 | gn_1 = F.silu(self.gn_1(x)) 247 | upsample = F.interpolate(gn_1.float(), scale_factor=2, mode="nearest").to(gn_1.dtype) 248 | 249 | f_1 = self.f_1(upsample) 250 | gn_2 = self.gn_2(f_1) 251 | 252 | f_2 = self.f_2(F.silu(t_2 + (t_1 * gn_2))) 253 | 254 | return f_2 + F.interpolate(x_skip.float(), scale_factor=2, mode="nearest").to(x_skip.dtype) 255 | 256 | 257 | class ConvUNetVAE(nn.Module): 258 | def __init__(self) -> None: 259 | super().__init__() 260 | self.embed_image = ImageEmbedding() 261 | self.embed_time = TimestepEmbedding() 262 | 263 | down_0 = nn.ModuleList( 264 | [ 265 | ConvResblock(320, 320), 266 | ConvResblock(320, 320), 267 | ConvResblock(320, 320), 268 | Downsample(320), 269 | ] 270 | ) 271 | down_1 = nn.ModuleList( 272 | [ 273 | ConvResblock(320, 640), 274 | ConvResblock(640, 640), 275 | ConvResblock(640, 640), 276 | Downsample(640), 277 | ] 278 | ) 279 | down_2 = nn.ModuleList( 280 | [ 281 | ConvResblock(640, 1024), 282 | ConvResblock(1024, 1024), 283 | ConvResblock(1024, 1024), 284 | Downsample(1024), 285 | ] 286 | ) 287 | down_3 = nn.ModuleList( 288 | [ 289 | ConvResblock(1024, 1024), 290 | ConvResblock(1024, 1024), 291 | ConvResblock(1024, 1024), 292 | ] 293 | ) 294 | self.down = nn.ModuleList( 295 | [ 296 | down_0, 297 | down_1, 298 | down_2, 299 | down_3, 300 | ] 301 | ) 302 | 303 | self.mid = nn.ModuleList( 304 | [ 305 | ConvResblock(1024, 1024), 306 | ConvResblock(1024, 1024), 307 | ] 308 | ) 309 | 310 | up_3 = nn.ModuleList( 311 | [ 312 | ConvResblock(1024 * 2, 1024), 313 | ConvResblock(1024 * 2, 1024), 314 | ConvResblock(1024 * 2, 1024), 315 | ConvResblock(1024 * 2, 1024), 316 | Upsample(1024), 317 | ] 318 | ) 319 | up_2 = nn.ModuleList( 320 | [ 321 | ConvResblock(1024 * 2, 1024), 322 | ConvResblock(1024 * 2, 1024), 323 | ConvResblock(1024 * 2, 1024), 324 | ConvResblock(1024 + 640, 1024), 325 | Upsample(1024), 326 | ] 327 | ) 328 | up_1 = nn.ModuleList( 329 | [ 330 | ConvResblock(1024 + 640, 640), 331 | ConvResblock(640 * 2, 640), 332 | ConvResblock(640 * 2, 640), 333 | ConvResblock(320 + 640, 640), 334 | Upsample(640), 335 | ] 336 | ) 337 | up_0 = nn.ModuleList( 338 | [ 339 | ConvResblock(320 + 640, 320), 340 | ConvResblock(320 * 2, 320), 341 | ConvResblock(320 * 2, 320), 342 | ConvResblock(320 * 2, 320), 343 | ] 344 | ) 345 | self.up = nn.ModuleList( 346 | [ 347 | up_0, 348 | up_1, 349 | up_2, 350 | up_3, 351 | ] 352 | ) 353 | 354 | self.output = ImageUnembedding() 355 | 356 | def forward(self, x, t, features) -> torch.Tensor: 357 | x = torch.cat([x, F.interpolate(features.float(),scale_factor=8,mode="nearest").to(features.dtype)], dim=1) 358 | t = self.embed_time(t) 359 | x = self.embed_image(x) 360 | 361 | skips = [x] 362 | for down in self.down: 363 | for block in down: 364 | x = block(x, t) 365 | skips.append(x) 366 | 367 | for i in range(2): 368 | x = self.mid[i](x, t) 369 | 370 | for up in self.up[::-1]: 371 | for block in up: 372 | if isinstance(block, ConvResblock): 373 | x = torch.concat([x, skips.pop()], dim=1) 374 | x = block(x, t) 375 | 376 | return self.output(x) 377 | -------------------------------------------------------------------------------- /VAE/models/vq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch import nn 4 | from einops import rearrange 5 | 6 | from .kl import Encoder, Decoder 7 | 8 | class VQModel(nn.Module): 9 | def __init__(self, 10 | config, 11 | remap=None, 12 | sane_index_shape=False, # tell vector quantizer to return indices as bhw 13 | ): 14 | super().__init__() 15 | self.embed_dim = config["embed_dim"] 16 | self.n_embed = config["n_embed"] 17 | self.encoder = Encoder(**config) 18 | self.decoder = Decoder(**config) 19 | self.quantize = VectorQuantizer(self.n_embed, self.embed_dim, beta=0.25, 20 | remap=remap, 21 | sane_index_shape=sane_index_shape) 22 | self.quant_conv = torch.nn.Conv2d(config["z_channels"], self.embed_dim, 1) 23 | self.post_quant_conv = torch.nn.Conv2d(self.embed_dim, config["z_channels"], 1) 24 | 25 | def encode(self, x): 26 | h = self.encoder(x) 27 | h = self.quant_conv(h) 28 | return h 29 | 30 | def decode(self, h, force_not_quantize=False): 31 | # also go through quantization layer 32 | if not force_not_quantize: 33 | quant, emb_loss, info = self.quantize(h) 34 | else: 35 | quant = h 36 | quant = self.post_quant_conv(quant) 37 | dec = self.decoder(quant) 38 | return dec 39 | 40 | def forward(self, input, return_pred_indices=False): 41 | quant, diff, (_,_,ind) = self.encode(input) 42 | dec = self.decode(quant) 43 | if return_pred_indices: 44 | return dec, diff, ind 45 | return dec, diff 46 | 47 | 48 | class VectorQuantizer(nn.Module): 49 | """ 50 | Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly 51 | avoids costly matrix multiplications and allows for post-hoc remapping of indices. 52 | """ 53 | # NOTE: due to a bug the beta term was applied to the wrong term. for 54 | # backwards compatibility we use the buggy version by default, but you can 55 | # specify legacy=False to fix it. 56 | def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", 57 | sane_index_shape=False, legacy=True): 58 | super().__init__() 59 | self.n_e = n_e 60 | self.e_dim = e_dim 61 | self.beta = beta 62 | self.legacy = legacy 63 | 64 | self.embedding = nn.Embedding(self.n_e, self.e_dim) 65 | self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) 66 | 67 | self.remap = remap 68 | if self.remap is not None: 69 | self.register_buffer("used", torch.tensor(np.load(self.remap))) 70 | self.re_embed = self.used.shape[0] 71 | self.unknown_index = unknown_index # "random" or "extra" or integer 72 | if self.unknown_index == "extra": 73 | self.unknown_index = self.re_embed 74 | self.re_embed = self.re_embed+1 75 | print(f"Remapping {self.n_e} indices to {self.re_embed} indices. " 76 | f"Using {self.unknown_index} for unknown indices.") 77 | else: 78 | self.re_embed = n_e 79 | 80 | self.sane_index_shape = sane_index_shape 81 | 82 | def remap_to_used(self, inds): 83 | ishape = inds.shape 84 | assert len(ishape)>1 85 | inds = inds.reshape(ishape[0],-1) 86 | used = self.used.to(inds) 87 | match = (inds[:,:,None]==used[None,None,...]).long() 88 | new = match.argmax(-1) 89 | unknown = match.sum(2)<1 90 | if self.unknown_index == "random": 91 | new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device) 92 | else: 93 | new[unknown] = self.unknown_index 94 | return new.reshape(ishape) 95 | 96 | def unmap_to_all(self, inds): 97 | ishape = inds.shape 98 | assert len(ishape)>1 99 | inds = inds.reshape(ishape[0],-1) 100 | used = self.used.to(inds) 101 | if self.re_embed > self.used.shape[0]: # extra token 102 | inds[inds>=self.used.shape[0]] = 0 # simply set to zero 103 | back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds) 104 | return back.reshape(ishape) 105 | 106 | def forward(self, z, temp=None, rescale_logits=False, return_logits=False): 107 | assert temp is None or temp==1.0, "Only for interface compatible with Gumbel" 108 | assert rescale_logits==False, "Only for interface compatible with Gumbel" 109 | assert return_logits==False, "Only for interface compatible with Gumbel" 110 | # reshape z -> (batch, height, width, channel) and flatten 111 | z = rearrange(z, 'b c h w -> b h w c').contiguous() 112 | z_flattened = z.view(-1, self.e_dim) 113 | # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z 114 | 115 | d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ 116 | torch.sum(self.embedding.weight**2, dim=1) - 2 * \ 117 | torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n')) 118 | 119 | min_encoding_indices = torch.argmin(d, dim=1) 120 | z_q = self.embedding(min_encoding_indices).view(z.shape) 121 | perplexity = None 122 | min_encodings = None 123 | 124 | # compute loss for embedding 125 | if not self.legacy: 126 | loss = self.beta * torch.mean((z_q.detach()-z)**2) + \ 127 | torch.mean((z_q - z.detach()) ** 2) 128 | else: 129 | loss = torch.mean((z_q.detach()-z)**2) + self.beta * \ 130 | torch.mean((z_q - z.detach()) ** 2) 131 | 132 | # preserve gradients 133 | z_q = z + (z_q - z).detach() 134 | 135 | # reshape back to match original input shape 136 | z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous() 137 | 138 | if self.remap is not None: 139 | min_encoding_indices = min_encoding_indices.reshape(z.shape[0],-1) # add batch axis 140 | min_encoding_indices = self.remap_to_used(min_encoding_indices) 141 | min_encoding_indices = min_encoding_indices.reshape(-1,1) # flatten 142 | 143 | if self.sane_index_shape: 144 | min_encoding_indices = min_encoding_indices.reshape( 145 | z_q.shape[0], z_q.shape[2], z_q.shape[3]) 146 | 147 | return z_q, loss, (perplexity, min_encodings, min_encoding_indices) 148 | 149 | def get_codebook_entry(self, indices, shape): 150 | # shape specifying (batch, height, width, channel) 151 | if self.remap is not None: 152 | indices = indices.reshape(shape[0],-1) # add batch axis 153 | indices = self.unmap_to_all(indices) 154 | indices = indices.reshape(-1) # flatten again 155 | 156 | # get quantized latent vectors 157 | z_q = self.embedding(indices) 158 | 159 | if shape is not None: 160 | z_q = z_q.view(shape) 161 | # reshape back to match original input shape 162 | z_q = z_q.permute(0, 3, 1, 2).contiguous() 163 | 164 | return z_q 165 | -------------------------------------------------------------------------------- /VAE/nodes.py: -------------------------------------------------------------------------------- 1 | import folder_paths 2 | 3 | from .conf import vae_conf 4 | from .loader import EXVAE 5 | 6 | from ..utils.dtype import string_to_dtype 7 | 8 | dtypes = [ 9 | "auto", 10 | "FP32", 11 | "FP16", 12 | "BF16" 13 | ] 14 | 15 | class ExtraVAELoader: 16 | @classmethod 17 | def INPUT_TYPES(s): 18 | return { 19 | "required": { 20 | "vae_name": (folder_paths.get_filename_list("vae"),), 21 | "vae_type": (list(vae_conf.keys()), {"default":"kl-f8"}), 22 | "dtype" : (dtypes,), 23 | } 24 | } 25 | RETURN_TYPES = ("VAE",) 26 | FUNCTION = "load_vae" 27 | CATEGORY = "ExtraModels" 28 | TITLE = "ExtraVAELoader" 29 | 30 | def load_vae(self, vae_name, vae_type, dtype): 31 | model_path = folder_paths.get_full_path("vae", vae_name) 32 | model_conf = vae_conf[vae_type] 33 | vae = EXVAE(model_path, model_conf, string_to_dtype(dtype, "vae")) 34 | return (vae,) 35 | 36 | NODE_CLASS_MAPPINGS = { 37 | "ExtraVAELoader" : ExtraVAELoader, 38 | } 39 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | # only import if running as a custom node 2 | try: 3 | import comfy.utils 4 | except ImportError: 5 | pass 6 | else: 7 | NODE_CLASS_MAPPINGS = {} 8 | 9 | # Deci Diffusion 10 | # from .DeciDiffusion.nodes import NODE_CLASS_MAPPINGS as DeciDiffusion_Nodes 11 | # NODE_CLASS_MAPPINGS.update(DeciDiffusion_Nodes) 12 | 13 | # HunYuan 14 | from .HunYuan.nodes import NODE_CLASS_MAPPINGS as HunYuan_Nodes 15 | NODE_CLASS_MAPPINGS.update(HunYuan_Nodes) 16 | 17 | # DiT 18 | from .DiT.nodes import NODE_CLASS_MAPPINGS as DiT_Nodes 19 | NODE_CLASS_MAPPINGS.update(DiT_Nodes) 20 | 21 | # PixArt 22 | from .PixArt.nodes import NODE_CLASS_MAPPINGS as PixArt_Nodes 23 | NODE_CLASS_MAPPINGS.update(PixArt_Nodes) 24 | 25 | # T5 26 | from .T5.nodes import NODE_CLASS_MAPPINGS as T5_Nodes 27 | NODE_CLASS_MAPPINGS.update(T5_Nodes) 28 | 29 | # VAE 30 | from .VAE.nodes import NODE_CLASS_MAPPINGS as VAE_Nodes 31 | NODE_CLASS_MAPPINGS.update(VAE_Nodes) 32 | 33 | NODE_DISPLAY_NAME_MAPPINGS = {k:v.TITLE for k,v in NODE_CLASS_MAPPINGS.items()} 34 | __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS'] 35 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | timm==0.6.13 2 | sentencepiece>=0.1.97 3 | transformers>=4.34.1 4 | accelerate>=0.23.0 5 | einops -------------------------------------------------------------------------------- /utils/dtype.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from comfy import model_management 3 | 4 | def string_to_dtype(s="none", mode=None): 5 | s = s.lower().strip() 6 | if s in ["default", "as-is"]: 7 | return None 8 | elif s in ["auto", "auto (comfy)"]: 9 | if mode == "vae": 10 | return model_management.vae_device() 11 | elif mode == "text_encoder": 12 | return model_management.text_encoder_dtype() 13 | elif mode == "unet": 14 | return model_management.unet_dtype() 15 | else: 16 | raise NotImplementedError(f"Unknown dtype mode '{mode}'") 17 | elif s in ["none", "auto (hf)", "auto (hf/bnb)"]: 18 | return None 19 | elif s in ["fp32", "float32", "float"]: 20 | return torch.float32 21 | elif s in ["bf16", "bfloat16"]: 22 | return torch.bfloat16 23 | elif s in ["fp16", "float16", "half"]: 24 | return torch.float16 25 | elif "fp8" in s or "float8" in s: 26 | if "e5m2" in s: 27 | return torch.float8_e5m2 28 | elif "e4m3" in s: 29 | return torch.float8_e4m3fn 30 | else: 31 | raise NotImplementedError(f"Unknown 8bit dtype '{s}'") 32 | elif "bnb" in s: 33 | assert s in ["bnb8bit", "bnb4bit"], f"Unknown bnb mode '{s}'" 34 | return s 35 | elif s is None: 36 | return None 37 | else: 38 | raise NotImplementedError(f"Unknown dtype '{s}'") 39 | --------------------------------------------------------------------------------