├── .gitignore
├── DiT
├── LICENSE-DiT
├── conf.py
├── labels
│ └── imagenet1000.json
├── loader.py
├── model.py
└── nodes.py
├── HunYuan
├── __init__.py
├── conf.py
├── loader.py
├── models
│ ├── __init__.py
│ ├── attn_layers.py
│ ├── embedders.py
│ ├── models.py
│ ├── norm_layers.py
│ ├── poolers.py
│ ├── posemb_layers.py
│ └── text_encoder.py
├── nodes.py
├── wf.json
└── wf.png
├── LICENSE
├── PixArt
├── LICENSE-PixArt
├── conf.py
├── diffusers_convert.py
├── loader.py
├── lora.py
├── models
│ ├── PixArt.py
│ ├── PixArtMS.py
│ ├── PixArt_blocks.py
│ ├── pixart_controlnet.py
│ └── utils.py
└── nodes.py
├── README.md
├── T5
├── LICENSE-ComfyUI
├── LICENSE-T5
├── loader.py
├── nodes.py
├── t5_tokenizer
│ ├── special_tokens_map.json
│ ├── spiece.model
│ └── tokenizer_config.json
├── t5v11-xxl_config.json
└── t5v11.py
├── VAE
├── conf.py
├── loader.py
├── models
│ ├── LICENSE-Consistency-Decoder
│ ├── LICENSE-Kandinsky-3
│ ├── LICENSE-Latent-Diffusion
│ ├── LICENSE-SAI
│ ├── LICENSE-SDV
│ ├── LICENSE-Taming-Transformers
│ ├── consistencydecoder.py
│ ├── kl.py
│ ├── movq3.py
│ ├── temporal_ae.py
│ └── vq.py
└── nodes.py
├── __init__.py
├── requirements.txt
└── utils
└── dtype.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
--------------------------------------------------------------------------------
/DiT/conf.py:
--------------------------------------------------------------------------------
1 | """
2 | List of all DiT model types / settings
3 | """
4 | sampling_settings = {
5 | "beta_schedule" : "sqrt_linear",
6 | "linear_start" : 0.0001,
7 | "linear_end" : 0.02,
8 | "timesteps" : 1000,
9 | }
10 |
11 | dit_conf = {
12 | "XL/2": { # DiT_XL_2
13 | "unet_config": {
14 | "depth" : 28,
15 | "num_heads" : 16,
16 | "patch_size" : 2,
17 | "hidden_size" : 1152,
18 | },
19 | "sampling_settings" : sampling_settings,
20 | },
21 | "XL/4": { # DiT_XL_4
22 | "unet_config": {
23 | "depth" : 28,
24 | "num_heads" : 16,
25 | "patch_size" : 4,
26 | "hidden_size" : 1152,
27 | },
28 | "sampling_settings" : sampling_settings,
29 | },
30 | "XL/8": { # DiT_XL_8
31 | "unet_config": {
32 | "depth" : 28,
33 | "num_heads" : 16,
34 | "patch_size" : 8,
35 | "hidden_size" : 1152,
36 | },
37 | "sampling_settings" : sampling_settings,
38 | },
39 | "L/2": { # DiT_L_2
40 | "unet_config": {
41 | "depth" : 24,
42 | "num_heads" : 16,
43 | "patch_size" : 2,
44 | "hidden_size" : 1024,
45 | },
46 | "sampling_settings" : sampling_settings,
47 | },
48 | "L/4": { # DiT_L_4
49 | "unet_config": {
50 | "depth" : 24,
51 | "num_heads" : 16,
52 | "patch_size" : 4,
53 | "hidden_size" : 1024,
54 | },
55 | "sampling_settings" : sampling_settings,
56 | },
57 | "L/8": { # DiT_L_8
58 | "unet_config": {
59 | "depth" : 24,
60 | "num_heads" : 16,
61 | "patch_size" : 8,
62 | "hidden_size" : 1024,
63 | },
64 | "sampling_settings" : sampling_settings,
65 | },
66 | "B/2": { # DiT_B_2
67 | "unet_config": {
68 | "depth" : 12,
69 | "num_heads" : 12,
70 | "patch_size" : 2,
71 | "hidden_size" : 768,
72 | },
73 | "sampling_settings" : sampling_settings,
74 | },
75 | "B/4": { # DiT_B_4
76 | "unet_config": {
77 | "depth" : 12,
78 | "num_heads" : 12,
79 | "patch_size" : 4,
80 | "hidden_size" : 768,
81 | },
82 | "sampling_settings" : sampling_settings,
83 | },
84 | "B/8": { # DiT_B_8
85 | "unet_config": {
86 | "depth" : 12,
87 | "num_heads" : 12,
88 | "patch_size" : 8,
89 | "hidden_size" : 768,
90 | },
91 | "sampling_settings" : sampling_settings,
92 | },
93 | "S/2": { # DiT_S_2
94 | "unet_config": {
95 | "depth" : 12,
96 | "num_heads" : 6,
97 | "patch_size" : 2,
98 | "hidden_size" : 384,
99 | },
100 | "sampling_settings" : sampling_settings,
101 | },
102 | "S/4": { # DiT_S_4
103 | "unet_config": {
104 | "depth" : 12,
105 | "num_heads" : 6,
106 | "patch_size" : 4,
107 | "hidden_size" : 384,
108 | },
109 | "sampling_settings" : sampling_settings,
110 | },
111 | "S/8": { # DiT_S_8
112 | "unet_config": {
113 | "depth" : 12,
114 | "num_heads" : 6,
115 | "patch_size" : 8,
116 | "hidden_size" : 384,
117 | },
118 | "sampling_settings" : sampling_settings,
119 | },
120 | }
121 |
--------------------------------------------------------------------------------
/DiT/loader.py:
--------------------------------------------------------------------------------
1 | import comfy.supported_models_base
2 | import comfy.latent_formats
3 | import comfy.model_patcher
4 | import comfy.model_base
5 | import comfy.utils
6 | import torch
7 | from comfy import model_management
8 |
9 | class EXM_DiT(comfy.supported_models_base.BASE):
10 | unet_config = {}
11 | unet_extra_config = {}
12 | latent_format = comfy.latent_formats.SD15
13 |
14 | def __init__(self, model_conf):
15 | self.unet_config = model_conf.get("unet_config", {})
16 | self.sampling_settings = model_conf.get("sampling_settings", {})
17 | self.latent_format = self.latent_format()
18 | # UNET is handled by extension
19 | self.unet_config["disable_unet_model_creation"] = True
20 |
21 | def model_type(self, state_dict, prefix=""):
22 | return comfy.model_base.ModelType.EPS
23 |
24 | def load_dit(model_path, model_conf):
25 | state_dict = comfy.utils.load_torch_file(model_path)
26 | state_dict = state_dict.get("model", state_dict)
27 | parameters = comfy.utils.calculate_parameters(state_dict)
28 | unet_dtype = model_management.unet_dtype(model_params=parameters)
29 | load_device = comfy.model_management.get_torch_device()
30 | offload_device = comfy.model_management.unet_offload_device()
31 |
32 | # ignore fp8/etc and use directly for now
33 | manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)
34 | if manual_cast_dtype:
35 | print(f"DiT: falling back to {manual_cast_dtype}")
36 | unet_dtype = manual_cast_dtype
37 |
38 | model_conf["unet_config"]["num_classes"] = state_dict["y_embedder.embedding_table.weight"].shape[0] - 1 # adj. for empty
39 |
40 | model_conf = EXM_DiT(model_conf)
41 | model = comfy.model_base.BaseModel(
42 | model_conf,
43 | model_type=comfy.model_base.ModelType.EPS,
44 | device=model_management.get_torch_device()
45 | )
46 |
47 | from .model import DiT
48 | model.diffusion_model = DiT(**model_conf.unet_config)
49 |
50 | model.diffusion_model.load_state_dict(state_dict)
51 | model.diffusion_model.dtype = unet_dtype
52 | model.diffusion_model.eval()
53 | model.diffusion_model.to(unet_dtype)
54 |
55 | model_patcher = comfy.model_patcher.ModelPatcher(
56 | model,
57 | load_device = load_device,
58 | offload_device = offload_device,
59 | current_device = "cpu",
60 | )
61 | return model_patcher
62 |
--------------------------------------------------------------------------------
/DiT/model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # GLIDE: https://github.com/openai/glide-text2im
9 | # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
10 | # --------------------------------------------------------
11 |
12 | import torch
13 | import torch.nn as nn
14 | import numpy as np
15 | import math
16 | from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
17 |
18 |
19 | def modulate(x, shift, scale):
20 | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
21 |
22 |
23 | #################################################################################
24 | # Embedding Layers for Timesteps and Class Labels #
25 | #################################################################################
26 |
27 | class TimestepEmbedder(nn.Module):
28 | """
29 | Embeds scalar timesteps into vector representations.
30 | """
31 | def __init__(self, hidden_size, frequency_embedding_size=256):
32 | super().__init__()
33 | self.mlp = nn.Sequential(
34 | nn.Linear(frequency_embedding_size, hidden_size, bias=True),
35 | nn.SiLU(),
36 | nn.Linear(hidden_size, hidden_size, bias=True),
37 | )
38 | self.frequency_embedding_size = frequency_embedding_size
39 |
40 | @staticmethod
41 | def timestep_embedding(t, dim, max_period=10000):
42 | """
43 | Create sinusoidal timestep embeddings.
44 | :param t: a 1-D Tensor of N indices, one per batch element.
45 | These may be fractional.
46 | :param dim: the dimension of the output.
47 | :param max_period: controls the minimum frequency of the embeddings.
48 | :return: an (N, D) Tensor of positional embeddings.
49 | """
50 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
51 | half = dim // 2
52 | freqs = torch.exp(
53 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
54 | ).to(device=t.device)
55 | args = t[:, None].float() * freqs[None]
56 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
57 | if dim % 2:
58 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
59 | return embedding.to(dtype=t.dtype)
60 |
61 | def forward(self, t):
62 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
63 | t_emb = self.mlp(t_freq)
64 | return t_emb
65 |
66 |
67 | class LabelEmbedder(nn.Module):
68 | """
69 | Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
70 | """
71 | def __init__(self, num_classes, hidden_size, dropout_prob):
72 | super().__init__()
73 | use_cfg_embedding = dropout_prob > 0
74 | self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
75 | self.num_classes = num_classes
76 | self.dropout_prob = dropout_prob
77 |
78 | def token_drop(self, labels, force_drop_ids=None):
79 | """
80 | Drops labels to enable classifier-free guidance.
81 | """
82 | if force_drop_ids is None:
83 | drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
84 | else:
85 | drop_ids = force_drop_ids == 1
86 | labels = torch.where(drop_ids, self.num_classes, labels)
87 | return labels
88 |
89 | def forward(self, labels, train, force_drop_ids=None):
90 | use_dropout = self.dropout_prob > 0
91 | if (train and use_dropout) or (force_drop_ids is not None):
92 | labels = self.token_drop(labels, force_drop_ids)
93 | embeddings = self.embedding_table(labels)
94 | return embeddings
95 |
96 |
97 | #################################################################################
98 | # Core DiT Model #
99 | #################################################################################
100 |
101 | class DiTBlock(nn.Module):
102 | """
103 | A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
104 | """
105 | def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
106 | super().__init__()
107 | self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
108 | self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
109 | self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
110 | mlp_hidden_dim = int(hidden_size * mlp_ratio)
111 | approx_gelu = lambda: nn.GELU(approximate="tanh")
112 | self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
113 | self.adaLN_modulation = nn.Sequential(
114 | nn.SiLU(),
115 | nn.Linear(hidden_size, 6 * hidden_size, bias=True)
116 | )
117 |
118 | def forward(self, x, c):
119 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
120 | x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
121 | x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
122 | return x
123 |
124 |
125 | class FinalLayer(nn.Module):
126 | """
127 | The final layer of DiT.
128 | """
129 | def __init__(self, hidden_size, patch_size, out_channels):
130 | super().__init__()
131 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
132 | self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
133 | self.adaLN_modulation = nn.Sequential(
134 | nn.SiLU(),
135 | nn.Linear(hidden_size, 2 * hidden_size, bias=True)
136 | )
137 |
138 | def forward(self, x, c):
139 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
140 | x = modulate(self.norm_final(x), shift, scale)
141 | x = self.linear(x)
142 | return x
143 |
144 |
145 | class DiT(nn.Module):
146 | """
147 | Diffusion model with a Transformer backbone.
148 | """
149 | def __init__(
150 | self,
151 | input_size=32,
152 | patch_size=2,
153 | in_channels=4,
154 | hidden_size=1152,
155 | depth=28,
156 | num_heads=16,
157 | mlp_ratio=4.0,
158 | class_dropout_prob=0.1,
159 | num_classes=1000,
160 | learn_sigma=True,
161 | **kwargs,
162 | ):
163 | super().__init__()
164 | self.learn_sigma = learn_sigma
165 | self.in_channels = in_channels
166 | self.out_channels = in_channels * 2 if learn_sigma else in_channels
167 | self.patch_size = patch_size
168 | self.num_heads = num_heads
169 |
170 | self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
171 | self.t_embedder = TimestepEmbedder(hidden_size)
172 | self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
173 | num_patches = self.x_embedder.num_patches
174 | # Will use fixed sin-cos embedding:
175 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
176 |
177 | self.blocks = nn.ModuleList([
178 | DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
179 | ])
180 | self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
181 | self.initialize_weights()
182 |
183 | def initialize_weights(self):
184 | # Initialize transformer layers:
185 | def _basic_init(module):
186 | if isinstance(module, nn.Linear):
187 | torch.nn.init.xavier_uniform_(module.weight)
188 | if module.bias is not None:
189 | nn.init.constant_(module.bias, 0)
190 | self.apply(_basic_init)
191 |
192 | # Initialize (and freeze) pos_embed by sin-cos embedding:
193 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))
194 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
195 |
196 | # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
197 | w = self.x_embedder.proj.weight.data
198 | nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
199 | nn.init.constant_(self.x_embedder.proj.bias, 0)
200 |
201 | # Initialize label embedding table:
202 | nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
203 |
204 | # Initialize timestep embedding MLP:
205 | nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
206 | nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
207 |
208 | # Zero-out adaLN modulation layers in DiT blocks:
209 | for block in self.blocks:
210 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
211 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
212 |
213 | # Zero-out output layers:
214 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
215 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
216 | nn.init.constant_(self.final_layer.linear.weight, 0)
217 | nn.init.constant_(self.final_layer.linear.bias, 0)
218 |
219 | def unpatchify(self, x):
220 | """
221 | x: (N, T, patch_size**2 * C)
222 | imgs: (N, H, W, C)
223 | """
224 | c = self.out_channels
225 | p = self.x_embedder.patch_size[0]
226 | h = w = int(x.shape[1] ** 0.5)
227 | assert h * w == x.shape[1]
228 |
229 | x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
230 | x = torch.einsum('nhwpqc->nchpwq', x)
231 | imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
232 | return imgs
233 |
234 | def forward_raw(self, x, t, y):
235 | """
236 | Forward pass of DiT.
237 | x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
238 | t: (N,) tensor of diffusion timesteps
239 | y: (N,) tensor of class labels
240 | """
241 | x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
242 | t = self.t_embedder(t) # (N, D)
243 | y = self.y_embedder(y, self.training) # (N, D)
244 | c = t + y # (N, D)
245 | for block in self.blocks:
246 | x = block(x, c) # (N, T, D)
247 | x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
248 | x = self.unpatchify(x) # (N, out_channels, H, W)
249 | return x
250 |
251 | def forward(self, x, timesteps, context, y=None, **kwargs):
252 | """
253 | Forward pass that adapts comfy input to original forward function
254 | x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
255 | timesteps: (N,) tensor of diffusion timesteps
256 | context: (N, [LabelID]) conditioning
257 | y: extra conditioning.
258 | """
259 | ## Remove outer array from cond
260 | context = context[:, 0]
261 |
262 | ## run original forward pass
263 | out = self.forward_raw(
264 | x = x.to(self.dtype),
265 | t = timesteps.to(self.dtype),
266 | y = context.to(torch.int),
267 | )
268 |
269 | ## only return EPS
270 | out = out.to(torch.float)
271 | eps, rest = out[:, :self.in_channels], out[:, self.in_channels:]
272 | return eps
273 |
274 | #################################################################################
275 | # Sine/Cosine Positional Embedding Functions #
276 | #################################################################################
277 | # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
278 |
279 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
280 | """
281 | grid_size: int of the grid height and width
282 | return:
283 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
284 | """
285 | grid_h = np.arange(grid_size, dtype=np.float32)
286 | grid_w = np.arange(grid_size, dtype=np.float32)
287 | grid = np.meshgrid(grid_w, grid_h) # here w goes first
288 | grid = np.stack(grid, axis=0)
289 |
290 | grid = grid.reshape([2, 1, grid_size, grid_size])
291 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
292 | if cls_token and extra_tokens > 0:
293 | pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
294 | return pos_embed
295 |
296 |
297 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
298 | assert embed_dim % 2 == 0
299 |
300 | # use half of dimensions to encode grid_h
301 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
302 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
303 |
304 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
305 | return emb
306 |
307 |
308 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
309 | """
310 | embed_dim: output dimension for each position
311 | pos: a list of positions to be encoded: size (M,)
312 | out: (M, D)
313 | """
314 | assert embed_dim % 2 == 0
315 | omega = np.arange(embed_dim // 2, dtype=np.float64)
316 | omega /= embed_dim / 2.
317 | omega = 1. / 10000**omega # (D/2,)
318 |
319 | pos = pos.reshape(-1) # (M,)
320 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
321 |
322 | emb_sin = np.sin(out) # (M, D/2)
323 | emb_cos = np.cos(out) # (M, D/2)
324 |
325 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
326 | return emb
327 |
--------------------------------------------------------------------------------
/DiT/nodes.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import torch
4 | import folder_paths
5 |
6 | from .conf import dit_conf
7 | from .loader import load_dit
8 |
9 | class DitCheckpointLoader:
10 | @classmethod
11 | def INPUT_TYPES(s):
12 | return {
13 | "required": {
14 | "ckpt_name": (folder_paths.get_filename_list("checkpoints"),),
15 | "model": (list(dit_conf.keys()),),
16 | "image_size": ([256, 512],),
17 | # "num_classes": ("INT", {"default": 1000, "min": 0,}),
18 | }
19 | }
20 | RETURN_TYPES = ("MODEL",)
21 | RETURN_NAMES = ("model",)
22 | FUNCTION = "load_checkpoint"
23 | CATEGORY = "ExtraModels/DiT"
24 | TITLE = "DitCheckpointLoader"
25 |
26 | def load_checkpoint(self, ckpt_name, model, image_size):
27 | ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
28 | model_conf = dit_conf[model]
29 | model_conf["unet_config"]["input_size"] = image_size // 8
30 | # model_conf["unet_config"]["num_classes"] = num_classes
31 | dit = load_dit(
32 | model_path = ckpt_path,
33 | model_conf = model_conf,
34 | )
35 | return (dit,)
36 |
37 | # todo: this needs frontend code to display properly
38 | def get_label_data(label_file="labels/imagenet1000.json"):
39 | label_path = os.path.join(
40 | os.path.dirname(os.path.realpath(__file__)),
41 | label_file,
42 | )
43 | label_data = {0: "None"}
44 | with open(label_path, "r") as f:
45 | label_data = json.loads(f.read())
46 | return label_data
47 | label_data = get_label_data()
48 |
49 | class DiTCondLabelSelect:
50 | @classmethod
51 | def INPUT_TYPES(s):
52 | global label_data
53 | return {
54 | "required": {
55 | "model" : ("MODEL",),
56 | "label_name": (list(label_data.values()),),
57 | }
58 | }
59 |
60 | RETURN_TYPES = ("CONDITIONING",)
61 | RETURN_NAMES = ("class",)
62 | FUNCTION = "cond_label"
63 | CATEGORY = "ExtraModels/DiT"
64 | TITLE = "DiTCondLabelSelect"
65 |
66 | def cond_label(self, model, label_name):
67 | global label_data
68 | class_labels = [int(k) for k,v in label_data.items() if v == label_name]
69 | y = torch.tensor([[class_labels[0]]]).to(torch.int)
70 | return ([[y, {}]], )
71 |
72 | class DiTCondLabelEmpty:
73 | @classmethod
74 | def INPUT_TYPES(s):
75 | global label_data
76 | return {
77 | "required": {
78 | "model" : ("MODEL",),
79 | }
80 | }
81 |
82 | RETURN_TYPES = ("CONDITIONING",)
83 | RETURN_NAMES = ("empty",)
84 | FUNCTION = "cond_empty"
85 | CATEGORY = "ExtraModels/DiT"
86 | TITLE = "DiTCondLabelEmpty"
87 |
88 | def cond_empty(self, model):
89 | # [ID of last class + 1] == [num_classes]
90 | y_null = model.model.model_config.unet_config["num_classes"]
91 | y = torch.tensor([[y_null]]).to(torch.int)
92 | return ([[y, {}]], )
93 |
94 | NODE_CLASS_MAPPINGS = {
95 | "DitCheckpointLoader" : DitCheckpointLoader,
96 | "DiTCondLabelSelect" : DiTCondLabelSelect,
97 | "DiTCondLabelEmpty" : DiTCondLabelEmpty,
98 | }
99 |
--------------------------------------------------------------------------------
/HunYuan/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI_ExtraModels/d8b11e401de830ccfb27fa84bdd0091b52408af8/HunYuan/__init__.py
--------------------------------------------------------------------------------
/HunYuan/conf.py:
--------------------------------------------------------------------------------
1 | """
2 | List of all DiT model types / settings
3 | """
4 | sampling_settings = {
5 | "beta_schedule" : "linear",
6 | "linear_start" : 0.00085,
7 | "linear_end" : 0.03,
8 | "timesteps" : 1000,
9 | 'steps_offset': 1,
10 | 'clip_sample': False,
11 | 'clip_sample_range': 1.0,
12 | 'beta_start': 0.00085,
13 | 'beta_end': 0.03,
14 | 'prediction_type': 'v_prediction',
15 | }
16 |
17 | dit_conf = {
18 | "DiT-g/2": { # DiT-g/2
19 | "unet_config": {
20 | "depth" : 40,
21 | "num_heads" : 16,
22 | "patch_size" : 2,
23 | "hidden_size" : 1408,
24 | 'mlp_ratio': 4.3637,
25 | },
26 | "sampling_settings" : sampling_settings,
27 | },
28 | "DiT-XL/2": { # DiT_XL_2
29 | "unet_config": {
30 | "depth" : 28,
31 | "num_heads" : 16,
32 | "patch_size" : 2,
33 | "hidden_size" : 1152,
34 | },
35 | "sampling_settings" : sampling_settings,
36 | },
37 | "DiT-L/2": { # DiT_L_2
38 | "unet_config": {
39 | "depth" : 24,
40 | "num_heads" : 16,
41 | "patch_size" : 2,
42 | "hidden_size" : 1024,
43 | },
44 | "sampling_settings" : sampling_settings,
45 | },
46 | "DiT-B/2": { # DiT_B_2
47 | "unet_config": {
48 | "depth" : 12,
49 | "num_heads" : 12,
50 | "patch_size" : 2,
51 | "hidden_size" : 768,
52 | },
53 | "sampling_settings" : sampling_settings,
54 | },
55 | }
56 |
--------------------------------------------------------------------------------
/HunYuan/loader.py:
--------------------------------------------------------------------------------
1 | import comfy.supported_models_base
2 | import comfy.latent_formats
3 | import comfy.model_patcher
4 | import comfy.model_base
5 | import comfy.utils
6 | import torch
7 | from comfy import model_management
8 | from ..PixArt.diffusers_convert import convert_state_dict
9 |
10 | class EXM_DiT(comfy.supported_models_base.BASE):
11 | unet_config = {}
12 | unet_extra_config = {}
13 | latent_format = comfy.latent_formats.SDXL
14 |
15 | def __init__(self, model_conf):
16 | self.model_target = model_conf.get("target")
17 | self.unet_config = model_conf.get("unet_config", {})
18 | self.sampling_settings = model_conf.get("sampling_settings", {})
19 | self.latent_format = self.latent_format()
20 | # UNET is handled by extension
21 | self.unet_config["disable_unet_model_creation"] = True
22 |
23 | def model_type(self, state_dict, prefix=""):
24 | return comfy.model_base.ModelType.V_PREDICTION
25 |
26 | class EXM_Dit_Model(comfy.model_base.BaseModel):
27 | def __init__(self, *args, **kwargs):
28 | super().__init__(*args, **kwargs)
29 |
30 | def extra_conds(self, **kwargs):
31 | out = super().extra_conds(**kwargs)
32 |
33 | clip_prompt_embeds = kwargs.get("clip_prompt_embeds", None)
34 | if clip_prompt_embeds is not None:
35 | out["clip_prompt_embeds"] = comfy.conds.CONDRegular(torch.tensor(clip_prompt_embeds))
36 |
37 | clip_attention_mask = kwargs.get("clip_attention_mask", None)
38 | if clip_attention_mask is not None:
39 | out["clip_attention_mask"] = comfy.conds.CONDRegular(torch.tensor(clip_attention_mask))
40 |
41 | mt5_prompt_embeds = kwargs.get("mt5_prompt_embeds", None)
42 | if mt5_prompt_embeds is not None:
43 | out["mt5_prompt_embeds"] = comfy.conds.CONDRegular(torch.tensor(mt5_prompt_embeds))
44 |
45 | mt5_attention_mask = kwargs.get("mt5_attention_mask", None)
46 | if mt5_attention_mask is not None:
47 | out["mt5_attention_mask"] = comfy.conds.CONDRegular(torch.tensor(mt5_attention_mask))
48 |
49 | return out
50 |
51 | def load_dit(model_path, model_conf):
52 | from comfy.diffusers_convert import convert_unet_state_dict
53 | state_dict = comfy.utils.load_torch_file(model_path)
54 | #state_dict=convert_unet_state_dict(state_dict)
55 | #state_dict = state_dict.get("model", state_dict)
56 |
57 | parameters = comfy.utils.calculate_parameters(state_dict)
58 | unet_dtype = torch.float16 #model_management.unet_dtype(model_params=parameters)
59 | load_device = comfy.model_management.get_torch_device()
60 | offload_device = comfy.model_management.unet_offload_device()
61 |
62 | # ignore fp8/etc and use directly for now
63 | #manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)
64 | #if manual_cast_dtype:
65 | # print(f"DiT: falling back to {manual_cast_dtype}")
66 | # unet_dtype = manual_cast_dtype
67 |
68 | #model_conf["unet_config"]["num_classes"] = state_dict["y_embedder.embedding_table.weight"].shape[0] - 1 # adj. for empty
69 |
70 | model_conf = EXM_DiT(model_conf)
71 |
72 | model = EXM_Dit_Model( # same as comfy.model_base.BaseModel
73 | model_conf,
74 | model_type=comfy.model_base.ModelType.V_PREDICTION,
75 | device=model_management.get_torch_device()
76 | )
77 |
78 | from .models.models import HunYuan
79 | model.diffusion_model = HunYuan(**model_conf.unet_config)
80 | model.latent_format = comfy.latent_formats.SDXL()
81 |
82 | model.diffusion_model.load_state_dict(state_dict)
83 | model.diffusion_model.dtype = unet_dtype
84 | model.diffusion_model.eval()
85 | model.diffusion_model.to(unet_dtype)
86 |
87 | model_patcher = comfy.model_patcher.ModelPatcher(
88 | model,
89 | load_device = load_device,
90 | offload_device = offload_device,
91 | current_device = "cpu",
92 | )
93 | return model_patcher
94 |
--------------------------------------------------------------------------------
/HunYuan/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI_ExtraModels/d8b11e401de830ccfb27fa84bdd0091b52408af8/HunYuan/models/__init__.py
--------------------------------------------------------------------------------
/HunYuan/models/embedders.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | from einops import repeat
5 |
6 | from timm.models.layers import to_2tuple
7 |
8 |
9 | class PatchEmbed(nn.Module):
10 | """ 2D Image to Patch Embedding
11 |
12 | Image to Patch Embedding using Conv2d
13 |
14 | A convolution based approach to patchifying a 2D image w/ embedding projection.
15 |
16 | Based on the impl in https://github.com/google-research/vision_transformer
17 |
18 | Hacked together by / Copyright 2020 Ross Wightman
19 |
20 | Remove the _assert function in forward function to be compatible with multi-resolution images.
21 | """
22 | def __init__(
23 | self,
24 | img_size=224,
25 | patch_size=16,
26 | in_chans=3,
27 | embed_dim=768,
28 | norm_layer=None,
29 | flatten=True,
30 | bias=True,
31 | ):
32 | super().__init__()
33 | if isinstance(img_size, int):
34 | img_size = to_2tuple(img_size)
35 | elif isinstance(img_size, (tuple, list)) and len(img_size) == 2:
36 | img_size = tuple(img_size)
37 | else:
38 | raise ValueError(f"img_size must be int or tuple/list of length 2. Got {img_size}")
39 | patch_size = to_2tuple(patch_size)
40 | self.img_size = img_size
41 | self.patch_size = patch_size
42 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
43 | self.num_patches = self.grid_size[0] * self.grid_size[1]
44 | self.flatten = flatten
45 |
46 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
47 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
48 |
49 | def update_image_size(self, img_size):
50 | self.img_size = img_size
51 | self.grid_size = (img_size[0] // self.patch_size[0], img_size[1] // self.patch_size[1])
52 | self.num_patches = self.grid_size[0] * self.grid_size[1]
53 |
54 | def forward(self, x):
55 | # B, C, H, W = x.shape
56 | # _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
57 | # _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
58 | x = self.proj(x)
59 | if self.flatten:
60 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
61 | x = self.norm(x)
62 | return x
63 |
64 |
65 | def timestep_embedding(t, dim, max_period=10000, repeat_only=False):
66 | """
67 | Create sinusoidal timestep embeddings.
68 | :param t: a 1-D Tensor of N indices, one per batch element.
69 | These may be fractional.
70 | :param dim: the dimension of the output.
71 | :param max_period: controls the minimum frequency of the embeddings.
72 | :return: an (N, D) Tensor of positional embeddings.
73 | """
74 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
75 | if not repeat_only:
76 | half = dim // 2
77 | freqs = torch.exp(
78 | -math.log(max_period)
79 | * torch.arange(start=0, end=half, dtype=torch.float32)
80 | / half
81 | ).to(device=t.device) # size: [dim/2], 一个指数衰减的曲线
82 | args = t[:, None].float() * freqs[None]
83 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
84 | if dim % 2:
85 | embedding = torch.cat(
86 | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
87 | )
88 | else:
89 | embedding = repeat(t, "b -> b d", d=dim)
90 | return embedding
91 |
92 |
93 | class TimestepEmbedder(nn.Module):
94 | """
95 | Embeds scalar timesteps into vector representations.
96 | """
97 | def __init__(self, hidden_size, frequency_embedding_size=256, out_size=None):
98 | super().__init__()
99 | if out_size is None:
100 | out_size = hidden_size
101 | self.mlp = nn.Sequential(
102 | nn.Linear(frequency_embedding_size, hidden_size, bias=True),
103 | nn.SiLU(),
104 | nn.Linear(hidden_size, out_size, bias=True),
105 | )
106 | self.frequency_embedding_size = frequency_embedding_size
107 |
108 | def forward(self, t):
109 | t_freq = timestep_embedding(t, self.frequency_embedding_size).type(self.mlp[0].weight.dtype)
110 | t_emb = self.mlp(t_freq)
111 | return t_emb
112 |
--------------------------------------------------------------------------------
/HunYuan/models/norm_layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class RMSNorm(nn.Module):
6 | def __init__(self, dim: int, elementwise_affine=True, eps: float = 1e-6):
7 | """
8 | Initialize the RMSNorm normalization layer.
9 |
10 | Args:
11 | dim (int): The dimension of the input tensor.
12 | eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
13 |
14 | Attributes:
15 | eps (float): A small value added to the denominator for numerical stability.
16 | weight (nn.Parameter): Learnable scaling parameter.
17 |
18 | """
19 | super().__init__()
20 | self.eps = eps
21 | if elementwise_affine:
22 | self.weight = nn.Parameter(torch.ones(dim))
23 |
24 | def _norm(self, x):
25 | """
26 | Apply the RMSNorm normalization to the input tensor.
27 |
28 | Args:
29 | x (torch.Tensor): The input tensor.
30 |
31 | Returns:
32 | torch.Tensor: The normalized tensor.
33 |
34 | """
35 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
36 |
37 | def forward(self, x):
38 | """
39 | Forward pass through the RMSNorm layer.
40 |
41 | Args:
42 | x (torch.Tensor): The input tensor.
43 |
44 | Returns:
45 | torch.Tensor: The output tensor after applying RMSNorm.
46 |
47 | """
48 | output = self._norm(x.float()).type_as(x)
49 | if hasattr(self, "weight"):
50 | output = output * self.weight
51 | return output
52 |
53 |
54 | class GroupNorm32(nn.GroupNorm):
55 | def __init__(self, num_groups, num_channels, eps=1e-5, dtype=None):
56 | super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps, dtype=dtype)
57 |
58 | def forward(self, x):
59 | y = super().forward(x).to(x.dtype)
60 | return y
61 |
62 | def normalization(channels, dtype=None):
63 | """
64 | Make a standard normalization layer.
65 | :param channels: number of input channels.
66 | :return: an nn.Module for normalization.
67 | """
68 | return GroupNorm32(num_channels=channels, num_groups=32, dtype=dtype)
69 |
--------------------------------------------------------------------------------
/HunYuan/models/poolers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class AttentionPool(nn.Module):
7 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
8 | super().__init__()
9 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim ** 0.5)
10 | self.k_proj = nn.Linear(embed_dim, embed_dim)
11 | self.q_proj = nn.Linear(embed_dim, embed_dim)
12 | self.v_proj = nn.Linear(embed_dim, embed_dim)
13 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
14 | self.num_heads = num_heads
15 |
16 | def forward(self, x):
17 | x = x.permute(1, 0, 2) # NLC -> LNC
18 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
19 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC
20 | x, _ = F.multi_head_attention_forward(
21 | query=x[:1], key=x, value=x,
22 | embed_dim_to_check=x.shape[-1],
23 | num_heads=self.num_heads,
24 | q_proj_weight=self.q_proj.weight,
25 | k_proj_weight=self.k_proj.weight,
26 | v_proj_weight=self.v_proj.weight,
27 | in_proj_weight=None,
28 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
29 | bias_k=None,
30 | bias_v=None,
31 | add_zero_attn=False,
32 | dropout_p=0,
33 | out_proj_weight=self.c_proj.weight,
34 | out_proj_bias=self.c_proj.bias,
35 | use_separate_proj_weight=True,
36 | training=self.training,
37 | need_weights=False
38 | )
39 | return x.squeeze(0)
40 |
--------------------------------------------------------------------------------
/HunYuan/models/posemb_layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from typing import Union
4 |
5 |
6 | def _to_tuple(x):
7 | if isinstance(x, int):
8 | return x, x
9 | else:
10 | return x
11 |
12 |
13 | def get_fill_resize_and_crop(src, tgt): # src 来源的分辨率 tgt base 分辨率
14 | th, tw = _to_tuple(tgt)
15 | h, w = _to_tuple(src)
16 |
17 | tr = th / tw # base 分辨率
18 | r = h / w # 目标分辨率
19 |
20 | # resize
21 | if r > tr:
22 | resize_height = th
23 | resize_width = int(round(th / h * w))
24 | else:
25 | resize_width = tw
26 | resize_height = int(round(tw / w * h)) # 根据base分辨率,将目标分辨率resize下来
27 |
28 | crop_top = int(round((th - resize_height) / 2.0))
29 | crop_left = int(round((tw - resize_width) / 2.0))
30 |
31 | return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
32 |
33 |
34 | def get_meshgrid(start, *args):
35 | if len(args) == 0:
36 | # start is grid_size
37 | num = _to_tuple(start)
38 | start = (0, 0)
39 | stop = num
40 | elif len(args) == 1:
41 | # start is start, args[0] is stop, step is 1
42 | start = _to_tuple(start)
43 | stop = _to_tuple(args[0])
44 | num = (stop[0] - start[0], stop[1] - start[1])
45 | elif len(args) == 2:
46 | # start is start, args[0] is stop, args[1] is num
47 | start = _to_tuple(start) # 左上角 eg: 12,0
48 | stop = _to_tuple(args[0]) # 右下角 eg: 20,32
49 | num = _to_tuple(args[1]) # 目标大小 eg: 32,124
50 | else:
51 | raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
52 |
53 | grid_h = np.linspace(start[0], stop[0], num[0], endpoint=False, dtype=np.float32) # 12-20 中间差值32份 0-32 中间差值124份
54 | grid_w = np.linspace(start[1], stop[1], num[1], endpoint=False, dtype=np.float32)
55 | grid = np.meshgrid(grid_w, grid_h) # here w goes first
56 | grid = np.stack(grid, axis=0) # [2, W, H]
57 | return grid
58 |
59 | #################################################################################
60 | # Sine/Cosine Positional Embedding Functions #
61 | #################################################################################
62 | # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
63 |
64 | def get_2d_sincos_pos_embed(embed_dim, start, *args, cls_token=False, extra_tokens=0):
65 | """
66 | grid_size: int of the grid height and width
67 | return:
68 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
69 | """
70 | grid = get_meshgrid(start, *args) # [2, H, w]
71 | # grid_h = np.arange(grid_size, dtype=np.float32)
72 | # grid_w = np.arange(grid_size, dtype=np.float32)
73 | # grid = np.meshgrid(grid_w, grid_h) # here w goes first
74 | # grid = np.stack(grid, axis=0) # [2, W, H]
75 |
76 | grid = grid.reshape([2, 1, *grid.shape[1:]])
77 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
78 | if cls_token and extra_tokens > 0:
79 | pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
80 | return pos_embed
81 |
82 |
83 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
84 | assert embed_dim % 2 == 0
85 |
86 | # use half of dimensions to encode grid_h
87 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
88 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
89 |
90 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
91 | return emb
92 |
93 |
94 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
95 | """
96 | embed_dim: output dimension for each position
97 | pos: a list of positions to be encoded: size (W,H)
98 | out: (M, D)
99 | """
100 | assert embed_dim % 2 == 0
101 | omega = np.arange(embed_dim // 2, dtype=np.float64)
102 | omega /= embed_dim / 2.
103 | omega = 1. / 10000**omega # (D/2,)
104 |
105 | pos = pos.reshape(-1) # (M,)
106 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
107 |
108 | emb_sin = np.sin(out) # (M, D/2)
109 | emb_cos = np.cos(out) # (M, D/2)
110 |
111 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
112 | return emb
113 |
114 |
115 | #################################################################################
116 | # Rotary Positional Embedding Functions #
117 | #################################################################################
118 | # https://github.com/facebookresearch/llama/blob/main/llama/model.py#L443
119 |
120 | def get_2d_rotary_pos_embed(embed_dim, start, *args, use_real=True):
121 | """
122 | This is a 2d version of precompute_freqs_cis, which is a RoPE for image tokens with 2d structure.
123 |
124 | Parameters
125 | ----------
126 | embed_dim: int
127 | embedding dimension size
128 | start: int or tuple of int
129 | If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop, step is 1;
130 | If len(args) == 2, start is start, args[0] is stop, args[1] is num.
131 | use_real: bool
132 | If True, return real part and imaginary part separately. Otherwise, return complex numbers.
133 |
134 | Returns
135 | -------
136 | pos_embed: torch.Tensor
137 | [HW, D/2]
138 | """
139 | grid = get_meshgrid(start, *args) # [2, H, w]
140 | grid = grid.reshape([2, 1, *grid.shape[1:]]) # 返回一个采样矩阵 分辨率与目标分辨率一致
141 | pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
142 | return pos_embed
143 |
144 |
145 | def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
146 | assert embed_dim % 4 == 0
147 |
148 | # use half of dimensions to encode grid_h
149 | emb_h = get_1d_rotary_pos_embed(embed_dim // 2, grid[0].reshape(-1), use_real=use_real) # (H*W, D/4)
150 | emb_w = get_1d_rotary_pos_embed(embed_dim // 2, grid[1].reshape(-1), use_real=use_real) # (H*W, D/4)
151 |
152 | if use_real:
153 | cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D/2)
154 | sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D/2)
155 | return cos, sin
156 | else:
157 | emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
158 | return emb
159 |
160 |
161 | def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False):
162 | """
163 | Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
164 |
165 | This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
166 | and the end index 'end'. The 'theta' parameter scales the frequencies.
167 | The returned tensor contains complex values in complex64 data type.
168 |
169 | Args:
170 | dim (int): Dimension of the frequency tensor.
171 | pos (np.ndarray, int): Position indices for the frequency tensor. [S] or scalar
172 | theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
173 | use_real (bool, optional): If True, return real part and imaginary part separately.
174 | Otherwise, return complex numbers.
175 |
176 | Returns:
177 | torch.Tensor: Precomputed frequency tensor with complex exponentials. [S, D/2]
178 |
179 | """
180 | if isinstance(pos, int):
181 | pos = np.arange(pos)
182 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2]
183 | t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
184 | freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2]
185 | if use_real:
186 | freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
187 | freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
188 | return freqs_cos, freqs_sin
189 | else:
190 | freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
191 | return freqs_cis
192 |
193 |
194 |
195 | def calc_sizes(rope_img, patch_size, th, tw):
196 | """ 计算 RoPE 的尺寸. """
197 | if rope_img == 'extend':
198 | # 拓展模式
199 | sub_args = [(th, tw)]
200 | elif rope_img.startswith('base'):
201 | # 基于一个尺寸, 其他尺寸插值获得.
202 | base_size = int(rope_img[4:]) // 8 // patch_size # 基于512作为base,其他根据512差值得到
203 | start, stop = get_fill_resize_and_crop((th, tw), base_size) # 需要在32x32里面 crop的左上角和右下角
204 | sub_args = [start, stop, (th, tw)]
205 | else:
206 | raise ValueError(f"Unknown rope_img: {rope_img}")
207 | return sub_args
208 |
209 |
210 | def init_image_posemb(rope_img,
211 | resolutions,
212 | patch_size,
213 | hidden_size,
214 | num_heads,
215 | log_fn,
216 | rope_real=True,
217 | ):
218 | freqs_cis_img = {}
219 | for reso in resolutions:
220 | th, tw = reso.height // 8 // patch_size, reso.width // 8 // patch_size
221 | sub_args = calc_sizes(rope_img, patch_size, th, tw) # [左上角, 右下角, 目标高宽] 需要在32x32里面 crop的左上角和右下角
222 | freqs_cis_img[str(reso)] = get_2d_rotary_pos_embed(hidden_size // num_heads, *sub_args, use_real=rope_real)
223 | log_fn(f" Using image RoPE ({rope_img}) ({'real' if rope_real else 'complex'}): {sub_args} | ({reso}) "
224 | f"{freqs_cis_img[str(reso)][0].shape if rope_real else freqs_cis_img[str(reso)].shape}")
225 | return freqs_cis_img
226 |
--------------------------------------------------------------------------------
/HunYuan/models/text_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from transformers import AutoTokenizer, T5EncoderModel, T5ForConditionalGeneration
4 |
5 |
6 | class MT5Embedder(nn.Module):
7 | available_models = ["t5-v1_1-xxl"]
8 |
9 | def __init__(
10 | self,
11 | model_dir="t5-v1_1-xxl",
12 | model_kwargs=None,
13 | torch_dtype=None,
14 | use_tokenizer_only=False,
15 | conditional_generation=False,
16 | max_length=128,
17 | device="cuda",
18 | ):
19 | super().__init__()
20 | self.device = device #"cuda" if torch.cuda.is_available() else "cpu"
21 | self.torch_dtype = torch_dtype or torch.bfloat16
22 | self.max_length = max_length
23 | if model_kwargs is None:
24 | model_kwargs = {
25 | # "low_cpu_mem_usage": True,
26 | "torch_dtype": self.torch_dtype,
27 | }
28 | model_kwargs["device_map"] = {"shared": self.device, "encoder": self.device}
29 | self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
30 | if use_tokenizer_only:
31 | return
32 | if conditional_generation:
33 | self.model = None
34 | self.generation_model = T5ForConditionalGeneration.from_pretrained(
35 | model_dir
36 | )
37 | return
38 | self.model = T5EncoderModel.from_pretrained(model_dir, **model_kwargs).eval().to(self.torch_dtype)
39 |
40 | def get_tokens_and_mask(self, texts):
41 | text_tokens_and_mask = self.tokenizer(
42 | texts,
43 | max_length=self.max_length,
44 | padding="max_length",
45 | truncation=True,
46 | return_attention_mask=True,
47 | add_special_tokens=True,
48 | return_tensors="pt",
49 | )
50 | tokens = text_tokens_and_mask["input_ids"][0]
51 | mask = text_tokens_and_mask["attention_mask"][0]
52 | # tokens = torch.tensor(tokens).clone().detach()
53 | # mask = torch.tensor(mask, dtype=torch.bool).clone().detach()
54 | return tokens, mask
55 |
56 | def get_text_embeddings(self, texts, attention_mask=True, layer_index=-1):
57 | text_tokens_and_mask = self.tokenizer(
58 | texts,
59 | max_length=self.max_length,
60 | padding="max_length",
61 | truncation=True,
62 | return_attention_mask=True,
63 | add_special_tokens=True,
64 | return_tensors="pt",
65 | )
66 |
67 | with torch.no_grad():
68 | outputs = self.model(
69 | input_ids=text_tokens_and_mask["input_ids"].to(self.device),
70 | attention_mask=text_tokens_and_mask["attention_mask"].to(self.device)
71 | if attention_mask
72 | else None,
73 | output_hidden_states=True,
74 | )
75 | text_encoder_embs = outputs["hidden_states"][layer_index].detach()
76 |
77 | return text_encoder_embs, text_tokens_and_mask["attention_mask"].to(self.device)
78 |
79 | @torch.no_grad()
80 | def __call__(self, tokens, attention_mask, layer_index=-1):
81 | with torch.cuda.amp.autocast():
82 | outputs = self.model(
83 | input_ids=tokens,
84 | attention_mask=attention_mask,
85 | output_hidden_states=True,
86 | )
87 |
88 | z = outputs.hidden_states[layer_index].detach()
89 | return z
90 |
91 | def general(self, text: str):
92 | # input_ids = input_ids = torch.tensor([list(text.encode("utf-8"))]) + num_special_tokens
93 | input_ids = self.tokenizer(text, max_length=128).input_ids
94 | print(input_ids)
95 | outputs = self.generation_model(input_ids)
96 | return outputs
--------------------------------------------------------------------------------
/HunYuan/nodes.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import torch
4 | import folder_paths
5 |
6 | from .conf import dit_conf
7 | from .loader import load_dit
8 | from .models.text_encoder import MT5Embedder
9 | from transformers import BertModel, BertTokenizer
10 |
11 | class MT5Loader:
12 | @classmethod
13 | def INPUT_TYPES(s):
14 | return {
15 | "required": {
16 | "HunyuanDiTfolder": (os.listdir(os.path.join(folder_paths.models_dir,"diffusers")), {"default": "HunyuanDiT"}),
17 | "device": (["cpu", "cuda"], {"default": "cuda"}),
18 | }
19 | }
20 | RETURN_TYPES = ("MT5","CLIP","Tokenizer",)
21 | FUNCTION = "load_model"
22 | CATEGORY = "ExtraModels/T5"
23 | TITLE = "MT5 Loader"
24 |
25 | def load_model(self, HunyuanDiTfolder, device):
26 | HunyuanDiTfolder=os.path.join(os.path.join(folder_paths.models_dir,"diffusers"),HunyuanDiTfolder)
27 | mt5folder=os.path.join(HunyuanDiTfolder,"t2i/mt5")
28 | clipfolder=os.path.join(HunyuanDiTfolder,"t2i/clip_text_encoder")
29 | tokenizerfolder=os.path.join(HunyuanDiTfolder,"t2i/tokenizer")
30 | torch_dtype=torch.float16
31 | if device=="cpu":
32 | torch_dtype=torch.float32
33 | clip_text_encoder = BertModel.from_pretrained(str(clipfolder), False, revision=None).to(device)
34 | tokenizer = BertTokenizer.from_pretrained(str(tokenizerfolder))
35 | embedder_t5 = MT5Embedder(mt5folder, torch_dtype=torch_dtype, max_length=256, device=device)
36 |
37 | return (embedder_t5,clip_text_encoder,tokenizer,)
38 |
39 | def clip_get_text_embeddings(clip_text_encoder,tokenizer,text,device):
40 | max_length=tokenizer.model_max_length
41 | text_inputs = tokenizer(
42 | text,
43 | padding="max_length",
44 | max_length=max_length,
45 | truncation=True,
46 | return_attention_mask=True,
47 | return_tensors="pt",
48 | )
49 | text_input_ids = text_inputs.input_ids
50 | attention_mask = text_inputs.attention_mask.to(device)
51 | prompt_embeds = clip_text_encoder(
52 | text_input_ids.to(device),
53 | attention_mask=attention_mask,
54 | )
55 | prompt_embeds = prompt_embeds[0]
56 | attention_mask = attention_mask.repeat(1, 1)
57 |
58 | return (prompt_embeds,attention_mask)
59 |
60 | class MT5TextEncode:
61 | @classmethod
62 | def INPUT_TYPES(s):
63 | return {
64 | "required": {
65 | "embedder_t5": ("MT5",),
66 | "clip_text_encoder": ("CLIP",),
67 | "tokenizer": ("Tokenizer",),
68 | "prompt": ("STRING", {"multiline": True}),
69 | "negative_prompt": ("STRING", {"multiline": True}),
70 | }
71 | }
72 |
73 | RETURN_TYPES = ("CONDITIONING","CONDITIONING",)
74 | RETURN_NAMES = ("positive","negative",)
75 | FUNCTION = "encode"
76 | CATEGORY = "ExtraModels/T5"
77 | TITLE = "MT5 Text Encode"
78 |
79 | def encode(self, embedder_t5, clip_text_encoder, tokenizer, prompt, negative_prompt):
80 | print(f'prompt{prompt}')
81 | clip_prompt_embeds,clip_attention_mask = clip_get_text_embeddings(clip_text_encoder,tokenizer,prompt,embedder_t5.device)
82 |
83 | clip_negative_prompt_embeds,clip_negative_attention_mask = clip_get_text_embeddings(clip_text_encoder,tokenizer,negative_prompt,embedder_t5.device)
84 |
85 | mt5_prompt_embeds,mt5_attention_mask = embedder_t5.get_text_embeddings(prompt)
86 |
87 | mt5_negative_prompt_embeds,mt5_negative_attention_mask = embedder_t5.get_text_embeddings(negative_prompt)
88 |
89 | return ([[clip_prompt_embeds, {"clip_prompt_embeds":clip_prompt_embeds,"clip_attention_mask":clip_attention_mask,"mt5_prompt_embeds":mt5_prompt_embeds,"mt5_attention_mask":mt5_attention_mask}]],[[clip_negative_prompt_embeds, {"clip_prompt_embeds":clip_negative_prompt_embeds,"clip_attention_mask":clip_negative_attention_mask,"mt5_prompt_embeds":mt5_negative_prompt_embeds,"mt5_attention_mask":mt5_negative_attention_mask}]], )
90 |
91 | class HunYuanDitCheckpointLoader:
92 | @classmethod
93 | def INPUT_TYPES(s):
94 | return {
95 | "required": {
96 | "HunyuanDiTfolder": (os.listdir(os.path.join(folder_paths.models_dir,"diffusers")), {"default": "HunyuanDiT"}),
97 | "model": (list(dit_conf.keys()),),
98 | "image_size_width": ("INT",{"default":1024}),
99 | "image_size_height": ("INT",{"default":1024}),
100 | # "num_classes": ("INT", {"default": 1000, "min": 0,}),
101 | }
102 | }
103 | RETURN_TYPES = ("MODEL",)
104 | RETURN_NAMES = ("model",)
105 | FUNCTION = "load_checkpoint"
106 | CATEGORY = "ExtraModels/DiT"
107 | TITLE = "HunYuanDitCheckpointLoader"
108 |
109 | def load_checkpoint(self, HunyuanDiTfolder, model, image_size_width, image_size_height):
110 | image_size_width = int((image_size_width // 16) * 16)
111 | image_size_height = int((image_size_height // 16) * 16)
112 | HunyuanDiTfolder=os.path.join(os.path.join(folder_paths.models_dir,"diffusers"),HunyuanDiTfolder)
113 | ckpt_path=os.path.join(HunyuanDiTfolder,"t2i/model/pytorch_model_ema.pt")
114 | model_conf = dit_conf[model]
115 | model_conf["unet_config"]["input_size"] = (image_size_height // 8, image_size_width // 8)
116 | # model_conf["unet_config"]["num_classes"] = num_classes
117 | dit = load_dit(
118 | model_path = ckpt_path,
119 | model_conf = model_conf,
120 | )
121 | return (dit,)
122 |
123 | NODE_CLASS_MAPPINGS = {
124 | "HunYuanDitCheckpointLoader" : HunYuanDitCheckpointLoader,
125 | "MT5Loader" : MT5Loader,
126 | "MT5TextEncode" : MT5TextEncode,
127 | }
128 |
--------------------------------------------------------------------------------
/HunYuan/wf.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI_ExtraModels/d8b11e401de830ccfb27fa84bdd0091b52408af8/HunYuan/wf.png
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/PixArt/conf.py:
--------------------------------------------------------------------------------
1 | """
2 | List of all PixArt model types / settings
3 | """
4 |
5 | sampling_settings = {
6 | "beta_schedule" : "sqrt_linear",
7 | "linear_start" : 0.0001,
8 | "linear_end" : 0.02,
9 | "timesteps" : 1000,
10 | }
11 |
12 | pixart_conf = {
13 | "PixArtMS_XL_2": { # models/PixArtMS
14 | "target": "PixArtMS",
15 | "unet_config": {
16 | "input_size" : 1024//8,
17 | "depth" : 28,
18 | "num_heads" : 16,
19 | "patch_size" : 2,
20 | "hidden_size" : 1152,
21 | "pe_interpolation": 2,
22 | },
23 | "sampling_settings" : sampling_settings,
24 | },
25 | "PixArtMS_Sigma_XL_2": {
26 | "target": "PixArtMSSigma",
27 | "unet_config": {
28 | "input_size" : 1024//8,
29 | "token_num" : 300,
30 | "depth" : 28,
31 | "num_heads" : 16,
32 | "patch_size" : 2,
33 | "hidden_size" : 1152,
34 | "micro_condition": False,
35 | "pe_interpolation": 2,
36 | "model_max_length": 300,
37 | },
38 | "sampling_settings" : sampling_settings,
39 | },
40 | "PixArtMS_Sigma_XL_2_2K": {
41 | "target": "PixArtMSSigma",
42 | "unet_config": {
43 | "input_size" : 2048//8,
44 | "token_num" : 300,
45 | "depth" : 28,
46 | "num_heads" : 16,
47 | "patch_size" : 2,
48 | "hidden_size" : 1152,
49 | "micro_condition": False,
50 | "pe_interpolation": 4,
51 | "model_max_length": 300,
52 | },
53 | "sampling_settings" : sampling_settings,
54 | },
55 | "PixArt_XL_2": { # models/PixArt
56 | "target": "PixArt",
57 | "unet_config": {
58 | "input_size" : 512//8,
59 | "token_num" : 120,
60 | "depth" : 28,
61 | "num_heads" : 16,
62 | "patch_size" : 2,
63 | "hidden_size" : 1152,
64 | "pe_interpolation": 1,
65 | },
66 | "sampling_settings" : sampling_settings,
67 | },
68 | }
69 |
70 | pixart_conf.update({ # controlnet models
71 | "ControlPixArtHalf": {
72 | "target": "ControlPixArtHalf",
73 | "unet_config": pixart_conf["PixArt_XL_2"]["unet_config"],
74 | "sampling_settings": pixart_conf["PixArt_XL_2"]["sampling_settings"],
75 | },
76 | "ControlPixArtMSHalf": {
77 | "target": "ControlPixArtMSHalf",
78 | "unet_config": pixart_conf["PixArtMS_XL_2"]["unet_config"],
79 | "sampling_settings": pixart_conf["PixArtMS_XL_2"]["sampling_settings"],
80 | }
81 | })
82 |
83 | pixart_res = {
84 | "PixArtMS_XL_2": { # models/PixArtMS 1024x1024
85 | '0.25': [512, 2048], '0.26': [512, 1984], '0.27': [512, 1920], '0.28': [512, 1856],
86 | '0.32': [576, 1792], '0.33': [576, 1728], '0.35': [576, 1664], '0.40': [640, 1600],
87 | '0.42': [640, 1536], '0.48': [704, 1472], '0.50': [704, 1408], '0.52': [704, 1344],
88 | '0.57': [768, 1344], '0.60': [768, 1280], '0.68': [832, 1216], '0.72': [832, 1152],
89 | '0.78': [896, 1152], '0.82': [896, 1088], '0.88': [960, 1088], '0.94': [960, 1024],
90 | '1.00': [1024,1024], '1.07': [1024, 960], '1.13': [1088, 960], '1.21': [1088, 896],
91 | '1.29': [1152, 896], '1.38': [1152, 832], '1.46': [1216, 832], '1.67': [1280, 768],
92 | '1.75': [1344, 768], '2.00': [1408, 704], '2.09': [1472, 704], '2.40': [1536, 640],
93 | '2.50': [1600, 640], '2.89': [1664, 576], '3.00': [1728, 576], '3.11': [1792, 576],
94 | '3.62': [1856, 512], '3.75': [1920, 512], '3.88': [1984, 512], '4.00': [2048, 512],
95 | },
96 | "PixArt_XL_2": { # models/PixArt 512x512
97 | '0.25': [256,1024], '0.26': [256, 992], '0.27': [256, 960], '0.28': [256, 928],
98 | '0.32': [288, 896], '0.33': [288, 864], '0.35': [288, 832], '0.40': [320, 800],
99 | '0.42': [320, 768], '0.48': [352, 736], '0.50': [352, 704], '0.52': [352, 672],
100 | '0.57': [384, 672], '0.60': [384, 640], '0.68': [416, 608], '0.72': [416, 576],
101 | '0.78': [448, 576], '0.82': [448, 544], '0.88': [480, 544], '0.94': [480, 512],
102 | '1.00': [512, 512], '1.07': [512, 480], '1.13': [544, 480], '1.21': [544, 448],
103 | '1.29': [576, 448], '1.38': [576, 416], '1.46': [608, 416], '1.67': [640, 384],
104 | '1.75': [672, 384], '2.00': [704, 352], '2.09': [736, 352], '2.40': [768, 320],
105 | '2.50': [800, 320], '2.89': [832, 288], '3.00': [864, 288], '3.11': [896, 288],
106 | '3.62': [928, 256], '3.75': [960, 256], '3.88': [992, 256], '4.00': [1024,256]
107 | },
108 | "PixArtMS_Sigma_XL_2_2K": {
109 | '0.25': [1024, 4096], '0.26': [1024, 3968], '0.27': [1024, 3840], '0.28': [1024, 3712],
110 | '0.32': [1152, 3584], '0.33': [1152, 3456], '0.35': [1152, 3328], '0.40': [1280, 3200],
111 | '0.42': [1280, 3072], '0.48': [1408, 2944], '0.50': [1408, 2816], '0.52': [1408, 2688],
112 | '0.57': [1536, 2688], '0.60': [1536, 2560], '0.68': [1664, 2432], '0.72': [1664, 2304],
113 | '0.78': [1792, 2304], '0.82': [1792, 2176], '0.88': [1920, 2176], '0.94': [1920, 2048],
114 | '1.00': [2048, 2048], '1.07': [2048, 1920], '1.13': [2176, 1920], '1.21': [2176, 1792],
115 | '1.29': [2304, 1792], '1.38': [2304, 1664], '1.46': [2432, 1664], '1.67': [2560, 1536],
116 | '1.75': [2688, 1536], '2.00': [2816, 1408], '2.09': [2944, 1408], '2.40': [3072, 1280],
117 | '2.50': [3200, 1280], '2.89': [3328, 1152], '3.00': [3456, 1152], '3.11': [3584, 1152],
118 | '3.62': [3712, 1024], '3.75': [3840, 1024], '3.88': [3968, 1024], '4.00': [4096, 1024]
119 | }
120 | }
121 | # These should be the same
122 | pixart_res.update({
123 | "PixArtMS_Sigma_XL_2": pixart_res["PixArtMS_XL_2"],
124 | "PixArtMS_Sigma_XL_2_512": pixart_res["PixArt_XL_2"],
125 | })
126 |
--------------------------------------------------------------------------------
/PixArt/diffusers_convert.py:
--------------------------------------------------------------------------------
1 | # For using the diffusers format weights
2 | # Based on the original ComfyUI function +
3 | # https://github.com/PixArt-alpha/PixArt-alpha/blob/master/tools/convert_pixart_alpha_to_diffusers.py
4 | import torch
5 |
6 | conversion_map = [ # main SD conversion map (PixArt reference, HF Diffusers)
7 | # Patch embeddings
8 | ("x_embedder.proj.weight", "pos_embed.proj.weight"),
9 | ("x_embedder.proj.bias", "pos_embed.proj.bias"),
10 | # Caption projection
11 | ("y_embedder.y_embedding", "caption_projection.y_embedding"),
12 | ("y_embedder.y_proj.fc1.weight", "caption_projection.linear_1.weight"),
13 | ("y_embedder.y_proj.fc1.bias", "caption_projection.linear_1.bias"),
14 | ("y_embedder.y_proj.fc2.weight", "caption_projection.linear_2.weight"),
15 | ("y_embedder.y_proj.fc2.bias", "caption_projection.linear_2.bias"),
16 | # AdaLN-single LN
17 | ("t_embedder.mlp.0.weight", "adaln_single.emb.timestep_embedder.linear_1.weight"),
18 | ("t_embedder.mlp.0.bias", "adaln_single.emb.timestep_embedder.linear_1.bias"),
19 | ("t_embedder.mlp.2.weight", "adaln_single.emb.timestep_embedder.linear_2.weight"),
20 | ("t_embedder.mlp.2.bias", "adaln_single.emb.timestep_embedder.linear_2.bias"),
21 | # Shared norm
22 | ("t_block.1.weight", "adaln_single.linear.weight"),
23 | ("t_block.1.bias", "adaln_single.linear.bias"),
24 | # Final block
25 | ("final_layer.linear.weight", "proj_out.weight"),
26 | ("final_layer.linear.bias", "proj_out.bias"),
27 | ("final_layer.scale_shift_table", "scale_shift_table"),
28 | ]
29 |
30 | conversion_map_ms = [ # for multi_scale_train (MS)
31 | # Resolution
32 | ("csize_embedder.mlp.0.weight", "adaln_single.emb.resolution_embedder.linear_1.weight"),
33 | ("csize_embedder.mlp.0.bias", "adaln_single.emb.resolution_embedder.linear_1.bias"),
34 | ("csize_embedder.mlp.2.weight", "adaln_single.emb.resolution_embedder.linear_2.weight"),
35 | ("csize_embedder.mlp.2.bias", "adaln_single.emb.resolution_embedder.linear_2.bias"),
36 | # Aspect ratio
37 | ("ar_embedder.mlp.0.weight", "adaln_single.emb.aspect_ratio_embedder.linear_1.weight"),
38 | ("ar_embedder.mlp.0.bias", "adaln_single.emb.aspect_ratio_embedder.linear_1.bias"),
39 | ("ar_embedder.mlp.2.weight", "adaln_single.emb.aspect_ratio_embedder.linear_2.weight"),
40 | ("ar_embedder.mlp.2.bias", "adaln_single.emb.aspect_ratio_embedder.linear_2.bias"),
41 | ]
42 |
43 | # Add actual transformer blocks
44 | for depth in range(28):
45 | # Transformer blocks
46 | conversion_map += [
47 | (f"blocks.{depth}.scale_shift_table", f"transformer_blocks.{depth}.scale_shift_table"),
48 | # Projection
49 | (f"blocks.{depth}.attn.proj.weight", f"transformer_blocks.{depth}.attn1.to_out.0.weight"),
50 | (f"blocks.{depth}.attn.proj.bias", f"transformer_blocks.{depth}.attn1.to_out.0.bias"),
51 | # Feed-forward
52 | (f"blocks.{depth}.mlp.fc1.weight", f"transformer_blocks.{depth}.ff.net.0.proj.weight"),
53 | (f"blocks.{depth}.mlp.fc1.bias", f"transformer_blocks.{depth}.ff.net.0.proj.bias"),
54 | (f"blocks.{depth}.mlp.fc2.weight", f"transformer_blocks.{depth}.ff.net.2.weight"),
55 | (f"blocks.{depth}.mlp.fc2.bias", f"transformer_blocks.{depth}.ff.net.2.bias"),
56 | # Cross-attention (proj)
57 | (f"blocks.{depth}.cross_attn.proj.weight" ,f"transformer_blocks.{depth}.attn2.to_out.0.weight"),
58 | (f"blocks.{depth}.cross_attn.proj.bias" ,f"transformer_blocks.{depth}.attn2.to_out.0.bias"),
59 | ]
60 |
61 | def find_prefix(state_dict, target_key):
62 | prefix = ""
63 | for k in state_dict.keys():
64 | if k.endswith(target_key):
65 | prefix = k.split(target_key)[0]
66 | break
67 | return prefix
68 |
69 | def convert_state_dict(state_dict):
70 | if "adaln_single.emb.resolution_embedder.linear_1.weight" in state_dict.keys():
71 | cmap = conversion_map + conversion_map_ms
72 | else:
73 | cmap = conversion_map
74 |
75 | missing = [k for k,v in cmap if v not in state_dict]
76 | new_state_dict = {k: state_dict[v] for k,v in cmap if k not in missing}
77 | matched = list(v for k,v in cmap if v in state_dict.keys())
78 |
79 | for depth in range(28):
80 | for wb in ["weight", "bias"]:
81 | # Self Attention
82 | key = lambda a: f"transformer_blocks.{depth}.attn1.to_{a}.{wb}"
83 | new_state_dict[f"blocks.{depth}.attn.qkv.{wb}"] = torch.cat((
84 | state_dict[key('q')], state_dict[key('k')], state_dict[key('v')]
85 | ), dim=0)
86 | matched += [key('q'), key('k'), key('v')]
87 |
88 | # Cross-attention (linear)
89 | key = lambda a: f"transformer_blocks.{depth}.attn2.to_{a}.{wb}"
90 | new_state_dict[f"blocks.{depth}.cross_attn.q_linear.{wb}"] = state_dict[key('q')]
91 | new_state_dict[f"blocks.{depth}.cross_attn.kv_linear.{wb}"] = torch.cat((
92 | state_dict[key('k')], state_dict[key('v')]
93 | ), dim=0)
94 | matched += [key('q'), key('k'), key('v')]
95 |
96 | if len(matched) < len(state_dict):
97 | print(f"PixArt: UNET conversion has leftover keys! ({len(matched)} vs {len(state_dict)})")
98 | print(list( set(state_dict.keys()) - set(matched) ))
99 |
100 | if len(missing) > 0:
101 | print(f"PixArt: UNET conversion has missing keys!")
102 | print(missing)
103 |
104 | return new_state_dict
105 |
106 | # Same as above but for LoRA weights:
107 | def convert_lora_state_dict(state_dict):
108 | # peft
109 | rep_ap = lambda x: x.replace(".weight", ".lora_A.weight")
110 | rep_bp = lambda x: x.replace(".weight", ".lora_B.weight")
111 | # koyha
112 | rep_ak = lambda x: x.replace(".weight", ".lora_down.weight")
113 | rep_bk = lambda x: x.replace(".weight", ".lora_up.weight")
114 |
115 | prefix = find_prefix(state_dict, "adaln_single.linear.lora_A.weight")
116 | state_dict = {k[len(prefix):]:v for k,v in state_dict.items()}
117 |
118 | cmap = []
119 | cmap_unet = conversion_map + conversion_map_ms # todo: 512 model
120 | for k, v in cmap_unet:
121 | if not v.endswith(".weight"):
122 | continue
123 | cmap.append((rep_ak(k), rep_ap(v)))
124 | cmap.append((rep_bk(k), rep_bp(v)))
125 |
126 | missing = [k for k,v in cmap if v not in state_dict]
127 | new_state_dict = {k: state_dict[v] for k,v in cmap if k not in missing}
128 | matched = list(v for k,v in cmap if v in state_dict.keys())
129 |
130 | for fp, fk in ((rep_ap, rep_ak),(rep_bp, rep_bk)):
131 | for depth in range(28):
132 | # Self Attention
133 | key = lambda a: fp(f"transformer_blocks.{depth}.attn1.to_{a}.weight")
134 | new_state_dict[fk(f"blocks.{depth}.attn.qkv.weight")] = torch.cat((
135 | state_dict[key('q')], state_dict[key('k')], state_dict[key('v')]
136 | ), dim=0)
137 | matched += [key('q'), key('k'), key('v')]
138 |
139 | # Cross-attention (linear)
140 | key = lambda a: fp(f"transformer_blocks.{depth}.attn2.to_{a}.weight")
141 | new_state_dict[fk(f"blocks.{depth}.cross_attn.q_linear.weight")] = state_dict[key('q')]
142 | new_state_dict[fk(f"blocks.{depth}.cross_attn.kv_linear.weight")] = torch.cat((
143 | state_dict[key('k')], state_dict[key('v')]
144 | ), dim=0)
145 | matched += [key('q'), key('k'), key('v')]
146 |
147 | if len(matched) < len(state_dict):
148 | print(f"PixArt: LoRA conversion has leftover keys! ({len(matched)} vs {len(state_dict)})")
149 | print(list( set(state_dict.keys()) - set(matched) ))
150 |
151 | if len(missing) > 0:
152 | print(f"PixArt: LoRA conversion has missing keys!")
153 | print(missing)
154 |
155 | return new_state_dict
156 |
--------------------------------------------------------------------------------
/PixArt/loader.py:
--------------------------------------------------------------------------------
1 | import comfy.supported_models_base
2 | import comfy.latent_formats
3 | import comfy.model_patcher
4 | import comfy.model_base
5 | import comfy.utils
6 | import comfy.conds
7 | import torch
8 | from comfy import model_management
9 | from .diffusers_convert import convert_state_dict
10 |
11 | class EXM_PixArt(comfy.supported_models_base.BASE):
12 | unet_config = {}
13 | unet_extra_config = {}
14 | latent_format = comfy.latent_formats.SD15
15 |
16 | def __init__(self, model_conf):
17 | self.model_target = model_conf.get("target")
18 | self.unet_config = model_conf.get("unet_config", {})
19 | self.sampling_settings = model_conf.get("sampling_settings", {})
20 | self.latent_format = self.latent_format()
21 | # UNET is handled by extension
22 | self.unet_config["disable_unet_model_creation"] = True
23 |
24 | def model_type(self, state_dict, prefix=""):
25 | return comfy.model_base.ModelType.EPS
26 |
27 | class EXM_PixArt_Model(comfy.model_base.BaseModel):
28 | def __init__(self, *args, **kwargs):
29 | super().__init__(*args, **kwargs)
30 |
31 | def extra_conds(self, **kwargs):
32 | out = super().extra_conds(**kwargs)
33 |
34 | img_hw = kwargs.get("img_hw", None)
35 | if img_hw is not None:
36 | out["img_hw"] = comfy.conds.CONDRegular(torch.tensor(img_hw))
37 |
38 | aspect_ratio = kwargs.get("aspect_ratio", None)
39 | if aspect_ratio is not None:
40 | out["aspect_ratio"] = comfy.conds.CONDRegular(torch.tensor(aspect_ratio))
41 |
42 | cn_hint = kwargs.get("cn_hint", None)
43 | if cn_hint is not None:
44 | out["cn_hint"] = comfy.conds.CONDRegular(cn_hint)
45 |
46 | return out
47 |
48 | def load_pixart(model_path, model_conf):
49 | state_dict = comfy.utils.load_torch_file(model_path)
50 | state_dict = state_dict.get("model", state_dict)
51 |
52 | # prefix
53 | for prefix in ["model.diffusion_model.",]:
54 | if any(True for x in state_dict if x.startswith(prefix)):
55 | state_dict = {k[len(prefix):]:v for k,v in state_dict.items()}
56 |
57 | # diffusers
58 | if "adaln_single.linear.weight" in state_dict:
59 | state_dict = convert_state_dict(state_dict) # Diffusers
60 |
61 | parameters = comfy.utils.calculate_parameters(state_dict)
62 | unet_dtype = model_management.unet_dtype(model_params=parameters)
63 | load_device = comfy.model_management.get_torch_device()
64 | offload_device = comfy.model_management.unet_offload_device()
65 |
66 | # ignore fp8/etc and use directly for now
67 | manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)
68 | if manual_cast_dtype:
69 | print(f"PixArt: falling back to {manual_cast_dtype}")
70 | unet_dtype = manual_cast_dtype
71 |
72 | model_conf = EXM_PixArt(model_conf) # convert to object
73 | model = EXM_PixArt_Model( # same as comfy.model_base.BaseModel
74 | model_conf,
75 | model_type=comfy.model_base.ModelType.EPS,
76 | device=model_management.get_torch_device()
77 | )
78 |
79 | if model_conf.model_target == "PixArtMS":
80 | from .models.PixArtMS import PixArtMS
81 | model.diffusion_model = PixArtMS(**model_conf.unet_config)
82 | elif model_conf.model_target == "PixArt":
83 | from .models.PixArt import PixArt
84 | model.diffusion_model = PixArt(**model_conf.unet_config)
85 | elif model_conf.model_target == "PixArtMSSigma":
86 | from .models.PixArtMS import PixArtMS
87 | model.diffusion_model = PixArtMS(**model_conf.unet_config)
88 | model.latent_format = comfy.latent_formats.SDXL()
89 | elif model_conf.model_target == "ControlPixArtMSHalf":
90 | from .models.PixArtMS import PixArtMS
91 | from .models.pixart_controlnet import ControlPixArtMSHalf
92 | model.diffusion_model = PixArtMS(**model_conf.unet_config)
93 | model.diffusion_model = ControlPixArtMSHalf(model.diffusion_model)
94 | elif model_conf.model_target == "ControlPixArtHalf":
95 | from .models.PixArt import PixArt
96 | from .models.pixart_controlnet import ControlPixArtHalf
97 | model.diffusion_model = PixArt(**model_conf.unet_config)
98 | model.diffusion_model = ControlPixArtHalf(model.diffusion_model)
99 | else:
100 | raise NotImplementedError(f"Unknown model target '{model_conf.model_target}'")
101 |
102 | m, u = model.diffusion_model.load_state_dict(state_dict, strict=False)
103 | if len(m) > 0: print("Missing UNET keys", m)
104 | if len(u) > 0: print("Leftover UNET keys", u)
105 | model.diffusion_model.dtype = unet_dtype
106 | model.diffusion_model.eval()
107 | model.diffusion_model.to(unet_dtype)
108 |
109 | model_patcher = comfy.model_patcher.ModelPatcher(
110 | model,
111 | load_device = load_device,
112 | offload_device = offload_device,
113 | current_device = "cpu",
114 | )
115 | return model_patcher
116 |
--------------------------------------------------------------------------------
/PixArt/lora.py:
--------------------------------------------------------------------------------
1 | import os
2 | import copy
3 | import json
4 | import torch
5 | import comfy.lora
6 | import comfy.model_management
7 | from comfy.model_patcher import ModelPatcher
8 | from .diffusers_convert import convert_lora_state_dict
9 |
10 | class EXM_PixArt_ModelPatcher(ModelPatcher):
11 | def calculate_weight(self, patches, weight, key):
12 | """
13 | This is almost the same as the comfy function, but stripped down to just the LoRA patch code.
14 | The problem with the original code is the q/k/v keys being combined into one for the attention.
15 | In the diffusers code, they're treated as separate keys, but in the reference code they're recombined (q+kv|qkv).
16 | This means, for example, that the [1152,1152] weights become [3456,1152] in the state dict.
17 | The issue with this is that the LoRA weights are [128,1152],[1152,128] and become [384,1162],[3456,128] instead.
18 |
19 | This is the best thing I could think of that would fix that, but it's very fragile.
20 | - Check key shape to determine if it needs the fallback logic
21 | - Cut the input into parts based on the shape (undoing the torch.cat)
22 | - Do the matrix multiplication logic
23 | - Recombine them to match the expected shape
24 | """
25 | for p in patches:
26 | alpha = p[0]
27 | v = p[1]
28 | strength_model = p[2]
29 | if strength_model != 1.0:
30 | weight *= strength_model
31 |
32 | if isinstance(v, list):
33 | v = (self.calculate_weight(v[1:], v[0].clone(), key), )
34 |
35 | if len(v) == 2:
36 | patch_type = v[0]
37 | v = v[1]
38 |
39 | if patch_type == "lora":
40 | mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32)
41 | mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32)
42 | if v[2] is not None:
43 | alpha *= v[2] / mat2.shape[0]
44 | try:
45 | mat1 = mat1.flatten(start_dim=1)
46 | mat2 = mat2.flatten(start_dim=1)
47 |
48 | ch1 = mat1.shape[0] // mat2.shape[1]
49 | ch2 = mat2.shape[0] // mat1.shape[1]
50 | ### Fallback logic for shape mismatch ###
51 | if mat1.shape[0] != mat2.shape[1] and ch1 == ch2 and (mat1.shape[0]/mat2.shape[1])%1 == 0:
52 | mat1 = mat1.chunk(ch1, dim=0)
53 | mat2 = mat2.chunk(ch1, dim=0)
54 | weight += torch.cat(
55 | [alpha * torch.mm(mat1[x], mat2[x]) for x in range(ch1)],
56 | dim=0,
57 | ).reshape(weight.shape).type(weight.dtype)
58 | else:
59 | weight += (alpha * torch.mm(mat1, mat2)).reshape(weight.shape).type(weight.dtype)
60 | except Exception as e:
61 | print("ERROR", key, e)
62 | return weight
63 |
64 | def clone(self):
65 | n = EXM_PixArt_ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update)
66 | n.patches = {}
67 | for k in self.patches:
68 | n.patches[k] = self.patches[k][:]
69 |
70 | n.object_patches = self.object_patches.copy()
71 | n.model_options = copy.deepcopy(self.model_options)
72 | n.model_keys = self.model_keys
73 | return n
74 |
75 | def replace_model_patcher(model):
76 | n = EXM_PixArt_ModelPatcher(
77 | model = model.model,
78 | size = model.size,
79 | load_device = model.load_device,
80 | offload_device = model.offload_device,
81 | current_device = model.current_device,
82 | weight_inplace_update = model.weight_inplace_update,
83 | )
84 | n.patches = {}
85 | for k in model.patches:
86 | n.patches[k] = model.patches[k][:]
87 |
88 | n.object_patches = model.object_patches.copy()
89 | n.model_options = copy.deepcopy(model.model_options)
90 | n.model_keys = model.model_keys
91 | return n
92 |
93 | def find_peft_alpha(path):
94 | def load_json(json_path):
95 | with open(json_path) as f:
96 | data = json.load(f)
97 | alpha = data.get("lora_alpha")
98 | alpha = alpha or data.get("alpha")
99 | if not alpha:
100 | print(" Found config but `lora_alpha` is missing!")
101 | else:
102 | print(f" Found config at {json_path} [alpha:{alpha}]")
103 | return alpha
104 |
105 | # For some weird reason peft doesn't include the alpha in the actual model
106 | print("PixArt: Warning! This is a PEFT LoRA. Trying to find config...")
107 | files = [
108 | f"{os.path.splitext(path)[0]}.json",
109 | f"{os.path.splitext(path)[0]}.config.json",
110 | os.path.join(os.path.dirname(path),"adapter_config.json"),
111 | ]
112 | for file in files:
113 | if os.path.isfile(file):
114 | return load_json(file)
115 |
116 | print(" Missing config/alpha! assuming alpha of 8. Consider converting it/adding a config json to it.")
117 | return 8.0
118 |
119 | def load_pixart_lora(model, lora, lora_path, strength):
120 | k_back = lambda x: x.replace(".lora_up.weight", "")
121 | # need to convert the actual weights for this to work.
122 | if any(True for x in lora.keys() if x.endswith("adaln_single.linear.lora_A.weight")):
123 | lora = convert_lora_state_dict(lora)
124 | alpha = find_peft_alpha(lora_path)
125 | lora.update({f"{k_back(x)}.alpha":torch.tensor(alpha) for x in lora.keys() if "lora_up" in x})
126 |
127 | key_map = {k_back(x):f"diffusion_model.{k_back(x)}.weight" for x in lora.keys() if "lora_up" in x} # fake
128 |
129 | loaded = comfy.lora.load_lora(lora, key_map)
130 | if model is not None:
131 | # switch to custom model patcher when using LoRAs
132 | if isinstance(model, EXM_PixArt_ModelPatcher):
133 | new_modelpatcher = model.clone()
134 | else:
135 | new_modelpatcher = replace_model_patcher(model)
136 | k = new_modelpatcher.add_patches(loaded, strength)
137 | else:
138 | k = ()
139 | new_modelpatcher = None
140 |
141 | k = set(k)
142 | for x in loaded:
143 | if (x not in k):
144 | print("NOT LOADED", x)
145 |
146 | return new_modelpatcher
147 |
--------------------------------------------------------------------------------
/PixArt/models/PixArt.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # GLIDE: https://github.com/openai/glide-text2im
9 | # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
10 | # --------------------------------------------------------
11 | import math
12 | import torch
13 | import torch.nn as nn
14 | import os
15 | import numpy as np
16 | from timm.models.layers import DropPath
17 | from timm.models.vision_transformer import PatchEmbed, Mlp
18 |
19 |
20 | from .utils import auto_grad_checkpoint, to_2tuple
21 | from .PixArt_blocks import t2i_modulate, CaptionEmbedder, AttentionKVCompress, MultiHeadCrossAttention, T2IFinalLayer, TimestepEmbedder, LabelEmbedder, FinalLayer
22 |
23 |
24 | class PixArtBlock(nn.Module):
25 | """
26 | A PixArt block with adaptive layer norm (adaLN-single) conditioning.
27 | """
28 | def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0, input_size=None, sampling=None, sr_ratio=1, qk_norm=False, **block_kwargs):
29 | super().__init__()
30 | self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
31 | self.attn = AttentionKVCompress(
32 | hidden_size, num_heads=num_heads, qkv_bias=True, sampling=sampling, sr_ratio=sr_ratio,
33 | qk_norm=qk_norm, **block_kwargs
34 | )
35 | self.cross_attn = MultiHeadCrossAttention(hidden_size, num_heads, **block_kwargs)
36 | self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
37 | # to be compatible with lower version pytorch
38 | approx_gelu = lambda: nn.GELU(approximate="tanh")
39 | self.mlp = Mlp(in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0)
40 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
41 | self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5)
42 | self.sampling = sampling
43 | self.sr_ratio = sr_ratio
44 |
45 | def forward(self, x, y, t, mask=None, **kwargs):
46 | B, N, C = x.shape
47 |
48 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None] + t.reshape(B, 6, -1)).chunk(6, dim=1)
49 | x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa)).reshape(B, N, C))
50 | x = x + self.cross_attn(x, y, mask)
51 | x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
52 |
53 | return x
54 |
55 |
56 | ### Core PixArt Model ###
57 | class PixArt(nn.Module):
58 | """
59 | Diffusion model with a Transformer backbone.
60 | """
61 | def __init__(
62 | self,
63 | input_size=32,
64 | patch_size=2,
65 | in_channels=4,
66 | hidden_size=1152,
67 | depth=28,
68 | num_heads=16,
69 | mlp_ratio=4.0,
70 | class_dropout_prob=0.1,
71 | pred_sigma=True,
72 | drop_path: float = 0.,
73 | caption_channels=4096,
74 | pe_interpolation=1.0,
75 | config=None,
76 | model_max_length=120,
77 | qk_norm=False,
78 | kv_compress_config=None,
79 | **kwargs,
80 | ):
81 | super().__init__()
82 | self.pred_sigma = pred_sigma
83 | self.in_channels = in_channels
84 | self.out_channels = in_channels * 2 if pred_sigma else in_channels
85 | self.patch_size = patch_size
86 | self.num_heads = num_heads
87 | self.pe_interpolation = pe_interpolation
88 | self.depth = depth
89 |
90 | self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
91 | self.t_embedder = TimestepEmbedder(hidden_size)
92 | num_patches = self.x_embedder.num_patches
93 | self.base_size = input_size // self.patch_size
94 | # Will use fixed sin-cos embedding:
95 | self.register_buffer("pos_embed", torch.zeros(1, num_patches, hidden_size))
96 |
97 | approx_gelu = lambda: nn.GELU(approximate="tanh")
98 | self.t_block = nn.Sequential(
99 | nn.SiLU(),
100 | nn.Linear(hidden_size, 6 * hidden_size, bias=True)
101 | )
102 | self.y_embedder = CaptionEmbedder(
103 | in_channels=caption_channels, hidden_size=hidden_size, uncond_prob=class_dropout_prob,
104 | act_layer=approx_gelu, token_num=model_max_length
105 | )
106 | drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
107 | self.kv_compress_config = kv_compress_config
108 | if kv_compress_config is None:
109 | self.kv_compress_config = {
110 | 'sampling': None,
111 | 'scale_factor': 1,
112 | 'kv_compress_layer': [],
113 | }
114 | self.blocks = nn.ModuleList([
115 | PixArtBlock(
116 | hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i],
117 | input_size=(input_size // patch_size, input_size // patch_size),
118 | sampling=self.kv_compress_config['sampling'],
119 | sr_ratio=int(
120 | self.kv_compress_config['scale_factor']
121 | ) if i in self.kv_compress_config['kv_compress_layer'] else 1,
122 | qk_norm=qk_norm,
123 | )
124 | for i in range(depth)
125 | ])
126 | self.final_layer = T2IFinalLayer(hidden_size, patch_size, self.out_channels)
127 |
128 | def forward_raw(self, x, t, y, mask=None, data_info=None):
129 | """
130 | Original forward pass of PixArt.
131 | x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
132 | t: (N,) tensor of diffusion timesteps
133 | y: (N, 1, 120, C) tensor of class labels
134 | """
135 | x = x.to(self.dtype)
136 | timestep = t.to(self.dtype)
137 | y = y.to(self.dtype)
138 | pos_embed = self.pos_embed.to(self.dtype)
139 | self.h, self.w = x.shape[-2]//self.patch_size, x.shape[-1]//self.patch_size
140 | x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
141 | t = self.t_embedder(timestep.to(x.dtype)) # (N, D)
142 | t0 = self.t_block(t)
143 | y = self.y_embedder(y, self.training) # (N, 1, L, D)
144 | if mask is not None:
145 | if mask.shape[0] != y.shape[0]:
146 | mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
147 | mask = mask.squeeze(1).squeeze(1)
148 | y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
149 | y_lens = mask.sum(dim=1).tolist()
150 | else:
151 | y_lens = [y.shape[2]] * y.shape[0]
152 | y = y.squeeze(1).view(1, -1, x.shape[-1])
153 | for block in self.blocks:
154 | x = auto_grad_checkpoint(block, x, y, t0, y_lens) # (N, T, D) #support grad checkpoint
155 | x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
156 | x = self.unpatchify(x) # (N, out_channels, H, W)
157 | return x
158 |
159 | def forward(self, x, timesteps, context, y=None, **kwargs):
160 | """
161 | Forward pass that adapts comfy input to original forward function
162 | x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
163 | timesteps: (N,) tensor of diffusion timesteps
164 | context: (N, 1, 120, C) conditioning
165 | y: extra conditioning.
166 | """
167 | ## Still accepts the input w/o that dim but returns garbage
168 | if len(context.shape) == 3:
169 | context = context.unsqueeze(1)
170 |
171 | ## run original forward pass
172 | out = self.forward_raw(
173 | x = x.to(self.dtype),
174 | t = timesteps.to(self.dtype),
175 | y = context.to(self.dtype),
176 | )
177 |
178 | ## only return EPS
179 | out = out.to(torch.float)
180 | eps, rest = out[:, :self.in_channels], out[:, self.in_channels:]
181 | return eps
182 |
183 | def unpatchify(self, x):
184 | """
185 | x: (N, T, patch_size**2 * C)
186 | imgs: (N, H, W, C)
187 | """
188 | c = self.out_channels
189 | p = self.x_embedder.patch_size[0]
190 | h = w = int(x.shape[1] ** 0.5)
191 | assert h * w == x.shape[1]
192 |
193 | x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
194 | x = torch.einsum('nhwpqc->nchpwq', x)
195 | imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
196 | return imgs
197 |
198 |
199 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, pe_interpolation=1.0, base_size=16):
200 | """
201 | grid_size: int of the grid height and width
202 | return:
203 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
204 | """
205 | if isinstance(grid_size, int):
206 | grid_size = to_2tuple(grid_size)
207 | grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0]/base_size) / pe_interpolation
208 | grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1]/base_size) / pe_interpolation
209 | grid = np.meshgrid(grid_w, grid_h) # here w goes first
210 | grid = np.stack(grid, axis=0)
211 | grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
212 |
213 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
214 | if cls_token and extra_tokens > 0:
215 | pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
216 | return pos_embed
217 |
218 |
219 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
220 | assert embed_dim % 2 == 0
221 |
222 | # use half of dimensions to encode grid_h
223 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
224 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
225 |
226 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
227 | return emb
228 |
229 |
230 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
231 | """
232 | embed_dim: output dimension for each position
233 | pos: a list of positions to be encoded: size (M,)
234 | out: (M, D)
235 | """
236 | assert embed_dim % 2 == 0
237 | omega = np.arange(embed_dim // 2, dtype=np.float64)
238 | omega /= embed_dim / 2.
239 | omega = 1. / 10000 ** omega # (D/2,)
240 |
241 | pos = pos.reshape(-1) # (M,)
242 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
243 |
244 | emb_sin = np.sin(out) # (M, D/2)
245 | emb_cos = np.cos(out) # (M, D/2)
246 |
247 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
248 | return emb
249 |
--------------------------------------------------------------------------------
/PixArt/models/PixArtMS.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # GLIDE: https://github.com/openai/glide-text2im
9 | # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
10 | # --------------------------------------------------------
11 | import torch
12 | import torch.nn as nn
13 | from tqdm import tqdm
14 | from timm.models.layers import DropPath
15 | from timm.models.vision_transformer import Mlp
16 |
17 | from .utils import auto_grad_checkpoint, to_2tuple
18 | from .PixArt_blocks import t2i_modulate, CaptionEmbedder, AttentionKVCompress, MultiHeadCrossAttention, T2IFinalLayer, TimestepEmbedder, SizeEmbedder
19 | from .PixArt import PixArt, get_2d_sincos_pos_embed
20 |
21 |
22 | class PatchEmbed(nn.Module):
23 | """
24 | 2D Image to Patch Embedding
25 | """
26 | def __init__(
27 | self,
28 | patch_size=16,
29 | in_chans=3,
30 | embed_dim=768,
31 | norm_layer=None,
32 | flatten=True,
33 | bias=True,
34 | ):
35 | super().__init__()
36 | patch_size = to_2tuple(patch_size)
37 | self.patch_size = patch_size
38 | self.flatten = flatten
39 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
40 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
41 |
42 | def forward(self, x):
43 | x = self.proj(x)
44 | if self.flatten:
45 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
46 | x = self.norm(x)
47 | return x
48 |
49 |
50 | class PixArtMSBlock(nn.Module):
51 | """
52 | A PixArt block with adaptive layer norm zero (adaLN-Zero) conditioning.
53 | """
54 | def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0., input_size=None,
55 | sampling=None, sr_ratio=1, qk_norm=False, **block_kwargs):
56 | super().__init__()
57 | self.hidden_size = hidden_size
58 | self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
59 | self.attn = AttentionKVCompress(
60 | hidden_size, num_heads=num_heads, qkv_bias=True, sampling=sampling, sr_ratio=sr_ratio,
61 | qk_norm=qk_norm, **block_kwargs
62 | )
63 | self.cross_attn = MultiHeadCrossAttention(hidden_size, num_heads, **block_kwargs)
64 | self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
65 | # to be compatible with lower version pytorch
66 | approx_gelu = lambda: nn.GELU(approximate="tanh")
67 | self.mlp = Mlp(in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0)
68 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
69 | self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5)
70 |
71 | def forward(self, x, y, t, mask=None, HW=None, **kwargs):
72 | B, N, C = x.shape
73 |
74 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None] + t.reshape(B, 6, -1)).chunk(6, dim=1)
75 | x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=HW))
76 | x = x + self.cross_attn(x, y, mask)
77 | x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
78 |
79 | return x
80 |
81 |
82 | ### Core PixArt Model ###
83 | class PixArtMS(PixArt):
84 | """
85 | Diffusion model with a Transformer backbone.
86 | """
87 | def __init__(
88 | self,
89 | input_size=32,
90 | patch_size=2,
91 | in_channels=4,
92 | hidden_size=1152,
93 | depth=28,
94 | num_heads=16,
95 | mlp_ratio=4.0,
96 | class_dropout_prob=0.1,
97 | learn_sigma=True,
98 | pred_sigma=True,
99 | drop_path: float = 0.,
100 | caption_channels=4096,
101 | pe_interpolation=1.,
102 | config=None,
103 | model_max_length=120,
104 | micro_condition=True,
105 | qk_norm=False,
106 | kv_compress_config=None,
107 | **kwargs,
108 | ):
109 | super().__init__(
110 | input_size=input_size,
111 | patch_size=patch_size,
112 | in_channels=in_channels,
113 | hidden_size=hidden_size,
114 | depth=depth,
115 | num_heads=num_heads,
116 | mlp_ratio=mlp_ratio,
117 | class_dropout_prob=class_dropout_prob,
118 | learn_sigma=learn_sigma,
119 | pred_sigma=pred_sigma,
120 | drop_path=drop_path,
121 | pe_interpolation=pe_interpolation,
122 | config=config,
123 | model_max_length=model_max_length,
124 | qk_norm=qk_norm,
125 | kv_compress_config=kv_compress_config,
126 | **kwargs,
127 | )
128 | self.dtype = torch.get_default_dtype()
129 | self.h = self.w = 0
130 | approx_gelu = lambda: nn.GELU(approximate="tanh")
131 | self.t_block = nn.Sequential(
132 | nn.SiLU(),
133 | nn.Linear(hidden_size, 6 * hidden_size, bias=True)
134 | )
135 | self.x_embedder = PatchEmbed(patch_size, in_channels, hidden_size, bias=True)
136 | self.y_embedder = CaptionEmbedder(in_channels=caption_channels, hidden_size=hidden_size, uncond_prob=class_dropout_prob, act_layer=approx_gelu, token_num=model_max_length)
137 | self.micro_conditioning = micro_condition
138 | if self.micro_conditioning:
139 | self.csize_embedder = SizeEmbedder(hidden_size//3) # c_size embed
140 | self.ar_embedder = SizeEmbedder(hidden_size//3) # aspect ratio embed
141 | drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
142 | if kv_compress_config is None:
143 | kv_compress_config = {
144 | 'sampling': None,
145 | 'scale_factor': 1,
146 | 'kv_compress_layer': [],
147 | }
148 | self.blocks = nn.ModuleList([
149 | PixArtMSBlock(
150 | hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i],
151 | input_size=(input_size // patch_size, input_size // patch_size),
152 | sampling=kv_compress_config['sampling'],
153 | sr_ratio=int(kv_compress_config['scale_factor']) if i in kv_compress_config['kv_compress_layer'] else 1,
154 | qk_norm=qk_norm,
155 | )
156 | for i in range(depth)
157 | ])
158 | self.final_layer = T2IFinalLayer(hidden_size, patch_size, self.out_channels)
159 |
160 | def forward_raw(self, x, t, y, mask=None, data_info=None, **kwargs):
161 | """
162 | Original forward pass of PixArt.
163 | x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
164 | t: (N,) tensor of diffusion timesteps
165 | y: (N, 1, 120, C) tensor of class labels
166 | """
167 | bs = x.shape[0]
168 | x = x.to(self.dtype)
169 | timestep = t.to(self.dtype)
170 | y = y.to(self.dtype)
171 | self.h, self.w = x.shape[-2]//self.patch_size, x.shape[-1]//self.patch_size
172 | pos_embed = torch.from_numpy(
173 | get_2d_sincos_pos_embed(
174 | self.pos_embed.shape[-1], (self.h, self.w), pe_interpolation=self.pe_interpolation,
175 | base_size=self.base_size
176 | )
177 | ).unsqueeze(0).to(x.device).to(self.dtype)
178 |
179 | x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
180 | t = self.t_embedder(timestep) # (N, D)
181 |
182 | if self.micro_conditioning:
183 | c_size, ar = data_info['img_hw'].to(self.dtype), data_info['aspect_ratio'].to(self.dtype)
184 | csize = self.csize_embedder(c_size, bs) # (N, D)
185 | ar = self.ar_embedder(ar, bs) # (N, D)
186 | t = t + torch.cat([csize, ar], dim=1)
187 |
188 | t0 = self.t_block(t)
189 | y = self.y_embedder(y, self.training) # (N, D)
190 |
191 | if mask is not None:
192 | if mask.shape[0] != y.shape[0]:
193 | mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
194 | mask = mask.squeeze(1).squeeze(1)
195 | y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
196 | y_lens = mask.sum(dim=1).tolist()
197 | else:
198 | y_lens = [y.shape[2]] * y.shape[0]
199 | y = y.squeeze(1).view(1, -1, x.shape[-1])
200 | for block in self.blocks:
201 | x = auto_grad_checkpoint(block, x, y, t0, y_lens, (self.h, self.w), **kwargs) # (N, T, D) #support grad checkpoint
202 |
203 | x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
204 | x = self.unpatchify(x) # (N, out_channels, H, W)
205 |
206 | return x
207 |
208 | def forward(self, x, timesteps, context, img_hw=None, aspect_ratio=None, **kwargs):
209 | """
210 | Forward pass that adapts comfy input to original forward function
211 | x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
212 | timesteps: (N,) tensor of diffusion timesteps
213 | context: (N, 1, 120, C) conditioning
214 | img_hw: height|width conditioning
215 | aspect_ratio: aspect ratio conditioning
216 | """
217 | ## size/ar from cond with fallback based on the latent image shape.
218 | bs = x.shape[0]
219 | data_info = {}
220 | if img_hw is None:
221 | data_info["img_hw"] = torch.tensor(
222 | [[x.shape[2]*8, x.shape[3]*8]],
223 | dtype=self.dtype,
224 | device=x.device
225 | ).repeat(bs, 1)
226 | else:
227 | data_info["img_hw"] = img_hw.to(x.dtype).to(x.device)
228 | if aspect_ratio is None or True:
229 | data_info["aspect_ratio"] = torch.tensor(
230 | [[x.shape[2]/x.shape[3]]],
231 | dtype=self.dtype,
232 | device=x.device
233 | ).repeat(bs, 1)
234 | else:
235 | data_info["aspect_ratio"] = aspect_ratio.to(x.dtype).to(x.device)
236 |
237 | ## Still accepts the input w/o that dim but returns garbage
238 | if len(context.shape) == 3:
239 | context = context.unsqueeze(1)
240 |
241 | ## run original forward pass
242 | out = self.forward_raw(
243 | x = x.to(self.dtype),
244 | t = timesteps.to(self.dtype),
245 | y = context.to(self.dtype),
246 | data_info=data_info,
247 | )
248 |
249 | ## only return EPS
250 | out = out.to(torch.float)
251 | eps, rest = out[:, :self.in_channels], out[:, self.in_channels:]
252 | return eps
253 |
254 | def unpatchify(self, x):
255 | """
256 | x: (N, T, patch_size**2 * C)
257 | imgs: (N, H, W, C)
258 | """
259 | c = self.out_channels
260 | p = self.x_embedder.patch_size[0]
261 | assert self.h * self.w == x.shape[1]
262 |
263 | x = x.reshape(shape=(x.shape[0], self.h, self.w, p, p, c))
264 | x = torch.einsum('nhwpqc->nchpwq', x)
265 | imgs = x.reshape(shape=(x.shape[0], c, self.h * p, self.w * p))
266 | return imgs
267 |
--------------------------------------------------------------------------------
/PixArt/models/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.utils.checkpoint import checkpoint, checkpoint_sequential
5 | from collections.abc import Iterable
6 | from itertools import repeat
7 |
8 | def _ntuple(n):
9 | def parse(x):
10 | if isinstance(x, Iterable) and not isinstance(x, str):
11 | return x
12 | return tuple(repeat(x, n))
13 | return parse
14 |
15 | to_1tuple = _ntuple(1)
16 | to_2tuple = _ntuple(2)
17 |
18 | def set_grad_checkpoint(model, use_fp32_attention=False, gc_step=1):
19 | assert isinstance(model, nn.Module)
20 |
21 | def set_attr(module):
22 | module.grad_checkpointing = True
23 | module.fp32_attention = use_fp32_attention
24 | module.grad_checkpointing_step = gc_step
25 | model.apply(set_attr)
26 |
27 | def auto_grad_checkpoint(module, *args, **kwargs):
28 | if getattr(module, 'grad_checkpointing', False):
29 | if isinstance(module, Iterable):
30 | gc_step = module[0].grad_checkpointing_step
31 | return checkpoint_sequential(module, gc_step, *args, **kwargs)
32 | else:
33 | return checkpoint(module, *args, **kwargs)
34 | return module(*args, **kwargs)
35 |
36 | def checkpoint_sequential(functions, step, input, *args, **kwargs):
37 |
38 | # Hack for keyword-only parameter in a python 2.7-compliant way
39 | preserve = kwargs.pop('preserve_rng_state', True)
40 | if kwargs:
41 | raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))
42 |
43 | def run_function(start, end, functions):
44 | def forward(input):
45 | for j in range(start, end + 1):
46 | input = functions[j](input, *args)
47 | return input
48 | return forward
49 |
50 | if isinstance(functions, torch.nn.Sequential):
51 | functions = list(functions.children())
52 |
53 | # the last chunk has to be non-volatile
54 | end = -1
55 | segment = len(functions) // step
56 | for start in range(0, step * (segment - 1), step):
57 | end = start + step - 1
58 | input = checkpoint(run_function(start, end, functions), input, preserve_rng_state=preserve)
59 | return run_function(end + 1, len(functions) - 1, functions)(input)
60 |
61 | def get_rel_pos(q_size, k_size, rel_pos):
62 | """
63 | Get relative positional embeddings according to the relative positions of
64 | query and key sizes.
65 | Args:
66 | q_size (int): size of query q.
67 | k_size (int): size of key k.
68 | rel_pos (Tensor): relative position embeddings (L, C).
69 |
70 | Returns:
71 | Extracted positional embeddings according to relative positions.
72 | """
73 | max_rel_dist = int(2 * max(q_size, k_size) - 1)
74 | # Interpolate rel pos if needed.
75 | if rel_pos.shape[0] != max_rel_dist:
76 | # Interpolate rel pos.
77 | rel_pos_resized = F.interpolate(
78 | rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
79 | size=max_rel_dist,
80 | mode="linear",
81 | )
82 | rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
83 | else:
84 | rel_pos_resized = rel_pos
85 |
86 | # Scale the coords with short length if shapes for q and k are different.
87 | q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
88 | k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
89 | relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
90 |
91 | return rel_pos_resized[relative_coords.long()]
92 |
93 | def add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size):
94 | """
95 | Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
96 | https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
97 | Args:
98 | attn (Tensor): attention map.
99 | q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
100 | rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
101 | rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
102 | q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
103 | k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
104 |
105 | Returns:
106 | attn (Tensor): attention map with added relative positional embeddings.
107 | """
108 | q_h, q_w = q_size
109 | k_h, k_w = k_size
110 | Rh = get_rel_pos(q_h, k_h, rel_pos_h)
111 | Rw = get_rel_pos(q_w, k_w, rel_pos_w)
112 |
113 | B, _, dim = q.shape
114 | r_q = q.reshape(B, q_h, q_w, dim)
115 | rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
116 | rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
117 |
118 | attn = (
119 | attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
120 | ).view(B, q_h * q_w, k_h * k_w)
121 |
122 | return attn
123 |
--------------------------------------------------------------------------------
/PixArt/nodes.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import torch
4 | import folder_paths
5 |
6 | from comfy import utils
7 | from .conf import pixart_conf, pixart_res
8 | from .lora import load_pixart_lora
9 | from .loader import load_pixart
10 |
11 | class PixArtCheckpointLoader:
12 | @classmethod
13 | def INPUT_TYPES(s):
14 | return {
15 | "required": {
16 | "ckpt_name": (folder_paths.get_filename_list("checkpoints"),),
17 | "model": (list(pixart_conf.keys()),),
18 | }
19 | }
20 | RETURN_TYPES = ("MODEL",)
21 | RETURN_NAMES = ("model",)
22 | FUNCTION = "load_checkpoint"
23 | CATEGORY = "ExtraModels/PixArt"
24 | TITLE = "PixArt Checkpoint Loader"
25 |
26 | def load_checkpoint(self, ckpt_name, model):
27 | ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
28 | model_conf = pixart_conf[model]
29 | model = load_pixart(
30 | model_path = ckpt_path,
31 | model_conf = model_conf,
32 | )
33 | return (model,)
34 |
35 | class PixArtResolutionSelect():
36 | @classmethod
37 | def INPUT_TYPES(s):
38 | return {
39 | "required": {
40 | "model": (list(pixart_res.keys()),),
41 | # keys are the same for both
42 | "ratio": (list(pixart_res["PixArtMS_XL_2"].keys()),{"default":"1.00"}),
43 | }
44 | }
45 | RETURN_TYPES = ("INT","INT")
46 | RETURN_NAMES = ("width","height")
47 | FUNCTION = "get_res"
48 | CATEGORY = "ExtraModels/PixArt"
49 | TITLE = "PixArt Resolution Select"
50 |
51 | def get_res(self, model, ratio):
52 | width, height = pixart_res[model][ratio]
53 | return (width,height)
54 |
55 | class PixArtLoraLoader:
56 | def __init__(self):
57 | self.loaded_lora = None
58 |
59 | @classmethod
60 | def INPUT_TYPES(s):
61 | return {
62 | "required": {
63 | "model": ("MODEL",),
64 | "lora_name": (folder_paths.get_filename_list("loras"), ),
65 | "strength": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
66 | }
67 | }
68 | RETURN_TYPES = ("MODEL",)
69 | FUNCTION = "load_lora"
70 | CATEGORY = "ExtraModels/PixArt"
71 | TITLE = "PixArt Load LoRA"
72 |
73 | def load_lora(self, model, lora_name, strength,):
74 | if strength == 0:
75 | return (model)
76 |
77 | lora_path = folder_paths.get_full_path("loras", lora_name)
78 | lora = None
79 | if self.loaded_lora is not None:
80 | if self.loaded_lora[0] == lora_path:
81 | lora = self.loaded_lora[1]
82 | else:
83 | temp = self.loaded_lora
84 | self.loaded_lora = None
85 | del temp
86 |
87 | if lora is None:
88 | lora = utils.load_torch_file(lora_path, safe_load=True)
89 | self.loaded_lora = (lora_path, lora)
90 |
91 | model_lora = load_pixart_lora(model, lora, lora_path, strength,)
92 | return (model_lora,)
93 |
94 | class PixArtResolutionCond:
95 | @classmethod
96 | def INPUT_TYPES(s):
97 | return {
98 | "required": {
99 | "cond": ("CONDITIONING", ),
100 | "width": ("INT", {"default": 1024.0, "min": 0, "max": 8192}),
101 | "height": ("INT", {"default": 1024.0, "min": 0, "max": 8192}),
102 | }
103 | }
104 |
105 | RETURN_TYPES = ("CONDITIONING",)
106 | RETURN_NAMES = ("cond",)
107 | FUNCTION = "add_cond"
108 | CATEGORY = "ExtraModels/PixArt"
109 | TITLE = "PixArt Resolution Conditioning"
110 |
111 | def add_cond(self, cond, width, height):
112 | for c in range(len(cond)):
113 | cond[c][1].update({
114 | "img_hw": [[height, width]],
115 | "aspect_ratio": [[height/width]],
116 | })
117 | return (cond,)
118 |
119 | class PixArtControlNetCond:
120 | @classmethod
121 | def INPUT_TYPES(s):
122 | return {
123 | "required": {
124 | "cond": ("CONDITIONING",),
125 | "latent": ("LATENT",),
126 | # "image": ("IMAGE",),
127 | # "vae": ("VAE",),
128 | # "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01})
129 | }
130 | }
131 |
132 | RETURN_TYPES = ("CONDITIONING",)
133 | RETURN_NAMES = ("cond",)
134 | FUNCTION = "add_cond"
135 | CATEGORY = "ExtraModels/PixArt"
136 | TITLE = "PixArt ControlNet Conditioning"
137 |
138 | def add_cond(self, cond, latent):
139 | for c in range(len(cond)):
140 | cond[c][1]["cn_hint"] = latent["samples"] * 0.18215
141 | return (cond,)
142 |
143 | class PixArtT5TextEncode:
144 | """
145 | Reference code, mostly to verify compatibility.
146 | Once everything works, this should instead inherit from the
147 | T5 text encode node and simply add the extra conds (res/ar).
148 | """
149 | @classmethod
150 | def INPUT_TYPES(s):
151 | return {
152 | "required": {
153 | "text": ("STRING", {"multiline": True}),
154 | "T5": ("T5",),
155 | }
156 | }
157 |
158 | RETURN_TYPES = ("CONDITIONING",)
159 | FUNCTION = "encode"
160 | CATEGORY = "ExtraModels/PixArt"
161 | TITLE = "PixArt T5 Text Encode [Reference]"
162 |
163 | def mask_feature(self, emb, mask):
164 | if emb.shape[0] == 1:
165 | keep_index = mask.sum().item()
166 | return emb[:, :, :keep_index, :], keep_index
167 | else:
168 | masked_feature = emb * mask[:, None, :, None]
169 | return masked_feature, emb.shape[2]
170 |
171 | def encode(self, text, T5):
172 | text = text.lower().strip()
173 | tokenizer_out = T5.tokenizer.tokenizer(
174 | text,
175 | max_length = 120,
176 | padding = 'max_length',
177 | truncation = True,
178 | return_attention_mask = True,
179 | add_special_tokens = True,
180 | return_tensors = 'pt'
181 | )
182 | tokens = tokenizer_out["input_ids"]
183 | mask = tokenizer_out["attention_mask"]
184 | embs = T5.cond_stage_model.transformer(
185 | input_ids = tokens.to(T5.load_device),
186 | attention_mask = mask.to(T5.load_device),
187 | )['last_hidden_state'].float()[:, None]
188 | masked_embs, keep_index = self.mask_feature(
189 | embs.detach().to("cpu"),
190 | mask.detach().to("cpu")
191 | )
192 | masked_embs = masked_embs.squeeze(0) # match CLIP/internal
193 | print("Encoded T5:", masked_embs.shape)
194 | return ([[masked_embs, {}]], )
195 |
196 | NODE_CLASS_MAPPINGS = {
197 | "PixArtCheckpointLoader" : PixArtCheckpointLoader,
198 | "PixArtResolutionSelect" : PixArtResolutionSelect,
199 | "PixArtLoraLoader" : PixArtLoraLoader,
200 | "PixArtT5TextEncode" : PixArtT5TextEncode,
201 | "PixArtResolutionCond" : PixArtResolutionCond,
202 | "PixArtControlNetCond" : PixArtControlNetCond,
203 | }
204 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ComfyUI HunyuanDiT (WIP)
2 |
3 | [HunyuanDiT](https://github.com/Tencent/HunyuanDiT)
4 |
5 | ```
6 | huggingface-cli download --resume-download Tencent-Hunyuan/HunyuanDiT --local-dir ComfyUI/models/diffusers --local-dir-use-symlinks False
7 | ```
8 |
9 | sdxl vae
10 |
11 | ## workflow
12 |
13 | [Recommended complete Workflow](https://github.com/chaojie/ComfyUI_ExtraModels/blob/main/HunYuan/wf.json)
14 |
15 |
--------------------------------------------------------------------------------
/T5/LICENSE-T5:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/T5/loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import comfy.utils
4 | import comfy.model_patcher
5 | from comfy import model_management
6 | import folder_paths
7 |
8 | from .t5v11 import T5v11Model, T5v11Tokenizer
9 |
10 | class EXM_T5v11:
11 | def __init__(self, textmodel_ver="xxl", embedding_directory=None, textmodel_path=None, no_init=False, device="cpu", dtype=None):
12 | if no_init:
13 | return
14 |
15 | if device == "auto":
16 | size = 0
17 | self.load_device = model_management.text_encoder_device()
18 | self.offload_device = model_management.text_encoder_offload_device()
19 | self.init_device = "cpu"
20 | elif dtype == "bnb8bit":
21 | # BNB doesn't support size enum
22 | size = 12.4 * (1024**3)
23 | # Or moving between devices
24 | self.load_device = model_management.get_torch_device()
25 | self.offload_device = self.load_device
26 | self.init_device = self.load_device
27 | elif dtype == "bnb4bit":
28 | # This seems to use the same VRAM as 8bit on Pascal?
29 | size = 6.2 * (1024**3)
30 | self.load_device = model_management.get_torch_device()
31 | self.offload_device = self.load_device
32 | self.init_device = self.load_device
33 | elif device == "cpu":
34 | size = 0
35 | self.load_device = "cpu"
36 | self.offload_device = "cpu"
37 | self.init_device="cpu"
38 | elif device.startswith("cuda"):
39 | print("Direct CUDA device override!\nVRAM will not be freed by default.")
40 | size = 0
41 | self.load_device = device
42 | self.offload_device = device
43 | self.init_device = device
44 | else:
45 | size = 0
46 | self.load_device = model_management.get_torch_device()
47 | self.offload_device = "cpu"
48 | self.init_device="cpu"
49 |
50 | self.cond_stage_model = T5v11Model(
51 | textmodel_ver = textmodel_ver,
52 | textmodel_path = textmodel_path,
53 | device = device,
54 | dtype = dtype,
55 | )
56 | self.tokenizer = T5v11Tokenizer(embedding_directory=embedding_directory)
57 | self.patcher = comfy.model_patcher.ModelPatcher(
58 | self.cond_stage_model,
59 | load_device = self.load_device,
60 | offload_device = self.offload_device,
61 | current_device = self.load_device,
62 | size = size,
63 | )
64 |
65 | def clone(self):
66 | n = T5(no_init=True)
67 | n.patcher = self.patcher.clone()
68 | n.cond_stage_model = self.cond_stage_model
69 | n.tokenizer = self.tokenizer
70 | return n
71 |
72 | def tokenize(self, text, return_word_ids=False):
73 | return self.tokenizer.tokenize_with_weights(text, return_word_ids)
74 |
75 | def encode_from_tokens(self, tokens):
76 | self.load_model()
77 | return self.cond_stage_model.encode_token_weights(tokens)
78 |
79 | def encode(self, text):
80 | tokens = self.tokenize(text)
81 | return self.encode_from_tokens(tokens)
82 |
83 | def load_sd(self, sd):
84 | return self.cond_stage_model.load_sd(sd)
85 |
86 | def get_sd(self):
87 | return self.cond_stage_model.state_dict()
88 |
89 | def load_model(self):
90 | if self.load_device != "cpu":
91 | model_management.load_model_gpu(self.patcher)
92 | return self.patcher
93 |
94 | def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
95 | return self.patcher.add_patches(patches, strength_patch, strength_model)
96 |
97 | def get_key_patches(self):
98 | return self.patcher.get_key_patches()
99 |
100 |
101 | def load_t5(model_type, model_ver, model_path, path_type="file", device="cpu", dtype=None):
102 | assert model_type in ["t5v11"] # Only supported model for now
103 | model_args = {
104 | "textmodel_ver" : model_ver,
105 | "device" : device,
106 | "dtype" : dtype,
107 | }
108 |
109 | if path_type == "folder":
110 | # pass directly to transformers and initialize there
111 | # this is to avoid having to handle multi-file state dict loading for now.
112 | model_args["textmodel_path"] = os.path.dirname(model_path)
113 | return EXM_T5v11(**model_args)
114 | else:
115 | # for some reason this returns garbage with torch.int8 weights, or just OOMs
116 | model = EXM_T5v11(**model_args)
117 | sd = comfy.utils.load_torch_file(model_path)
118 | model.load_sd(sd)
119 | return model
120 |
--------------------------------------------------------------------------------
/T5/nodes.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import torch
4 | import folder_paths
5 |
6 | from .loader import load_t5
7 | from ..utils.dtype import string_to_dtype
8 |
9 | # initialize custom folder path
10 | os.makedirs(
11 | os.path.join(folder_paths.models_dir,"t5"),
12 | exist_ok = True,
13 | )
14 | folder_paths.folder_names_and_paths["t5"] = (
15 | [
16 | os.path.join(folder_paths.models_dir,"t5"),
17 | *folder_paths.folder_names_and_paths.get("t5", [[],set()])[0]
18 | ],
19 | folder_paths.supported_pt_extensions
20 | )
21 |
22 | dtypes = [
23 | "default",
24 | "auto (comfy)",
25 | "FP32",
26 | "FP16",
27 | # Note: remove these at some point
28 | "bnb8bit",
29 | "bnb4bit",
30 | ]
31 | try: torch.float8_e5m2
32 | except AttributeError: print("Torch version too old for FP8")
33 | else: dtypes += ["FP8 E4M3", "FP8 E5M2"]
34 |
35 | class T5v11Loader:
36 | @classmethod
37 | def INPUT_TYPES(s):
38 | devices = ["auto", "cpu", "gpu"]
39 | # hack for using second GPU as offload
40 | for k in range(1, torch.cuda.device_count()):
41 | devices.append(f"cuda:{k}")
42 | return {
43 | "required": {
44 | "t5v11_name": (folder_paths.get_filename_list("t5"),),
45 | "t5v11_ver": (["xxl"],),
46 | "path_type": (["folder", "file"],),
47 | "device": (devices, {"default":"cpu"}),
48 | "dtype": (dtypes,),
49 | }
50 | }
51 | RETURN_TYPES = ("T5",)
52 | FUNCTION = "load_model"
53 | CATEGORY = "ExtraModels/T5"
54 | TITLE = "T5v1.1 Loader"
55 |
56 | def load_model(self, t5v11_name, t5v11_ver, path_type, device, dtype):
57 | if "bnb" in dtype:
58 | assert device == "gpu" or device.startswith("cuda"), "BitsAndBytes only works on CUDA! Set device to 'gpu'."
59 | dtype = string_to_dtype(dtype, "text_encoder")
60 | if device == "cpu":
61 | assert dtype in [None, torch.float32], f"Can't use dtype '{dtype}' with CPU! Set dtype to 'default'."
62 |
63 | return (load_t5(
64 | model_type = "t5v11",
65 | model_ver = t5v11_ver,
66 | model_path = folder_paths.get_full_path("t5", t5v11_name),
67 | path_type = path_type,
68 | device = device,
69 | dtype = dtype,
70 | ),)
71 |
72 | class T5TextEncode:
73 | @classmethod
74 | def INPUT_TYPES(s):
75 | return {
76 | "required": {
77 | "text": ("STRING", {"multiline": True}),
78 | "T5": ("T5",),
79 | }
80 | }
81 |
82 | RETURN_TYPES = ("CONDITIONING",)
83 | FUNCTION = "encode"
84 | CATEGORY = "ExtraModels/T5"
85 | TITLE = "T5 Text Encode"
86 |
87 | def encode(self, text, T5=None):
88 | tokens = T5.tokenize(text)
89 | cond = T5.encode_from_tokens(tokens)
90 | return ([[cond, {}]], )
91 |
92 | NODE_CLASS_MAPPINGS = {
93 | "T5v11Loader" : T5v11Loader,
94 | "T5TextEncode" : T5TextEncode,
95 | }
96 |
--------------------------------------------------------------------------------
/T5/t5_tokenizer/special_tokens_map.json:
--------------------------------------------------------------------------------
1 | {"eos_token": "", "unk_token": "", "pad_token": "", "additional_special_tokens": ["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", ""]}
--------------------------------------------------------------------------------
/T5/t5_tokenizer/spiece.model:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI_ExtraModels/d8b11e401de830ccfb27fa84bdd0091b52408af8/T5/t5_tokenizer/spiece.model
--------------------------------------------------------------------------------
/T5/t5_tokenizer/tokenizer_config.json:
--------------------------------------------------------------------------------
1 | {"eos_token": "", "unk_token": "", "pad_token": "", "extra_ids": 100, "additional_special_tokens": ["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", ""], "model_max_length": 512, "name_or_path": "t5-small"}
--------------------------------------------------------------------------------
/T5/t5v11-xxl_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "_name_or_path": "google/t5-v1_1-xxl",
3 | "architectures": [
4 | "T5EncoderModel"
5 | ],
6 | "d_ff": 10240,
7 | "d_kv": 64,
8 | "d_model": 4096,
9 | "decoder_start_token_id": 0,
10 | "dense_act_fn": "gelu_new",
11 | "dropout_rate": 0.1,
12 | "eos_token_id": 1,
13 | "feed_forward_proj": "gated-gelu",
14 | "initializer_factor": 1.0,
15 | "is_encoder_decoder": true,
16 | "is_gated_act": true,
17 | "layer_norm_epsilon": 1e-06,
18 | "model_type": "t5",
19 | "num_decoder_layers": 24,
20 | "num_heads": 64,
21 | "num_layers": 24,
22 | "output_past": true,
23 | "pad_token_id": 0,
24 | "relative_attention_max_distance": 128,
25 | "relative_attention_num_buckets": 32,
26 | "tie_word_embeddings": false,
27 | "torch_dtype": "float32",
28 | "transformers_version": "4.21.1",
29 | "use_cache": true,
30 | "vocab_size": 32128
31 | }
32 |
--------------------------------------------------------------------------------
/T5/t5v11.py:
--------------------------------------------------------------------------------
1 | """
2 | Adapted from comfyui CLIP code.
3 | https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/sd1_clip.py
4 | """
5 |
6 | import os
7 |
8 | from transformers import T5Tokenizer, T5EncoderModel, T5Config, modeling_utils
9 | import torch
10 | import traceback
11 | import zipfile
12 | from comfy import model_management
13 |
14 | from comfy.sd1_clip import parse_parentheses, token_weights, escape_important, unescape_important, safe_load_embed_zip, expand_directory_list, load_embed
15 |
16 | class T5v11Model(torch.nn.Module):
17 | def __init__(self, textmodel_ver="xxl", textmodel_json_config=None, textmodel_path=None, device="cpu", max_length=120, freeze=True, dtype=None):
18 | super().__init__()
19 |
20 | self.num_layers = 24
21 | self.max_length = max_length
22 | self.bnb = False
23 |
24 | if textmodel_path is not None:
25 | model_args = {}
26 | model_args["low_cpu_mem_usage"] = True # Don't take 2x system ram on cpu
27 | if dtype == "bnb8bit":
28 | self.bnb = True
29 | model_args["load_in_8bit"] = True
30 | elif dtype == "bnb4bit":
31 | self.bnb = True
32 | model_args["load_in_4bit"] = True
33 | else:
34 | if dtype: model_args["torch_dtype"] = dtype
35 | self.bnb = False
36 | # second GPU offload hack part 2
37 | if device.startswith("cuda"):
38 | model_args["device_map"] = device
39 | print(f"Loading T5 from '{textmodel_path}'")
40 | self.transformer = T5EncoderModel.from_pretrained(textmodel_path, **model_args)
41 | else:
42 | if textmodel_json_config is None:
43 | textmodel_json_config = os.path.join(
44 | os.path.dirname(os.path.realpath(__file__)),
45 | f"t5v11-{textmodel_ver}_config.json"
46 | )
47 | config = T5Config.from_json_file(textmodel_json_config)
48 | self.num_layers = config.num_hidden_layers
49 | with modeling_utils.no_init_weights():
50 | self.transformer = T5EncoderModel(config)
51 |
52 | if freeze:
53 | self.freeze()
54 | self.empty_tokens = [[0] * self.max_length] # token
55 |
56 | def freeze(self):
57 | self.transformer = self.transformer.eval()
58 | for param in self.parameters():
59 | param.requires_grad = False
60 |
61 | def forward(self, tokens):
62 | device = self.transformer.get_input_embeddings().weight.device
63 | tokens = torch.LongTensor(tokens).to(device)
64 | attention_mask = torch.zeros_like(tokens)
65 | max_token = 1 # token
66 | for x in range(attention_mask.shape[0]):
67 | for y in range(attention_mask.shape[1]):
68 | attention_mask[x, y] = 1
69 | if tokens[x, y] == max_token:
70 | break
71 |
72 | outputs = self.transformer(input_ids=tokens, attention_mask=attention_mask)
73 |
74 | z = outputs['last_hidden_state']
75 | z.detach().cpu().float()
76 | return z
77 |
78 | def encode(self, tokens):
79 | return self(tokens)
80 |
81 | def load_sd(self, sd):
82 | return self.transformer.load_state_dict(sd, strict=False)
83 |
84 | def to(self, *args, **kwargs):
85 | """BNB complains if you try to change the device or dtype"""
86 | if self.bnb:
87 | print("Thanks to BitsAndBytes, T5 becomes an immovable rock.", args, kwargs)
88 | else:
89 | self.transformer.to(*args, **kwargs)
90 |
91 | def encode_token_weights(self, token_weight_pairs, return_padded=False):
92 | to_encode = list(self.empty_tokens)
93 | for x in token_weight_pairs:
94 | tokens = list(map(lambda a: a[0], x))
95 | to_encode.append(tokens)
96 |
97 | out = self.encode(to_encode)
98 | z_empty = out[0:1]
99 |
100 | output = []
101 | for k in range(1, out.shape[0]):
102 | z = out[k:k+1]
103 | for i in range(len(z)):
104 | for j in range(len(z[i])):
105 | weight = token_weight_pairs[k - 1][j][1]
106 | z[i][j] = (z[i][j] - z_empty[0][j]) * weight + z_empty[0][j]
107 | output.append(z)
108 |
109 | if (len(output) == 0):
110 | return z_empty.cpu()
111 |
112 | out = torch.cat(output, dim=-2)
113 | if not return_padded:
114 | # Count number of tokens that aren't , then use that number as an index.
115 | keep_index = sum([sum([1 for y in x if y[0] != 0]) for x in token_weight_pairs])
116 | out = out[:, :keep_index, :]
117 | return out
118 |
119 |
120 | class T5v11Tokenizer:
121 | """
122 | This is largely just based on the ComfyUI CLIP code.
123 | """
124 | def __init__(self, tokenizer_path=None, max_length=120, embedding_directory=None, embedding_size=4096, embedding_key='t5'):
125 | if tokenizer_path is None:
126 | tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
127 | self.tokenizer = T5Tokenizer.from_pretrained(tokenizer_path)
128 | self.max_length = max_length
129 | self.max_tokens_per_section = self.max_length - 1 # but no
130 |
131 | self.pad_token = self.tokenizer("", add_special_tokens=False)["input_ids"][0]
132 | self.end_token = self.tokenizer("", add_special_tokens=False)["input_ids"][0]
133 | vocab = self.tokenizer.get_vocab()
134 | self.inv_vocab = {v: k for k, v in vocab.items()}
135 | self.embedding_directory = embedding_directory
136 | self.max_word_length = 8 # haven't verified this
137 | self.embedding_identifier = "embedding:"
138 | self.embedding_size = embedding_size
139 | self.embedding_key = embedding_key
140 |
141 | def _try_get_embedding(self, embedding_name:str):
142 | '''
143 | Takes a potential embedding name and tries to retrieve it.
144 | Returns a Tuple consisting of the embedding and any leftover string, embedding can be None.
145 | '''
146 | embed = load_embed(embedding_name, self.embedding_directory, self.embedding_size, self.embedding_key)
147 | if embed is None:
148 | stripped = embedding_name.strip(',')
149 | if len(stripped) < len(embedding_name):
150 | embed = load_embed(stripped, self.embedding_directory, self.embedding_size, self.embedding_key)
151 | return (embed, embedding_name[len(stripped):])
152 | return (embed, "")
153 |
154 | def tokenize_with_weights(self, text:str, return_word_ids=False):
155 | '''
156 | Takes a prompt and converts it to a list of (token, weight, word id) elements.
157 | Tokens can both be integer tokens and pre computed T5 tensors.
158 | Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens.
159 | Returned list has the dimensions NxM where M is the input size of T5
160 | '''
161 | pad_token = self.pad_token
162 | text = escape_important(text)
163 | parsed_weights = token_weights(text, 1.0)
164 |
165 | #tokenize words
166 | tokens = []
167 | for weighted_segment, weight in parsed_weights:
168 | to_tokenize = unescape_important(weighted_segment).replace("\n", " ").split(' ')
169 | to_tokenize = [x for x in to_tokenize if x != ""]
170 | for word in to_tokenize:
171 | #if we find an embedding, deal with the embedding
172 | if word.startswith(self.embedding_identifier) and self.embedding_directory is not None:
173 | embedding_name = word[len(self.embedding_identifier):].strip('\n')
174 | embed, leftover = self._try_get_embedding(embedding_name)
175 | if embed is None:
176 | print(f"warning, embedding:{embedding_name} does not exist, ignoring")
177 | else:
178 | if len(embed.shape) == 1:
179 | tokens.append([(embed, weight)])
180 | else:
181 | tokens.append([(embed[x], weight) for x in range(embed.shape[0])])
182 | #if we accidentally have leftover text, continue parsing using leftover, else move on to next word
183 | if leftover != "":
184 | word = leftover
185 | else:
186 | continue
187 | #parse word
188 | tokens.append([(t, weight) for t in self.tokenizer(word, add_special_tokens=False)["input_ids"]])
189 |
190 | #reshape token array to T5 input size
191 | batched_tokens = []
192 | batch = []
193 | batched_tokens.append(batch)
194 | for i, t_group in enumerate(tokens):
195 | #determine if we're going to try and keep the tokens in a single batch
196 | is_large = len(t_group) >= self.max_word_length
197 |
198 | while len(t_group) > 0:
199 | if len(t_group) + len(batch) > self.max_length - 1:
200 | remaining_length = self.max_length - len(batch) - 1
201 | #break word in two and add end token
202 | if is_large:
203 | batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]])
204 | batch.append((self.end_token, 1.0, 0))
205 | t_group = t_group[remaining_length:]
206 | #add end token and pad
207 | else:
208 | batch.append((self.end_token, 1.0, 0))
209 | batch.extend([(self.pad_token, 1.0, 0)] * (remaining_length))
210 | #start new batch
211 | batch = []
212 | batched_tokens.append(batch)
213 | else:
214 | batch.extend([(t,w,i+1) for t,w in t_group])
215 | t_group = []
216 |
217 | # fill last batch
218 | batch.extend([(self.end_token, 1.0, 0)] + [(self.pad_token, 1.0, 0)] * (self.max_length - len(batch) - 1))
219 | # instead of filling, just add EOS (DEBUG)
220 | # batch.extend([(self.end_token, 1.0, 0)])
221 |
222 | if not return_word_ids:
223 | batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]
224 | return batched_tokens
225 |
226 | def untokenize(self, token_weight_pair):
227 | return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair))
228 |
--------------------------------------------------------------------------------
/VAE/conf.py:
--------------------------------------------------------------------------------
1 | """
2 | List of all VAE configs, with training parts stripped.
3 | """
4 | vae_conf = {
5 | ### AutoencoderKL ###
6 | "kl-f4": {
7 | "type" : "AutoencoderKL",
8 | "embed_scale" : 4,
9 | "embed_dim" : 3,
10 | "z_channels" : 3,
11 | "double_z" : True,
12 | "resolution" : 256,
13 | "in_channels" : 3,
14 | "out_ch" : 3,
15 | "ch" : 128,
16 | "ch_mult" : [1,2,4],
17 | "num_res_blocks" : 2,
18 | "attn_resolutions" : [],
19 | },
20 | "kl-f8": { # Default SD1.5 VAE
21 | "type" : "AutoencoderKL",
22 | "embed_scale" : 8,
23 | "embed_dim" : 4,
24 | "z_channels" : 4,
25 | "double_z" : True,
26 | "resolution" : 256,
27 | "in_channels" : 3,
28 | "out_ch" : 3,
29 | "ch" : 128,
30 | "ch_mult" : [1,2,4,4],
31 | "num_res_blocks" : 2,
32 | "attn_resolutions" : [],
33 | },
34 | "kl-f16": {
35 | "type" : "AutoencoderKL",
36 | "embed_scale" : 16,
37 | "embed_dim" : 16,
38 | "z_channels" : 16,
39 | "double_z" : True,
40 | "resolution" : 256,
41 | "in_channels" : 3,
42 | "out_ch" : 3,
43 | "ch" : 128,
44 | "ch_mult" : [1,1,2,2,4],
45 | "num_res_blocks" : 2,
46 | "attn_resolutions" : [16],
47 | },
48 | "kl-f32": {
49 | "type" : "AutoencoderKL",
50 | "embed_scale" : 32,
51 | "embed_dim" : 64,
52 | "z_channels" : 64,
53 | "double_z" : True,
54 | "resolution" : 256,
55 | "in_channels" : 3,
56 | "out_ch" : 3,
57 | "ch" : 128,
58 | "ch_mult" : [1,1,2,2,4,4],
59 | "num_res_blocks" : 2,
60 | "attn_resolutions" : [16,8],
61 | },
62 | ### VQModel ###
63 | "vq-f4": {
64 | "type" : "VQModel",
65 | "embed_scale" : 4,
66 | "n_embed" : 8192,
67 | "embed_dim" : 3,
68 | "z_channels" : 3,
69 | "double_z" : False,
70 | "resolution" : 256,
71 | "in_channels" : 3,
72 | "out_ch" : 3,
73 | "ch" : 128,
74 | "ch_mult" : [1,2,4],
75 | "num_res_blocks" : 2,
76 | "attn_resolutions" : [],
77 | },
78 | "vq-f8": {
79 | "type" : "VQModel",
80 | "embed_scale" : 8,
81 | "n_embed" : 16384,
82 | "embed_dim" : 4,
83 | "z_channels" : 4,
84 | "double_z" : False,
85 | "resolution" : 256,
86 | "in_channels" : 3,
87 | "out_ch" : 3,
88 | "ch" : 128,
89 | "ch_mult" : [1,2,2,4],
90 | "num_res_blocks" : 2,
91 | "attn_resolutions" : [32],
92 | },
93 | "vq-f16": {
94 | "type" : "VQModel",
95 | "embed_scale" : 16,
96 | "n_embed" : 16384,
97 | "embed_dim" : 8,
98 | "z_channels" : 8,
99 | "double_z" : False,
100 | "resolution" : 256,
101 | "in_channels" : 3,
102 | "out_ch" : 3,
103 | "ch" : 128,
104 | "ch_mult" : [1,1,2,2,4],
105 | "num_res_blocks" : 2,
106 | "attn_resolutions" : [16],
107 | },
108 | # OpenAI Consistency Decoder
109 | "Consistency-Decoder": {
110 | "type" : "ConsistencyDecoder",
111 | "embed_scale" : 8,
112 | "embed_dim" : 4,
113 | },
114 | # SAI Video Decoder
115 | "SDV-VideoDecoder": {
116 | "type" : "AutoencoderKL-VideoDecoder",
117 | "embed_scale" : 8,
118 | "embed_dim" : 4,
119 | "z_channels" : 4,
120 | "double_z" : True,
121 | "resolution" : 256,
122 | "in_channels" : 3,
123 | "out_ch" : 3,
124 | "ch" : 128,
125 | "ch_mult" : [1,2,4,4],
126 | "num_res_blocks" : 2,
127 | "attn_resolutions" : [],
128 | "video_kernel_size": [3, 1, 1]
129 | },
130 | # Kandinsky-3
131 | "MoVQ3": {
132 | "type" : "MoVQ3",
133 | "embed_scale" : 8,
134 | "embed_dim" : 4,
135 | "double_z" : False,
136 | "z_channels" : 4,
137 | "resolution" : 256,
138 | "in_channels" : 3,
139 | "out_ch" : 3,
140 | "ch" : 256,
141 | "ch_mult" : [1, 2, 2, 4],
142 | "num_res_blocks" : 2,
143 | "attn_resolutions" : [32],
144 | }
145 | }
146 |
--------------------------------------------------------------------------------
/VAE/loader.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import comfy.sd
3 | import comfy.utils
4 | from comfy import model_management
5 | from comfy import diffusers_convert
6 |
7 | class EXVAE(comfy.sd.VAE):
8 | def __init__(self, model_path, model_conf, dtype=torch.float32):
9 | self.latent_dim = model_conf["embed_dim"]
10 | self.latent_scale = model_conf["embed_scale"]
11 | self.device = model_management.vae_device()
12 | self.offload_device = model_management.vae_offload_device()
13 | self.vae_dtype = dtype
14 |
15 | sd = comfy.utils.load_torch_file(model_path)
16 | model = None
17 | if model_conf["type"] == "AutoencoderKL":
18 | from .models.kl import AutoencoderKL
19 | model = AutoencoderKL(config=model_conf)
20 | if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys():
21 | sd = diffusers_convert.convert_vae_state_dict(sd)
22 | elif model_conf["type"] == "AutoencoderKL-VideoDecoder":
23 | from .models.temporal_ae import AutoencoderKL
24 | model = AutoencoderKL(config=model_conf)
25 | elif model_conf["type"] == "VQModel":
26 | from .models.vq import VQModel
27 | model = VQModel(config=model_conf)
28 | elif model_conf["type"] == "ConsistencyDecoder":
29 | from .models.consistencydecoder import ConsistencyDecoder
30 | model = ConsistencyDecoder()
31 | sd = {f"model.{k}":v for k,v in sd.items()}
32 | elif model_conf["type"] == "MoVQ3":
33 | from .models.movq3 import MoVQ
34 | model = MoVQ(model_conf)
35 | else:
36 | raise NotImplementedError(f"Unknown VAE type '{model_conf['type']}'")
37 |
38 | self.first_stage_model = model.eval()
39 | m, u = self.first_stage_model.load_state_dict(sd, strict=False)
40 | if len(m) > 0: print("Missing VAE keys", m)
41 | if len(u) > 0: print("Leftover VAE keys", u)
42 |
43 | self.first_stage_model.to(self.vae_dtype).to(self.offload_device)
44 |
45 | ### Encode/Decode functions below needed due to source repo having 4 VAE channels and a scale factor of 8 hardcoded
46 | def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
47 | steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
48 | steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap)
49 | steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap)
50 | pbar = comfy.utils.ProgressBar(steps)
51 |
52 | decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)) + 1.0).float()
53 | output = torch.clamp((
54 | (comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.latent_scale, pbar = pbar) +
55 | comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.latent_scale, pbar = pbar) +
56 | comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = self.latent_scale, pbar = pbar))
57 | / 3.0) / 2.0, min=0.0, max=1.0)
58 | return output
59 |
60 | def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
61 | steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
62 | steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap)
63 | steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
64 | pbar = comfy.utils.ProgressBar(steps)
65 |
66 | encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).float()
67 | samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.latent_scale), out_channels=self.latent_dim, pbar=pbar)
68 | samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.latent_scale), out_channels=self.latent_dim, pbar=pbar)
69 | samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.latent_scale), out_channels=self.latent_dim, pbar=pbar)
70 | samples /= 3.0
71 | return samples
72 |
73 | def decode(self, samples_in):
74 | self.first_stage_model = self.first_stage_model.to(self.device)
75 | try:
76 | memory_used = (2562 * samples_in.shape[2] * samples_in.shape[3] * 64) * 1.7
77 | model_management.free_memory(memory_used, self.device)
78 | free_memory = model_management.get_free_memory(self.device)
79 | batch_number = int(free_memory / memory_used)
80 | batch_number = max(1, batch_number)
81 |
82 | pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * self.latent_scale), round(samples_in.shape[3] * self.latent_scale)), device="cpu")
83 | for x in range(0, samples_in.shape[0], batch_number):
84 | samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
85 | pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples).cpu().float() + 1.0) / 2.0, min=0.0, max=1.0)
86 | except model_management.OOM_EXCEPTION as e:
87 | print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
88 | pixel_samples = self.decode_tiled_(samples_in)
89 |
90 | self.first_stage_model = self.first_stage_model.to(self.offload_device)
91 | pixel_samples = pixel_samples.cpu().movedim(1,-1)
92 | return pixel_samples
93 |
94 | def encode(self, pixel_samples):
95 | self.first_stage_model = self.first_stage_model.to(self.device)
96 | pixel_samples = pixel_samples.movedim(-1,1)
97 | try:
98 | memory_used = (2078 * pixel_samples.shape[2] * pixel_samples.shape[3]) * 1.7 #NOTE: this constant along with the one in the decode above are estimated from the mem usage for the VAE and could change.
99 | model_management.free_memory(memory_used, self.device)
100 | free_memory = model_management.get_free_memory(self.device)
101 | batch_number = int(free_memory / memory_used)
102 | batch_number = max(1, batch_number)
103 | samples = torch.empty((pixel_samples.shape[0], self.latent_dim, round(pixel_samples.shape[2] // self.latent_scale), round(pixel_samples.shape[3] // self.latent_scale)), device="cpu")
104 | for x in range(0, pixel_samples.shape[0], batch_number):
105 | pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.vae_dtype).to(self.device)
106 | samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).cpu().float()
107 |
108 | except model_management.OOM_EXCEPTION as e:
109 | print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
110 | samples = self.encode_tiled_(pixel_samples)
111 |
112 | self.first_stage_model = self.first_stage_model.to(self.offload_device)
113 | return samples
114 |
--------------------------------------------------------------------------------
/VAE/models/LICENSE-Consistency-Decoder:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 OpenAI
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/VAE/models/LICENSE-Kandinsky-3:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/VAE/models/LICENSE-Latent-Diffusion:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/VAE/models/LICENSE-SAI:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Stability AI
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/VAE/models/LICENSE-SDV:
--------------------------------------------------------------------------------
1 | STABLE VIDEO DIFFUSION NON-COMMERCIAL COMMUNITY LICENSE AGREEMENT
2 | Dated: November 21, 2023
3 |
4 | “AUP” means the Stability AI Acceptable Use Policy available at https://stability.ai/use-policy, as may be updated from time to time.
5 |
6 | "Agreement" means the terms and conditions for use, reproduction, distribution and modification of the Software Products set forth herein.
7 | "Derivative Work(s)” means (a) any derivative work of the Software Products as recognized by U.S. copyright laws and (b) any modifications to a Model, and any other model created which is based on or derived from the Model or the Model’s output. For clarity, Derivative Works do not include the output of any Model.
8 | “Documentation” means any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software.
9 |
10 | "Licensee" or "you" means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity's behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf.
11 |
12 | "Stability AI" or "we" means Stability AI Ltd.
13 |
14 | "Software" means, collectively, Stability AI’s proprietary models and algorithms, including machine-learning models, trained model weights and other elements of the foregoing, made available under this Agreement.
15 |
16 | “Software Products” means Software and Documentation.
17 |
18 | By using or distributing any portion or element of the Software Products, you agree to be bound by this Agreement.
19 |
20 |
21 |
22 | License Rights and Redistribution.
23 | Subject to your compliance with this Agreement, the AUP (which is hereby incorporated herein by reference), and the Documentation, Stability AI grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license under Stability AI’s intellectual property or other rights owned by Stability AI embodied in the Software Products to reproduce, distribute, and create Derivative Works of the Software Products for purposes other than commercial or production use.
24 | b. If you distribute or make the Software Products, or any Derivative Works thereof, available to a third party, the Software Products, Derivative Works, or any portion thereof, respectively, will remain subject to this Agreement and you must (i) provide a copy of this Agreement to such third party, and (ii) retain the following attribution notice within a "Notice" text file distributed as a part of such copies: "Stable Video Diffusion is licensed under the Stable Video Diffusion Research License, Copyright (c) Stability AI Ltd. All Rights Reserved.” If you create a Derivative Work of a Software Product, you may add your own attribution notices to the Notice file included with the Software Product, provided that you clearly indicate which attributions apply to the Software Product and you must state in the NOTICE file that you changed the Software Product and how it was modified.
25 | 2. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE SOFTWARE PRODUCTS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE SOFTWARE PRODUCTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE SOFTWARE PRODUCTS AND ANY OUTPUT AND RESULTS.
26 | 3. Limitation of Liability. IN NO EVENT WILL STABILITY AI OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF STABILITY AI OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
27 | 3. Intellectual Property.
28 | a. No trademark licenses are granted under this Agreement, and in connection with the Software Products, neither Stability AI nor Licensee may use any name or mark owned by or associated with the other or any of its affiliates, except as required for reasonable and customary use in describing and redistributing the Software Products.
29 | Subject to Stability AI’s ownership of the Software Products and Derivative Works made by or for Stability AI, with respect to any Derivative Works that are made by you, as between you and Stability AI, you are and will be the owner of such Derivative Works.
30 | If you institute litigation or other proceedings against Stability AI (including a cross-claim or counterclaim in a lawsuit) alleging that the Software Products or associated outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Stability AI from and against any claim by any third party arising out of or related to your use or distribution of the Software Products in violation of this Agreement.
31 | 4. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Software Products and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Stability AI may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the Software Products. Sections 2-4 shall survive the termination of this Agreement.
32 |
--------------------------------------------------------------------------------
/VAE/models/LICENSE-Taming-Transformers:
--------------------------------------------------------------------------------
1 | Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy
4 | of this software and associated documentation files (the "Software"), to deal
5 | in the Software without restriction, including without limitation the rights
6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | copies of the Software, and to permit persons to whom the Software is
8 | furnished to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in all
11 | copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
14 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
15 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
16 | IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
17 | DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
18 | OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
19 | OR OTHER DEALINGS IN THE SOFTWARE./
20 |
--------------------------------------------------------------------------------
/VAE/models/consistencydecoder.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn.functional as F
4 | import torch.nn as nn
5 |
6 | """
7 | Code below ported from https://github.com/openai/consistencydecoder
8 | """
9 |
10 | def _extract_into_tensor(arr, timesteps, broadcast_shape):
11 | # from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/gaussian_diffusion.py#L895 """
12 | res = arr[timesteps.to(torch.int).cpu()].float().to(timesteps.device)
13 | dims_to_append = len(broadcast_shape) - len(res.shape)
14 | return res[(...,) + (None,) * dims_to_append]
15 |
16 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
17 | # from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/gaussian_diffusion.py#L45
18 | betas = []
19 | for i in range(num_diffusion_timesteps):
20 | t1 = i / num_diffusion_timesteps
21 | t2 = (i + 1) / num_diffusion_timesteps
22 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
23 | return torch.tensor(betas)
24 |
25 | class ConsistencyDecoder(torch.nn.Module):
26 | # From https://github.com/openai/consistencydecoder
27 | def __init__(self):
28 | super().__init__()
29 | self.model = ConvUNetVAE()
30 | self.n_distilled_steps = 64
31 |
32 | sigma_data = 0.5
33 | betas = betas_for_alpha_bar(
34 | 1024, lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
35 | )
36 | alphas = 1.0 - betas
37 | alphas_cumprod = torch.cumprod(alphas, dim=0)
38 | self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
39 | self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
40 | sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod)
41 | sigmas = torch.sqrt(1.0 / alphas_cumprod - 1)
42 | self.c_skip = (
43 | sqrt_recip_alphas_cumprod
44 | * sigma_data**2
45 | / (sigmas**2 + sigma_data**2)
46 | )
47 | self.c_out = sigmas * sigma_data / (sigmas**2 + sigma_data**2) ** 0.5
48 | self.c_in = sqrt_recip_alphas_cumprod / (sigmas**2 + sigma_data**2) ** 0.5
49 |
50 | @staticmethod
51 | def round_timesteps(timesteps, total_timesteps, n_distilled_steps, truncate_start=True):
52 | with torch.no_grad():
53 | space = torch.div(total_timesteps, n_distilled_steps, rounding_mode="floor")
54 | rounded_timesteps = (
55 | torch.div(timesteps, space, rounding_mode="floor") + 1
56 | ) * space
57 | if truncate_start:
58 | rounded_timesteps[rounded_timesteps == total_timesteps] -= space
59 | else:
60 | rounded_timesteps[rounded_timesteps == total_timesteps] -= space
61 | rounded_timesteps[rounded_timesteps == 0] += space
62 | return rounded_timesteps
63 |
64 | @staticmethod
65 | def ldm_transform_latent(z, extra_scale_factor=1):
66 | channel_means = [0.38862467, 0.02253063, 0.07381133, -0.0171294]
67 | channel_stds = [0.9654121, 1.0440036, 0.76147926, 0.77022034]
68 |
69 | if len(z.shape) != 4:
70 | raise ValueError()
71 |
72 | z = z * 0.18215
73 | channels = [z[:, i] for i in range(z.shape[1])]
74 |
75 | channels = [
76 | extra_scale_factor * (c - channel_means[i]) / channel_stds[i]
77 | for i, c in enumerate(channels)
78 | ]
79 | return torch.stack(channels, dim=1)
80 |
81 | @torch.no_grad()
82 | def decode(self, features: torch.Tensor, schedule=[1.0, 0.5]):
83 | features = self.ldm_transform_latent(features)
84 | ts = self.round_timesteps(
85 | torch.arange(0, 1024),
86 | 1024,
87 | self.n_distilled_steps,
88 | truncate_start=False,
89 | )
90 | shape = (
91 | features.size(0),
92 | 3,
93 | 8 * features.size(2),
94 | 8 * features.size(3),
95 | )
96 | x_start = torch.zeros(shape, device=features.device, dtype=features.dtype)
97 | schedule_timesteps = [int((1024 - 1) * s) for s in schedule]
98 | for i in schedule_timesteps:
99 | t = ts[i].item()
100 | t_ = torch.tensor([t] * features.shape[0], device=features.device)
101 | noise = torch.randn_like(x_start, device=features.device)
102 | x_start = (
103 | _extract_into_tensor(self.sqrt_alphas_cumprod, t_, x_start.shape)
104 | * x_start
105 | + _extract_into_tensor(
106 | self.sqrt_one_minus_alphas_cumprod, t_, x_start.shape
107 | )
108 | * noise
109 | )
110 | c_in = _extract_into_tensor(self.c_in, t_, x_start.shape)
111 | model_output = self.model((c_in * x_start).to(features.dtype), t_, features=features)
112 | B, C = x_start.shape[:2]
113 | model_output, _ = torch.split(model_output, C, dim=1)
114 | pred_xstart = (
115 | _extract_into_tensor(self.c_out, t_, x_start.shape) * model_output
116 | + _extract_into_tensor(self.c_skip, t_, x_start.shape) * x_start
117 | ).clamp(-1, 1)
118 | x_start = pred_xstart
119 | return x_start
120 |
121 | def encode(self, *args, **kwargs):
122 | raise NotImplementedError("ConsistencyDecoder can't be used for encoding!")
123 |
124 | """
125 | Model definitions ported from:
126 | https://gist.github.com/madebyollin/865fa6a18d9099351ddbdfbe7299ccbf
127 | https://gist.github.com/mrsteyk/74ad3ec2f6f823111ae4c90e168505ac.
128 | """
129 |
130 | class TimestepEmbedding(nn.Module):
131 | def __init__(self, n_time=1024, n_emb=320, n_out=1280) -> None:
132 | super().__init__()
133 | self.emb = nn.Embedding(n_time, n_emb)
134 | self.f_1 = nn.Linear(n_emb, n_out)
135 | self.f_2 = nn.Linear(n_out, n_out)
136 |
137 | def forward(self, x) -> torch.Tensor:
138 | x = self.emb(x)
139 | x = self.f_1(x)
140 | x = F.silu(x)
141 | return self.f_2(x)
142 |
143 |
144 | class ImageEmbedding(nn.Module):
145 | def __init__(self, in_channels=7, out_channels=320) -> None:
146 | super().__init__()
147 | self.f = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
148 |
149 | def forward(self, x) -> torch.Tensor:
150 | return self.f(x)
151 |
152 |
153 | class ImageUnembedding(nn.Module):
154 | def __init__(self, in_channels=320, out_channels=6) -> None:
155 | super().__init__()
156 | self.gn = nn.GroupNorm(32, in_channels)
157 | self.f = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
158 |
159 | def forward(self, x) -> torch.Tensor:
160 | return self.f(F.silu(self.gn(x)))
161 |
162 |
163 | class ConvResblock(nn.Module):
164 | def __init__(self, in_features=320, out_features=320) -> None:
165 | super().__init__()
166 | self.f_t = nn.Linear(1280, out_features * 2)
167 |
168 | self.gn_1 = nn.GroupNorm(32, in_features)
169 | self.f_1 = nn.Conv2d(in_features, out_features, kernel_size=3, padding=1)
170 |
171 | self.gn_2 = nn.GroupNorm(32, out_features)
172 | self.f_2 = nn.Conv2d(out_features, out_features, kernel_size=3, padding=1)
173 |
174 | skip_conv = in_features != out_features
175 | self.f_s = (
176 | nn.Conv2d(in_features, out_features, kernel_size=1, padding=0)
177 | if skip_conv
178 | else nn.Identity()
179 | )
180 |
181 | def forward(self, x, t):
182 | x_skip = x
183 | t = self.f_t(F.silu(t))
184 | t = t.chunk(2, dim=1)
185 | t_1 = t[0].unsqueeze(dim=2).unsqueeze(dim=3) + 1
186 | t_2 = t[1].unsqueeze(dim=2).unsqueeze(dim=3)
187 |
188 | gn_1 = F.silu(self.gn_1(x))
189 | f_1 = self.f_1(gn_1)
190 |
191 | gn_2 = self.gn_2(f_1)
192 |
193 | return self.f_s(x_skip) + self.f_2(F.silu(gn_2 * t_1 + t_2))
194 |
195 |
196 | # Also ConvResblock
197 | class Downsample(nn.Module):
198 | def __init__(self, in_channels=320) -> None:
199 | super().__init__()
200 | self.f_t = nn.Linear(1280, in_channels * 2)
201 |
202 | self.gn_1 = nn.GroupNorm(32, in_channels)
203 | self.f_1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
204 | self.gn_2 = nn.GroupNorm(32, in_channels)
205 |
206 | self.f_2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
207 |
208 | def forward(self, x, t) -> torch.Tensor:
209 | x_skip = x
210 |
211 | t = self.f_t(F.silu(t))
212 | t_1, t_2 = t.chunk(2, dim=1)
213 | t_1 = t_1.unsqueeze(2).unsqueeze(3) + 1
214 | t_2 = t_2.unsqueeze(2).unsqueeze(3)
215 |
216 | gn_1 = F.silu(self.gn_1(x))
217 | avg_pool2d = F.avg_pool2d(gn_1, kernel_size=(2, 2), stride=None)
218 | f_1 = self.f_1(avg_pool2d)
219 | gn_2 = self.gn_2(f_1)
220 |
221 | f_2 = self.f_2(F.silu(t_2 + (t_1 * gn_2)))
222 |
223 | return f_2 + F.avg_pool2d(x_skip, kernel_size=(2, 2), stride=None)
224 |
225 |
226 | # Also ConvResblock
227 | class Upsample(nn.Module):
228 | def __init__(self, in_channels=1024) -> None:
229 | super().__init__()
230 | self.f_t = nn.Linear(1280, in_channels * 2)
231 |
232 | self.gn_1 = nn.GroupNorm(32, in_channels)
233 | self.f_1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
234 | self.gn_2 = nn.GroupNorm(32, in_channels)
235 |
236 | self.f_2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
237 |
238 | def forward(self, x, t) -> torch.Tensor:
239 | x_skip = x
240 |
241 | t = self.f_t(F.silu(t))
242 | t_1, t_2 = t.chunk(2, dim=1)
243 | t_1 = t_1.unsqueeze(2).unsqueeze(3) + 1
244 | t_2 = t_2.unsqueeze(2).unsqueeze(3)
245 |
246 | gn_1 = F.silu(self.gn_1(x))
247 | upsample = F.interpolate(gn_1.float(), scale_factor=2, mode="nearest").to(gn_1.dtype)
248 |
249 | f_1 = self.f_1(upsample)
250 | gn_2 = self.gn_2(f_1)
251 |
252 | f_2 = self.f_2(F.silu(t_2 + (t_1 * gn_2)))
253 |
254 | return f_2 + F.interpolate(x_skip.float(), scale_factor=2, mode="nearest").to(x_skip.dtype)
255 |
256 |
257 | class ConvUNetVAE(nn.Module):
258 | def __init__(self) -> None:
259 | super().__init__()
260 | self.embed_image = ImageEmbedding()
261 | self.embed_time = TimestepEmbedding()
262 |
263 | down_0 = nn.ModuleList(
264 | [
265 | ConvResblock(320, 320),
266 | ConvResblock(320, 320),
267 | ConvResblock(320, 320),
268 | Downsample(320),
269 | ]
270 | )
271 | down_1 = nn.ModuleList(
272 | [
273 | ConvResblock(320, 640),
274 | ConvResblock(640, 640),
275 | ConvResblock(640, 640),
276 | Downsample(640),
277 | ]
278 | )
279 | down_2 = nn.ModuleList(
280 | [
281 | ConvResblock(640, 1024),
282 | ConvResblock(1024, 1024),
283 | ConvResblock(1024, 1024),
284 | Downsample(1024),
285 | ]
286 | )
287 | down_3 = nn.ModuleList(
288 | [
289 | ConvResblock(1024, 1024),
290 | ConvResblock(1024, 1024),
291 | ConvResblock(1024, 1024),
292 | ]
293 | )
294 | self.down = nn.ModuleList(
295 | [
296 | down_0,
297 | down_1,
298 | down_2,
299 | down_3,
300 | ]
301 | )
302 |
303 | self.mid = nn.ModuleList(
304 | [
305 | ConvResblock(1024, 1024),
306 | ConvResblock(1024, 1024),
307 | ]
308 | )
309 |
310 | up_3 = nn.ModuleList(
311 | [
312 | ConvResblock(1024 * 2, 1024),
313 | ConvResblock(1024 * 2, 1024),
314 | ConvResblock(1024 * 2, 1024),
315 | ConvResblock(1024 * 2, 1024),
316 | Upsample(1024),
317 | ]
318 | )
319 | up_2 = nn.ModuleList(
320 | [
321 | ConvResblock(1024 * 2, 1024),
322 | ConvResblock(1024 * 2, 1024),
323 | ConvResblock(1024 * 2, 1024),
324 | ConvResblock(1024 + 640, 1024),
325 | Upsample(1024),
326 | ]
327 | )
328 | up_1 = nn.ModuleList(
329 | [
330 | ConvResblock(1024 + 640, 640),
331 | ConvResblock(640 * 2, 640),
332 | ConvResblock(640 * 2, 640),
333 | ConvResblock(320 + 640, 640),
334 | Upsample(640),
335 | ]
336 | )
337 | up_0 = nn.ModuleList(
338 | [
339 | ConvResblock(320 + 640, 320),
340 | ConvResblock(320 * 2, 320),
341 | ConvResblock(320 * 2, 320),
342 | ConvResblock(320 * 2, 320),
343 | ]
344 | )
345 | self.up = nn.ModuleList(
346 | [
347 | up_0,
348 | up_1,
349 | up_2,
350 | up_3,
351 | ]
352 | )
353 |
354 | self.output = ImageUnembedding()
355 |
356 | def forward(self, x, t, features) -> torch.Tensor:
357 | x = torch.cat([x, F.interpolate(features.float(),scale_factor=8,mode="nearest").to(features.dtype)], dim=1)
358 | t = self.embed_time(t)
359 | x = self.embed_image(x)
360 |
361 | skips = [x]
362 | for down in self.down:
363 | for block in down:
364 | x = block(x, t)
365 | skips.append(x)
366 |
367 | for i in range(2):
368 | x = self.mid[i](x, t)
369 |
370 | for up in self.up[::-1]:
371 | for block in up:
372 | if isinstance(block, ConvResblock):
373 | x = torch.concat([x, skips.pop()], dim=1)
374 | x = block(x, t)
375 |
376 | return self.output(x)
377 |
--------------------------------------------------------------------------------
/VAE/models/vq.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from torch import nn
4 | from einops import rearrange
5 |
6 | from .kl import Encoder, Decoder
7 |
8 | class VQModel(nn.Module):
9 | def __init__(self,
10 | config,
11 | remap=None,
12 | sane_index_shape=False, # tell vector quantizer to return indices as bhw
13 | ):
14 | super().__init__()
15 | self.embed_dim = config["embed_dim"]
16 | self.n_embed = config["n_embed"]
17 | self.encoder = Encoder(**config)
18 | self.decoder = Decoder(**config)
19 | self.quantize = VectorQuantizer(self.n_embed, self.embed_dim, beta=0.25,
20 | remap=remap,
21 | sane_index_shape=sane_index_shape)
22 | self.quant_conv = torch.nn.Conv2d(config["z_channels"], self.embed_dim, 1)
23 | self.post_quant_conv = torch.nn.Conv2d(self.embed_dim, config["z_channels"], 1)
24 |
25 | def encode(self, x):
26 | h = self.encoder(x)
27 | h = self.quant_conv(h)
28 | return h
29 |
30 | def decode(self, h, force_not_quantize=False):
31 | # also go through quantization layer
32 | if not force_not_quantize:
33 | quant, emb_loss, info = self.quantize(h)
34 | else:
35 | quant = h
36 | quant = self.post_quant_conv(quant)
37 | dec = self.decoder(quant)
38 | return dec
39 |
40 | def forward(self, input, return_pred_indices=False):
41 | quant, diff, (_,_,ind) = self.encode(input)
42 | dec = self.decode(quant)
43 | if return_pred_indices:
44 | return dec, diff, ind
45 | return dec, diff
46 |
47 |
48 | class VectorQuantizer(nn.Module):
49 | """
50 | Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
51 | avoids costly matrix multiplications and allows for post-hoc remapping of indices.
52 | """
53 | # NOTE: due to a bug the beta term was applied to the wrong term. for
54 | # backwards compatibility we use the buggy version by default, but you can
55 | # specify legacy=False to fix it.
56 | def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random",
57 | sane_index_shape=False, legacy=True):
58 | super().__init__()
59 | self.n_e = n_e
60 | self.e_dim = e_dim
61 | self.beta = beta
62 | self.legacy = legacy
63 |
64 | self.embedding = nn.Embedding(self.n_e, self.e_dim)
65 | self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
66 |
67 | self.remap = remap
68 | if self.remap is not None:
69 | self.register_buffer("used", torch.tensor(np.load(self.remap)))
70 | self.re_embed = self.used.shape[0]
71 | self.unknown_index = unknown_index # "random" or "extra" or integer
72 | if self.unknown_index == "extra":
73 | self.unknown_index = self.re_embed
74 | self.re_embed = self.re_embed+1
75 | print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
76 | f"Using {self.unknown_index} for unknown indices.")
77 | else:
78 | self.re_embed = n_e
79 |
80 | self.sane_index_shape = sane_index_shape
81 |
82 | def remap_to_used(self, inds):
83 | ishape = inds.shape
84 | assert len(ishape)>1
85 | inds = inds.reshape(ishape[0],-1)
86 | used = self.used.to(inds)
87 | match = (inds[:,:,None]==used[None,None,...]).long()
88 | new = match.argmax(-1)
89 | unknown = match.sum(2)<1
90 | if self.unknown_index == "random":
91 | new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
92 | else:
93 | new[unknown] = self.unknown_index
94 | return new.reshape(ishape)
95 |
96 | def unmap_to_all(self, inds):
97 | ishape = inds.shape
98 | assert len(ishape)>1
99 | inds = inds.reshape(ishape[0],-1)
100 | used = self.used.to(inds)
101 | if self.re_embed > self.used.shape[0]: # extra token
102 | inds[inds>=self.used.shape[0]] = 0 # simply set to zero
103 | back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
104 | return back.reshape(ishape)
105 |
106 | def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
107 | assert temp is None or temp==1.0, "Only for interface compatible with Gumbel"
108 | assert rescale_logits==False, "Only for interface compatible with Gumbel"
109 | assert return_logits==False, "Only for interface compatible with Gumbel"
110 | # reshape z -> (batch, height, width, channel) and flatten
111 | z = rearrange(z, 'b c h w -> b h w c').contiguous()
112 | z_flattened = z.view(-1, self.e_dim)
113 | # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
114 |
115 | d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
116 | torch.sum(self.embedding.weight**2, dim=1) - 2 * \
117 | torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
118 |
119 | min_encoding_indices = torch.argmin(d, dim=1)
120 | z_q = self.embedding(min_encoding_indices).view(z.shape)
121 | perplexity = None
122 | min_encodings = None
123 |
124 | # compute loss for embedding
125 | if not self.legacy:
126 | loss = self.beta * torch.mean((z_q.detach()-z)**2) + \
127 | torch.mean((z_q - z.detach()) ** 2)
128 | else:
129 | loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
130 | torch.mean((z_q - z.detach()) ** 2)
131 |
132 | # preserve gradients
133 | z_q = z + (z_q - z).detach()
134 |
135 | # reshape back to match original input shape
136 | z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
137 |
138 | if self.remap is not None:
139 | min_encoding_indices = min_encoding_indices.reshape(z.shape[0],-1) # add batch axis
140 | min_encoding_indices = self.remap_to_used(min_encoding_indices)
141 | min_encoding_indices = min_encoding_indices.reshape(-1,1) # flatten
142 |
143 | if self.sane_index_shape:
144 | min_encoding_indices = min_encoding_indices.reshape(
145 | z_q.shape[0], z_q.shape[2], z_q.shape[3])
146 |
147 | return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
148 |
149 | def get_codebook_entry(self, indices, shape):
150 | # shape specifying (batch, height, width, channel)
151 | if self.remap is not None:
152 | indices = indices.reshape(shape[0],-1) # add batch axis
153 | indices = self.unmap_to_all(indices)
154 | indices = indices.reshape(-1) # flatten again
155 |
156 | # get quantized latent vectors
157 | z_q = self.embedding(indices)
158 |
159 | if shape is not None:
160 | z_q = z_q.view(shape)
161 | # reshape back to match original input shape
162 | z_q = z_q.permute(0, 3, 1, 2).contiguous()
163 |
164 | return z_q
165 |
--------------------------------------------------------------------------------
/VAE/nodes.py:
--------------------------------------------------------------------------------
1 | import folder_paths
2 |
3 | from .conf import vae_conf
4 | from .loader import EXVAE
5 |
6 | from ..utils.dtype import string_to_dtype
7 |
8 | dtypes = [
9 | "auto",
10 | "FP32",
11 | "FP16",
12 | "BF16"
13 | ]
14 |
15 | class ExtraVAELoader:
16 | @classmethod
17 | def INPUT_TYPES(s):
18 | return {
19 | "required": {
20 | "vae_name": (folder_paths.get_filename_list("vae"),),
21 | "vae_type": (list(vae_conf.keys()), {"default":"kl-f8"}),
22 | "dtype" : (dtypes,),
23 | }
24 | }
25 | RETURN_TYPES = ("VAE",)
26 | FUNCTION = "load_vae"
27 | CATEGORY = "ExtraModels"
28 | TITLE = "ExtraVAELoader"
29 |
30 | def load_vae(self, vae_name, vae_type, dtype):
31 | model_path = folder_paths.get_full_path("vae", vae_name)
32 | model_conf = vae_conf[vae_type]
33 | vae = EXVAE(model_path, model_conf, string_to_dtype(dtype, "vae"))
34 | return (vae,)
35 |
36 | NODE_CLASS_MAPPINGS = {
37 | "ExtraVAELoader" : ExtraVAELoader,
38 | }
39 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | # only import if running as a custom node
2 | try:
3 | import comfy.utils
4 | except ImportError:
5 | pass
6 | else:
7 | NODE_CLASS_MAPPINGS = {}
8 |
9 | # Deci Diffusion
10 | # from .DeciDiffusion.nodes import NODE_CLASS_MAPPINGS as DeciDiffusion_Nodes
11 | # NODE_CLASS_MAPPINGS.update(DeciDiffusion_Nodes)
12 |
13 | # HunYuan
14 | from .HunYuan.nodes import NODE_CLASS_MAPPINGS as HunYuan_Nodes
15 | NODE_CLASS_MAPPINGS.update(HunYuan_Nodes)
16 |
17 | # DiT
18 | from .DiT.nodes import NODE_CLASS_MAPPINGS as DiT_Nodes
19 | NODE_CLASS_MAPPINGS.update(DiT_Nodes)
20 |
21 | # PixArt
22 | from .PixArt.nodes import NODE_CLASS_MAPPINGS as PixArt_Nodes
23 | NODE_CLASS_MAPPINGS.update(PixArt_Nodes)
24 |
25 | # T5
26 | from .T5.nodes import NODE_CLASS_MAPPINGS as T5_Nodes
27 | NODE_CLASS_MAPPINGS.update(T5_Nodes)
28 |
29 | # VAE
30 | from .VAE.nodes import NODE_CLASS_MAPPINGS as VAE_Nodes
31 | NODE_CLASS_MAPPINGS.update(VAE_Nodes)
32 |
33 | NODE_DISPLAY_NAME_MAPPINGS = {k:v.TITLE for k,v in NODE_CLASS_MAPPINGS.items()}
34 | __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS']
35 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | timm==0.6.13
2 | sentencepiece>=0.1.97
3 | transformers>=4.34.1
4 | accelerate>=0.23.0
5 | einops
--------------------------------------------------------------------------------
/utils/dtype.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from comfy import model_management
3 |
4 | def string_to_dtype(s="none", mode=None):
5 | s = s.lower().strip()
6 | if s in ["default", "as-is"]:
7 | return None
8 | elif s in ["auto", "auto (comfy)"]:
9 | if mode == "vae":
10 | return model_management.vae_device()
11 | elif mode == "text_encoder":
12 | return model_management.text_encoder_dtype()
13 | elif mode == "unet":
14 | return model_management.unet_dtype()
15 | else:
16 | raise NotImplementedError(f"Unknown dtype mode '{mode}'")
17 | elif s in ["none", "auto (hf)", "auto (hf/bnb)"]:
18 | return None
19 | elif s in ["fp32", "float32", "float"]:
20 | return torch.float32
21 | elif s in ["bf16", "bfloat16"]:
22 | return torch.bfloat16
23 | elif s in ["fp16", "float16", "half"]:
24 | return torch.float16
25 | elif "fp8" in s or "float8" in s:
26 | if "e5m2" in s:
27 | return torch.float8_e5m2
28 | elif "e4m3" in s:
29 | return torch.float8_e4m3fn
30 | else:
31 | raise NotImplementedError(f"Unknown 8bit dtype '{s}'")
32 | elif "bnb" in s:
33 | assert s in ["bnb8bit", "bnb4bit"], f"Unknown bnb mode '{s}'"
34 | return s
35 | elif s is None:
36 | return None
37 | else:
38 | raise NotImplementedError(f"Unknown dtype '{s}'")
39 |
--------------------------------------------------------------------------------