├── HunYuanDiT ├── mt5_tokenizer │ ├── special_tokens_map.json │ ├── spiece.model │ ├── tokenizer_config.json │ └── config.json ├── tokenizer │ ├── special_tokens_map.json │ ├── tokenizer_config.json │ └── config.json ├── config_mt5.json ├── config_clip.json ├── models │ ├── poolers.py │ ├── norm_layers.py │ ├── embedders.py │ └── posemb_layers.py ├── conf.py ├── loader.py ├── nodes.py └── tenc.py ├── T5 ├── t5_tokenizer │ ├── spiece.model │ ├── special_tokens_map.json │ └── tokenizer_config.json ├── t5v11-xxl_config.json ├── nodes.py ├── loader.py └── t5v11.py ├── requirements.txt ├── utils ├── nodes.py ├── dtype.py ├── offload.py └── IPEX │ └── attention.py ├── MiaoBi ├── tokenizer.py ├── tokenizer │ ├── special_tokens_map.json │ └── tokenizer_config.json └── nodes.py ├── VAE ├── models │ ├── LICENSE-SAI │ ├── LICENSE-Consistency-Decoder │ ├── LICENSE-Taming-Transformers │ ├── LICENSE-Latent-Diffusion │ ├── LICENSE-SDV │ └── vq.py ├── nodes.py ├── conf.py └── loader.py ├── __init__.py ├── Sana ├── models │ ├── act.py │ ├── utils.py │ └── norms.py ├── loader.py ├── conf.py ├── nodes.py └── diffusers_convert.py ├── DiT ├── loader.py ├── nodes.py └── conf.py ├── .gitignore ├── Gemma └── nodes.py └── PixArt ├── models └── utils.py ├── conf.py ├── lora.py ├── loader.py ├── nodes.py └── diffusers_convert.py /HunYuanDiT/mt5_tokenizer/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | {"eos_token": "", "unk_token": "", "pad_token": ""} -------------------------------------------------------------------------------- /T5/t5_tokenizer/spiece.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/city96/ComfyUI_ExtraModels/HEAD/T5/t5_tokenizer/spiece.model -------------------------------------------------------------------------------- /HunYuanDiT/mt5_tokenizer/spiece.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/city96/ComfyUI_ExtraModels/HEAD/HunYuanDiT/mt5_tokenizer/spiece.model -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | timm>=0.6.13 2 | sentencepiece>=0.1.97 3 | transformers>=4.34.1 4 | accelerate>=0.23.0 5 | einops>=0.6.0 6 | protobuf>=3.20.3 7 | bitsandbytes>=0.41.0 8 | -------------------------------------------------------------------------------- /HunYuanDiT/tokenizer/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "cls_token": "[CLS]", 3 | "mask_token": "[MASK]", 4 | "pad_token": "[PAD]", 5 | "sep_token": "[SEP]", 6 | "unk_token": "[UNK]" 7 | } 8 | -------------------------------------------------------------------------------- /utils/nodes.py: -------------------------------------------------------------------------------- 1 | NODE_CLASS_MAPPINGS = {} 2 | 3 | from .offload import NODE_CLASS_MAPPINGS as Offload_Nodes 4 | NODE_CLASS_MAPPINGS.update(Offload_Nodes) 5 | 6 | for name, node in NODE_CLASS_MAPPINGS.items(): 7 | cat = node.CATEGORY 8 | if not cat.startswith("ExtraModels/"): 9 | node.CATEGORY = f"ExtraModels/{cat}" 10 | 11 | __all__ = ["NODE_CLASS_MAPPINGS"] 12 | -------------------------------------------------------------------------------- /HunYuanDiT/mt5_tokenizer/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | {"eos_token": "", "unk_token": "", "pad_token": "", "extra_ids": 0, "additional_special_tokens": null, "special_tokens_map_file": "/home/patrick/.cache/torch/transformers/685ac0ca8568ec593a48b61b0a3c272beee9bc194a3c7241d15dcadb5f875e53.f76030f3ec1b96a8199b2593390c610e76ca8028ef3d24680000619ffb646276", "tokenizer_file": null, "name_or_path": "google/mt5-small"} -------------------------------------------------------------------------------- /HunYuanDiT/tokenizer/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "cls_token": "[CLS]", 3 | "do_basic_tokenize": true, 4 | "do_lower_case": true, 5 | "mask_token": "[MASK]", 6 | "name_or_path": "hfl/chinese-roberta-wwm-ext", 7 | "never_split": null, 8 | "pad_token": "[PAD]", 9 | "sep_token": "[SEP]", 10 | "special_tokens_map_file": "/home/chenweifeng/.cache/huggingface/hub/models--hfl--chinese-roberta-wwm-ext/snapshots/5c58d0b8ec1d9014354d691c538661bf00bfdb44/special_tokens_map.json", 11 | "strip_accents": null, 12 | "tokenize_chinese_chars": true, 13 | "tokenizer_class": "BertTokenizer", 14 | "unk_token": "[UNK]", 15 | "model_max_length": 77 16 | } 17 | -------------------------------------------------------------------------------- /MiaoBi/tokenizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from transformers import AutoTokenizer 3 | from comfy.sd1_clip import SDTokenizer 4 | 5 | class MiaoBiTokenizer(SDTokenizer): 6 | def __init__(self, **kwargs): 7 | super().__init__(**kwargs) 8 | tokenizer_path = os.path.join( 9 | os.path.dirname(os.path.realpath(__file__)), 10 | f"tokenizer" 11 | ) 12 | # remote code ok, see `clip_tokenizer_roberta.py`, no ckpt vocab 13 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True) 14 | 15 | empty = self.tokenizer('')["input_ids"] 16 | if self.tokens_start: 17 | self.start_token = empty[0] 18 | self.end_token = empty[1] 19 | else: 20 | self.start_token = None 21 | self.end_token = empty[0] 22 | 23 | vocab = self.tokenizer.get_vocab() 24 | self.inv_vocab = {v: k for k, v in vocab.items()} -------------------------------------------------------------------------------- /MiaoBi/tokenizer/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "cls_token": { 3 | "content": "[CLS]", 4 | "lstrip": false, 5 | "normalized": false, 6 | "rstrip": false, 7 | "single_word": false 8 | }, 9 | "mask_token": { 10 | "content": "[MASK]", 11 | "lstrip": false, 12 | "normalized": false, 13 | "rstrip": false, 14 | "single_word": false 15 | }, 16 | "pad_token": { 17 | "content": "[PAD]", 18 | "lstrip": false, 19 | "normalized": false, 20 | "rstrip": false, 21 | "single_word": false 22 | }, 23 | "sep_token": { 24 | "content": "[SEP]", 25 | "lstrip": false, 26 | "normalized": false, 27 | "rstrip": false, 28 | "single_word": false 29 | }, 30 | "unk_token": { 31 | "content": "[UNK]", 32 | "lstrip": false, 33 | "normalized": false, 34 | "rstrip": false, 35 | "single_word": false 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /T5/t5v11-xxl_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "google/t5-v1_1-xxl", 3 | "architectures": [ 4 | "T5EncoderModel" 5 | ], 6 | "d_ff": 10240, 7 | "d_kv": 64, 8 | "d_model": 4096, 9 | "decoder_start_token_id": 0, 10 | "dense_act_fn": "gelu_new", 11 | "dropout_rate": 0.1, 12 | "eos_token_id": 1, 13 | "feed_forward_proj": "gated-gelu", 14 | "initializer_factor": 1.0, 15 | "is_encoder_decoder": true, 16 | "is_gated_act": true, 17 | "layer_norm_epsilon": 1e-06, 18 | "model_type": "t5", 19 | "num_decoder_layers": 24, 20 | "num_heads": 64, 21 | "num_layers": 24, 22 | "output_past": true, 23 | "pad_token_id": 0, 24 | "relative_attention_max_distance": 128, 25 | "relative_attention_num_buckets": 32, 26 | "tie_word_embeddings": false, 27 | "torch_dtype": "float32", 28 | "transformers_version": "4.21.1", 29 | "use_cache": true, 30 | "vocab_size": 32128 31 | } 32 | -------------------------------------------------------------------------------- /HunYuanDiT/config_mt5.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "mt5", 3 | "architectures": [ 4 | "MT5EncoderModel" 5 | ], 6 | "classifier_dropout": 0.0, 7 | "d_ff": 5120, 8 | "d_kv": 64, 9 | "d_model": 2048, 10 | "decoder_start_token_id": 0, 11 | "dense_act_fn": "gelu_new", 12 | "dropout_rate": 0.1, 13 | "eos_token_id": 1, 14 | "feed_forward_proj": "gated-gelu", 15 | "initializer_factor": 1.0, 16 | "is_encoder_decoder": true, 17 | "is_gated_act": true, 18 | "layer_norm_epsilon": 1e-06, 19 | "model_type": "mt5", 20 | "num_decoder_layers": 24, 21 | "num_heads": 32, 22 | "num_layers": 24, 23 | "output_past": true, 24 | "pad_token_id": 0, 25 | "relative_attention_max_distance": 128, 26 | "relative_attention_num_buckets": 32, 27 | "tie_word_embeddings": false, 28 | "tokenizer_class": "T5Tokenizer", 29 | "torch_dtype": "float16", 30 | "transformers_version": "4.40.2", 31 | "use_cache": true, 32 | "vocab_size": 250112 33 | } 34 | -------------------------------------------------------------------------------- /HunYuanDiT/mt5_tokenizer/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "mt5", 3 | "architectures": [ 4 | "MT5ForConditionalGeneration" 5 | ], 6 | "classifier_dropout": 0.0, 7 | "d_ff": 5120, 8 | "d_kv": 64, 9 | "d_model": 2048, 10 | "decoder_start_token_id": 0, 11 | "dense_act_fn": "gelu_new", 12 | "dropout_rate": 0.1, 13 | "eos_token_id": 1, 14 | "feed_forward_proj": "gated-gelu", 15 | "initializer_factor": 1.0, 16 | "is_encoder_decoder": true, 17 | "is_gated_act": true, 18 | "layer_norm_epsilon": 1e-06, 19 | "model_type": "mt5", 20 | "num_decoder_layers": 24, 21 | "num_heads": 32, 22 | "num_layers": 24, 23 | "output_past": true, 24 | "pad_token_id": 0, 25 | "relative_attention_max_distance": 128, 26 | "relative_attention_num_buckets": 32, 27 | "tie_word_embeddings": false, 28 | "tokenizer_class": "T5Tokenizer", 29 | "torch_dtype": "float16", 30 | "transformers_version": "4.40.2", 31 | "use_cache": true, 32 | "vocab_size": 250112 33 | } 34 | -------------------------------------------------------------------------------- /HunYuanDiT/config_clip.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "hfl/chinese-roberta-wwm-ext-large", 3 | "architectures": [ 4 | "BertModel" 5 | ], 6 | "attention_probs_dropout_prob": 0.1, 7 | "bos_token_id": 0, 8 | "classifier_dropout": null, 9 | "directionality": "bidi", 10 | "eos_token_id": 2, 11 | "hidden_act": "gelu", 12 | "hidden_dropout_prob": 0.1, 13 | "hidden_size": 1024, 14 | "initializer_range": 0.02, 15 | "intermediate_size": 4096, 16 | "layer_norm_eps": 1e-12, 17 | "max_position_embeddings": 512, 18 | "model_type": "bert", 19 | "num_attention_heads": 16, 20 | "num_hidden_layers": 24, 21 | "output_past": true, 22 | "pad_token_id": 0, 23 | "pooler_fc_size": 768, 24 | "pooler_num_attention_heads": 12, 25 | "pooler_num_fc_layers": 3, 26 | "pooler_size_per_head": 128, 27 | "pooler_type": "first_token_transform", 28 | "position_embedding_type": "absolute", 29 | "torch_dtype": "float32", 30 | "transformers_version": "4.22.1", 31 | "type_vocab_size": 2, 32 | "use_cache": true, 33 | "vocab_size": 47020 34 | } 35 | -------------------------------------------------------------------------------- /HunYuanDiT/tokenizer/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "hfl/chinese-roberta-wwm-ext-large", 3 | "architectures": [ 4 | "BertModel" 5 | ], 6 | "attention_probs_dropout_prob": 0.1, 7 | "bos_token_id": 0, 8 | "classifier_dropout": null, 9 | "directionality": "bidi", 10 | "eos_token_id": 2, 11 | "hidden_act": "gelu", 12 | "hidden_dropout_prob": 0.1, 13 | "hidden_size": 1024, 14 | "initializer_range": 0.02, 15 | "intermediate_size": 4096, 16 | "layer_norm_eps": 1e-12, 17 | "max_position_embeddings": 512, 18 | "model_type": "bert", 19 | "num_attention_heads": 16, 20 | "num_hidden_layers": 24, 21 | "output_past": true, 22 | "pad_token_id": 0, 23 | "pooler_fc_size": 768, 24 | "pooler_num_attention_heads": 12, 25 | "pooler_num_fc_layers": 3, 26 | "pooler_size_per_head": 128, 27 | "pooler_type": "first_token_transform", 28 | "position_embedding_type": "absolute", 29 | "torch_dtype": "float32", 30 | "transformers_version": "4.22.1", 31 | "type_vocab_size": 2, 32 | "use_cache": true, 33 | "vocab_size": 47020 34 | } 35 | -------------------------------------------------------------------------------- /VAE/models/LICENSE-SAI: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Stability AI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /VAE/models/LICENSE-Consistency-Decoder: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 OpenAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /VAE/models/LICENSE-Taming-Transformers: -------------------------------------------------------------------------------- 1 | Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 14 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 15 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 16 | IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, 17 | DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 18 | OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE 19 | OR OTHER DEALINGS IN THE SOFTWARE./ 20 | -------------------------------------------------------------------------------- /VAE/models/LICENSE-Latent-Diffusion: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /utils/dtype.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from comfy import model_management 3 | 4 | def string_to_dtype(s="none", mode=None): 5 | s = s.lower().strip() 6 | if s in ["default", "as-is"]: 7 | return None 8 | elif s in ["auto", "auto (comfy)"]: 9 | if mode == "vae": 10 | return model_management.vae_device() 11 | elif mode == "text_encoder": 12 | return model_management.text_encoder_dtype() 13 | elif mode == "unet": 14 | return model_management.unet_dtype() 15 | else: 16 | raise NotImplementedError(f"Unknown dtype mode '{mode}'") 17 | elif s in ["none", "auto (hf)", "auto (hf/bnb)"]: 18 | return None 19 | elif s in ["fp32", "float32", "float"]: 20 | return torch.float32 21 | elif s in ["bf16", "bfloat16"]: 22 | return torch.bfloat16 23 | elif s in ["fp16", "float16", "half"]: 24 | return torch.float16 25 | elif "fp8" in s or "float8" in s: 26 | if "e5m2" in s: 27 | return torch.float8_e5m2 28 | elif "e4m3" in s: 29 | return torch.float8_e4m3fn 30 | else: 31 | raise NotImplementedError(f"Unknown 8bit dtype '{s}'") 32 | elif "bnb" in s: 33 | assert s in ["bnb8bit", "bnb4bit"], f"Unknown bnb mode '{s}'" 34 | return s 35 | elif s is None: 36 | return None 37 | else: 38 | raise NotImplementedError(f"Unknown dtype '{s}'") 39 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | # only import if running as a custom node 2 | try: 3 | import comfy.utils 4 | except ImportError: 5 | pass 6 | else: 7 | NODE_CLASS_MAPPINGS = {} 8 | 9 | # Deci Diffusion 10 | # from .DeciDiffusion.nodes import NODE_CLASS_MAPPINGS as DeciDiffusion_Nodes 11 | # NODE_CLASS_MAPPINGS.update(DeciDiffusion_Nodes) 12 | 13 | # DiT 14 | from .DiT.nodes import NODE_CLASS_MAPPINGS as DiT_Nodes 15 | NODE_CLASS_MAPPINGS.update(DiT_Nodes) 16 | 17 | # PixArt 18 | from .PixArt.nodes import NODE_CLASS_MAPPINGS as PixArt_Nodes 19 | NODE_CLASS_MAPPINGS.update(PixArt_Nodes) 20 | 21 | # T5 22 | from .T5.nodes import NODE_CLASS_MAPPINGS as T5_Nodes 23 | NODE_CLASS_MAPPINGS.update(T5_Nodes) 24 | 25 | # HYDiT 26 | from .HunYuanDiT.nodes import NODE_CLASS_MAPPINGS as HunYuanDiT_Nodes 27 | NODE_CLASS_MAPPINGS.update(HunYuanDiT_Nodes) 28 | 29 | # VAE 30 | from .VAE.nodes import NODE_CLASS_MAPPINGS as VAE_Nodes 31 | NODE_CLASS_MAPPINGS.update(VAE_Nodes) 32 | 33 | # MiaoBi 34 | from .MiaoBi.nodes import NODE_CLASS_MAPPINGS as MiaoBi_Nodes 35 | NODE_CLASS_MAPPINGS.update(MiaoBi_Nodes) 36 | 37 | # Extra 38 | from .utils.nodes import NODE_CLASS_MAPPINGS as Extra_Nodes 39 | NODE_CLASS_MAPPINGS.update(Extra_Nodes) 40 | 41 | # Sana 42 | from .Sana.nodes import NODE_CLASS_MAPPINGS as Sana_Nodes 43 | NODE_CLASS_MAPPINGS.update(Sana_Nodes) 44 | 45 | # Gemma 46 | from .Gemma.nodes import NODE_CLASS_MAPPINGS as Gemma_Nodes 47 | NODE_CLASS_MAPPINGS.update(Gemma_Nodes) 48 | 49 | NODE_DISPLAY_NAME_MAPPINGS = {k:v.TITLE for k,v in NODE_CLASS_MAPPINGS.items()} 50 | __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS'] 51 | 52 | -------------------------------------------------------------------------------- /MiaoBi/tokenizer/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "added_tokens_decoder": { 3 | "0": { 4 | "content": "[PAD]", 5 | "lstrip": false, 6 | "normalized": false, 7 | "rstrip": false, 8 | "single_word": false, 9 | "special": true 10 | }, 11 | "100": { 12 | "content": "[UNK]", 13 | "lstrip": false, 14 | "normalized": false, 15 | "rstrip": false, 16 | "single_word": false, 17 | "special": true 18 | }, 19 | "101": { 20 | "content": "[CLS]", 21 | "lstrip": false, 22 | "normalized": false, 23 | "rstrip": false, 24 | "single_word": false, 25 | "special": true 26 | }, 27 | "102": { 28 | "content": "[SEP]", 29 | "lstrip": false, 30 | "normalized": false, 31 | "rstrip": false, 32 | "single_word": false, 33 | "special": true 34 | }, 35 | "103": { 36 | "content": "[MASK]", 37 | "lstrip": false, 38 | "normalized": false, 39 | "rstrip": false, 40 | "single_word": false, 41 | "special": true 42 | } 43 | }, 44 | "auto_map": { 45 | "AutoTokenizer": [ 46 | "clip_tokenizer_roberta.CLIPTokenizerRoberta", 47 | null 48 | ] 49 | }, 50 | "clean_up_tokenization_spaces": true, 51 | "cls_token": "[CLS]", 52 | "do_basic_tokenize": true, 53 | "do_lower_case": true, 54 | "mask_token": "[MASK]", 55 | "model_max_length": 77, 56 | "never_split": null, 57 | "pad_token": "[PAD]", 58 | "sep_token": "[SEP]", 59 | "strip_accents": null, 60 | "tokenize_chinese_chars": true, 61 | "tokenizer_class": "CLIPTokenizerRoberta", 62 | "unk_token": "[UNK]", 63 | "use_fast": true 64 | } 65 | -------------------------------------------------------------------------------- /HunYuanDiT/models/poolers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class AttentionPool(nn.Module): 7 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 8 | super().__init__() 9 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim ** 0.5) 10 | self.k_proj = nn.Linear(embed_dim, embed_dim) 11 | self.q_proj = nn.Linear(embed_dim, embed_dim) 12 | self.v_proj = nn.Linear(embed_dim, embed_dim) 13 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 14 | self.num_heads = num_heads 15 | 16 | def forward(self, x): 17 | x = x.permute(1, 0, 2) # NLC -> LNC 18 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC 19 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC 20 | x, _ = F.multi_head_attention_forward( 21 | query=x[:1], key=x, value=x, 22 | embed_dim_to_check=x.shape[-1], 23 | num_heads=self.num_heads, 24 | q_proj_weight=self.q_proj.weight, 25 | k_proj_weight=self.k_proj.weight, 26 | v_proj_weight=self.v_proj.weight, 27 | in_proj_weight=None, 28 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 29 | bias_k=None, 30 | bias_v=None, 31 | add_zero_attn=False, 32 | dropout_p=0, 33 | out_proj_weight=self.c_proj.weight, 34 | out_proj_bias=self.c_proj.bias, 35 | use_separate_proj_weight=True, 36 | training=self.training, 37 | need_weights=False 38 | ) 39 | return x.squeeze(0) 40 | -------------------------------------------------------------------------------- /T5/t5_tokenizer/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | {"eos_token": "", "unk_token": "", "pad_token": "", "additional_special_tokens": ["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", ""]} -------------------------------------------------------------------------------- /HunYuanDiT/conf.py: -------------------------------------------------------------------------------- 1 | """ 2 | List of all HYDiT model types / settings 3 | """ 4 | from argparse import Namespace 5 | hydit_args = Namespace(**{ # normally from argparse 6 | "infer_mode": "torch", 7 | "norm": "layer", 8 | "learn_sigma": True, 9 | "text_states_dim": 1024, 10 | "text_states_dim_t5": 2048, 11 | "text_len": 77, 12 | "text_len_t5": 256, 13 | }) 14 | 15 | hydit_conf = { 16 | "G/2": { # Seems to be the main one 17 | "unet_config": { 18 | "depth" : 40, 19 | "num_heads" : 16, 20 | "patch_size" : 2, 21 | "hidden_size" : 1408, 22 | "mlp_ratio" : 4.3637, 23 | "input_size": (1024//8, 1024//8), 24 | "args": hydit_args, 25 | }, 26 | "sampling_settings" : { 27 | "beta_schedule" : "linear", 28 | "linear_start" : 0.00085, 29 | "linear_end" : 0.03, 30 | "timesteps" : 1000, 31 | }, 32 | }, 33 | "G/2-1.2": { 34 | "unet_config": { 35 | "depth" : 40, 36 | "num_heads" : 16, 37 | "patch_size" : 2, 38 | "hidden_size" : 1408, 39 | "mlp_ratio" : 4.3637, 40 | "input_size": (1024//8, 1024//8), 41 | "cond_style": False, 42 | "cond_res" : False, 43 | "args": hydit_args, 44 | }, 45 | "sampling_settings" : { 46 | "beta_schedule" : "linear", 47 | "linear_start" : 0.00085, 48 | "linear_end" : 0.018, 49 | "timesteps" : 1000, 50 | }, 51 | } 52 | } 53 | 54 | # these are the same as regular DiT, I think 55 | from ..DiT.conf import dit_conf 56 | for name in ["XL/2", "L/2", "B/2"]: 57 | hydit_conf[name] = { 58 | "unet_config": dit_conf[name]["unet_config"].copy(), 59 | "sampling_settings": hydit_conf["G/2"]["sampling_settings"], 60 | } 61 | hydit_conf[name]["unet_config"]["args"] = hydit_args 62 | -------------------------------------------------------------------------------- /T5/t5_tokenizer/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | {"eos_token": "", "unk_token": "", "pad_token": "", "extra_ids": 100, "additional_special_tokens": ["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", ""], "model_max_length": 512, "name_or_path": "t5-small"} -------------------------------------------------------------------------------- /HunYuanDiT/models/norm_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class RMSNorm(nn.Module): 6 | def __init__(self, dim: int, elementwise_affine=True, eps: float = 1e-6): 7 | """ 8 | Initialize the RMSNorm normalization layer. 9 | 10 | Args: 11 | dim (int): The dimension of the input tensor. 12 | eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. 13 | 14 | Attributes: 15 | eps (float): A small value added to the denominator for numerical stability. 16 | weight (nn.Parameter): Learnable scaling parameter. 17 | 18 | """ 19 | super().__init__() 20 | self.eps = eps 21 | if elementwise_affine: 22 | self.weight = nn.Parameter(torch.ones(dim)) 23 | 24 | def _norm(self, x): 25 | """ 26 | Apply the RMSNorm normalization to the input tensor. 27 | 28 | Args: 29 | x (torch.Tensor): The input tensor. 30 | 31 | Returns: 32 | torch.Tensor: The normalized tensor. 33 | 34 | """ 35 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 36 | 37 | def forward(self, x): 38 | """ 39 | Forward pass through the RMSNorm layer. 40 | 41 | Args: 42 | x (torch.Tensor): The input tensor. 43 | 44 | Returns: 45 | torch.Tensor: The output tensor after applying RMSNorm. 46 | 47 | """ 48 | output = self._norm(x.float()).type_as(x) 49 | if hasattr(self, "weight"): 50 | output = output * self.weight 51 | return output 52 | 53 | 54 | class GroupNorm32(nn.GroupNorm): 55 | def __init__(self, num_groups, num_channels, eps=1e-5, dtype=None): 56 | super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps, dtype=dtype) 57 | 58 | def forward(self, x): 59 | y = super().forward(x).to(x.dtype) 60 | return y 61 | 62 | def normalization(channels, dtype=None): 63 | """ 64 | Make a standard normalization layer. 65 | :param channels: number of input channels. 66 | :return: an nn.Module for normalization. 67 | """ 68 | return GroupNorm32(num_channels=channels, num_groups=32, dtype=dtype) 69 | -------------------------------------------------------------------------------- /Sana/models/act.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | import copy 18 | 19 | import torch.nn as nn 20 | 21 | __all__ = ["build_act", "get_act_name"] 22 | 23 | # register activation function here 24 | # name: module, kwargs with default values 25 | REGISTERED_ACT_DICT: dict[str, tuple[type, dict[str, any]]] = { 26 | "relu": (nn.ReLU, {"inplace": True}), 27 | "relu6": (nn.ReLU6, {"inplace": True}), 28 | "hswish": (nn.Hardswish, {"inplace": True}), 29 | "hsigmoid": (nn.Hardsigmoid, {"inplace": True}), 30 | "swish": (nn.SiLU, {"inplace": True}), 31 | "silu": (nn.SiLU, {"inplace": True}), 32 | "tanh": (nn.Tanh, {}), 33 | "sigmoid": (nn.Sigmoid, {}), 34 | "gelu": (nn.GELU, {"approximate": "tanh"}), 35 | "mish": (nn.Mish, {"inplace": True}), 36 | "identity": (nn.Identity, {}), 37 | } 38 | 39 | 40 | def build_act(name: str or None, **kwargs) -> nn.Module or None: 41 | if name in REGISTERED_ACT_DICT: 42 | act_cls, default_args = copy.deepcopy(REGISTERED_ACT_DICT[name]) 43 | for key in default_args: 44 | if key in kwargs: 45 | default_args[key] = kwargs[key] 46 | return act_cls(**default_args) 47 | elif name is None or name.lower() == "none": 48 | return None 49 | else: 50 | raise ValueError(f"do not support: {name}") 51 | 52 | 53 | def get_act_name(act: nn.Module or None) -> str or None: 54 | if act is None: 55 | return None 56 | module2name = {} 57 | for key, config in REGISTERED_ACT_DICT.items(): 58 | module2name[config[0].__name__] = key 59 | return module2name.get(type(act).__name__, "unknown") 60 | -------------------------------------------------------------------------------- /DiT/loader.py: -------------------------------------------------------------------------------- 1 | import comfy.supported_models_base 2 | import comfy.latent_formats 3 | import comfy.model_patcher 4 | import comfy.model_base 5 | import comfy.utils 6 | import torch 7 | from comfy import model_management 8 | 9 | class EXM_DiT(comfy.supported_models_base.BASE): 10 | unet_config = {} 11 | unet_extra_config = {} 12 | latent_format = comfy.latent_formats.SD15 13 | 14 | def __init__(self, model_conf): 15 | self.unet_config = model_conf.get("unet_config", {}) 16 | self.sampling_settings = model_conf.get("sampling_settings", {}) 17 | self.latent_format = self.latent_format() 18 | # UNET is handled by extension 19 | self.unet_config["disable_unet_model_creation"] = True 20 | 21 | def model_type(self, state_dict, prefix=""): 22 | return comfy.model_base.ModelType.EPS 23 | 24 | def load_dit(model_path, model_conf): 25 | state_dict = comfy.utils.load_torch_file(model_path) 26 | state_dict = state_dict.get("model", state_dict) 27 | parameters = comfy.utils.calculate_parameters(state_dict) 28 | unet_dtype = model_management.unet_dtype(model_params=parameters) 29 | load_device = comfy.model_management.get_torch_device() 30 | offload_device = comfy.model_management.unet_offload_device() 31 | 32 | # ignore fp8/etc and use directly for now 33 | manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device) 34 | if manual_cast_dtype: 35 | print(f"DiT: falling back to {manual_cast_dtype}") 36 | unet_dtype = manual_cast_dtype 37 | 38 | model_conf["unet_config"]["num_classes"] = state_dict["y_embedder.embedding_table.weight"].shape[0] - 1 # adj. for empty 39 | 40 | model_conf = EXM_DiT(model_conf) 41 | model = comfy.model_base.BaseModel( 42 | model_conf, 43 | model_type=comfy.model_base.ModelType.EPS, 44 | device=model_management.get_torch_device() 45 | ) 46 | 47 | from .model import DiT 48 | model.diffusion_model = DiT(**model_conf.unet_config) 49 | 50 | model.diffusion_model.load_state_dict(state_dict) 51 | model.diffusion_model.dtype = unet_dtype 52 | model.diffusion_model.eval() 53 | model.diffusion_model.to(unet_dtype) 54 | 55 | model_patcher = comfy.model_patcher.ModelPatcher( 56 | model, 57 | load_device = load_device, 58 | offload_device = offload_device, 59 | ) 60 | return model_patcher 61 | -------------------------------------------------------------------------------- /VAE/nodes.py: -------------------------------------------------------------------------------- 1 | import folder_paths 2 | import torch 3 | import comfy 4 | 5 | from .conf import vae_conf 6 | from .loader import EXVAE 7 | 8 | from ..utils.dtype import string_to_dtype 9 | 10 | dtypes = [ 11 | "auto", 12 | "FP32", 13 | "FP16", 14 | "BF16" 15 | ] 16 | 17 | MAX_RESOLUTION=16384 18 | 19 | class ExtraVAELoader: 20 | @classmethod 21 | def INPUT_TYPES(s): 22 | return { 23 | "required": { 24 | "vae_name": (folder_paths.get_filename_list("vae"),), 25 | "vae_type": (list(vae_conf.keys()), {"default":"kl-f8"}), 26 | "dtype" : (dtypes,), 27 | } 28 | } 29 | RETURN_TYPES = ("VAE",) 30 | FUNCTION = "load_vae" 31 | CATEGORY = "ExtraModels" 32 | TITLE = "ExtraVAELoader" 33 | 34 | def load_vae(self, vae_name, vae_type, dtype): 35 | model_path = folder_paths.get_full_path("vae", vae_name) 36 | model_conf = vae_conf[vae_type] 37 | vae = EXVAE(model_path, model_conf, string_to_dtype(dtype, "vae")) 38 | return (vae,) 39 | 40 | 41 | class EmptyDCAELatentImage: 42 | def __init__(self): 43 | self.device = comfy.model_management.intermediate_device() 44 | 45 | @classmethod 46 | def INPUT_TYPES(s): 47 | return { 48 | "required": { 49 | "width": ("INT", {"default": 512, "min": 16, "max": MAX_RESOLUTION, "step": 8, "tooltip": "The width of the latent images in pixels."}), 50 | "height": ("INT", {"default": 512, "min": 16, "max": MAX_RESOLUTION, "step": 8, "tooltip": "The height of the latent images in pixels."}), 51 | "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}) 52 | } 53 | } 54 | RETURN_TYPES = ("LATENT",) 55 | OUTPUT_TOOLTIPS = ("The empty latent image batch.",) 56 | FUNCTION = "generate" 57 | TITLE = "Empty DCAE Latent Image" 58 | 59 | CATEGORY = "latent" 60 | DESCRIPTION = "Create a new batch of empty latent images to be denoised via sampling." 61 | 62 | def generate(self, width, height, batch_size=1): 63 | latent = torch.zeros([batch_size, 32, height // 32, width // 32], device=self.device) 64 | return ({"samples":latent}, ) 65 | 66 | 67 | NODE_CLASS_MAPPINGS = { 68 | "ExtraVAELoader" : ExtraVAELoader, 69 | "EmptyDCAELatentImage" : EmptyDCAELatentImage, 70 | } 71 | -------------------------------------------------------------------------------- /utils/offload.py: -------------------------------------------------------------------------------- 1 | # 2 | # Force model to always use specified device 3 | # City96 [Apache2] 4 | # 5 | import types 6 | import torch 7 | import comfy.model_management 8 | 9 | class OverrideDevice: 10 | @classmethod 11 | def INPUT_TYPES(s): 12 | devices = ["cpu",] 13 | for k in range(0, torch.cuda.device_count()): 14 | devices.append(f"cuda:{k}") 15 | 16 | return { 17 | "required": { 18 | "device": (devices, {"default":"cpu"}), 19 | } 20 | } 21 | 22 | FUNCTION = "patch" 23 | CATEGORY = "other" 24 | 25 | def override(self, model, model_attr, device): 26 | # set model/patcher attributes 27 | model.device = device 28 | patcher = getattr(model, "patcher", model) #.clone() 29 | for name in ["device", "load_device", "offload_device", "current_device", "output_device"]: 30 | setattr(patcher, name, device) 31 | 32 | # move model to device 33 | py_model = getattr(model, model_attr) 34 | py_model.to = types.MethodType(torch.nn.Module.to, py_model) 35 | py_model.to(device) 36 | 37 | # remove ability to move model 38 | def to(*args, **kwargs): 39 | pass 40 | py_model.to = types.MethodType(to, py_model) 41 | return (model,) 42 | 43 | def patch(self, *args, **kwargs): 44 | raise NotImplementedError 45 | 46 | class OverrideCLIPDevice(OverrideDevice): 47 | @classmethod 48 | def INPUT_TYPES(s): 49 | k = super().INPUT_TYPES() 50 | k["required"]["clip"] = ("CLIP",) 51 | return k 52 | 53 | RETURN_TYPES = ("CLIP",) 54 | TITLE = "Force/Set CLIP Device" 55 | 56 | def patch(self, clip, device): 57 | return self.override(clip, "cond_stage_model", torch.device(device)) 58 | 59 | class OverrideVAEDevice(OverrideDevice): 60 | @classmethod 61 | def INPUT_TYPES(s): 62 | k = super().INPUT_TYPES() 63 | k["required"]["vae"] = ("VAE",) 64 | return k 65 | 66 | RETURN_TYPES = ("VAE",) 67 | TITLE = "Force/Set VAE Device" 68 | 69 | def patch(self, vae, device): 70 | return self.override(vae, "first_stage_model", torch.device(device)) 71 | 72 | 73 | NODE_CLASS_MAPPINGS = { 74 | "OverrideCLIPDevice": OverrideCLIPDevice, 75 | "OverrideVAEDevice": OverrideVAEDevice, 76 | } 77 | NODE_DISPLAY_NAME_MAPPINGS = {k:v.TITLE for k,v in NODE_CLASS_MAPPINGS.items()} 78 | -------------------------------------------------------------------------------- /MiaoBi/nodes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import folder_paths 3 | 4 | import comfy.sd 5 | import comfy.diffusers_load 6 | from .tokenizer import MiaoBiTokenizer 7 | 8 | class MiaoBiCLIPLoader: 9 | @classmethod 10 | def INPUT_TYPES(s): 11 | return { 12 | "required": { 13 | "clip_name": (folder_paths.get_filename_list("clip"),), 14 | } 15 | } 16 | 17 | RETURN_TYPES = ("CLIP",) 18 | FUNCTION = "load_mbclip" 19 | CATEGORY = "ExtraModels/MiaoBi" 20 | TITLE = "MiaoBi CLIP Loader" 21 | 22 | def load_mbclip(self, clip_name): 23 | clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION 24 | clip_path = folder_paths.get_full_path("clip", clip_name) 25 | clip = comfy.sd.load_clip( 26 | ckpt_paths=[clip_path], 27 | embedding_directory=folder_paths.get_folder_paths("embeddings"), 28 | clip_type=clip_type 29 | ) 30 | # override tokenizer 31 | clip.tokenizer.clip_l = MiaoBiTokenizer() 32 | return (clip,) 33 | 34 | 35 | class MiaoBiDiffusersLoader: 36 | @classmethod 37 | def INPUT_TYPES(cls): 38 | paths = [] 39 | for search_path in folder_paths.get_folder_paths("diffusers"): 40 | if os.path.exists(search_path): 41 | for root, subdir, files in os.walk(search_path, followlinks=True): 42 | if "model_index.json" in files: 43 | paths.append(os.path.relpath(root, start=search_path)) 44 | 45 | return { 46 | "required": { 47 | "model_path": (paths,), 48 | } 49 | } 50 | 51 | RETURN_TYPES = ("MODEL", "CLIP", "VAE") 52 | FUNCTION = "load_mbcheckpoint" 53 | CATEGORY = "ExtraModels/MiaoBi" 54 | TITLE = "MiaoBi Checkpoint Loader (Diffusers)" 55 | 56 | def load_mbcheckpoint(self, model_path, output_vae=True, output_clip=True): 57 | for search_path in folder_paths.get_folder_paths("diffusers"): 58 | if os.path.exists(search_path): 59 | path = os.path.join(search_path, model_path) 60 | if os.path.exists(path): 61 | model_path = path 62 | break 63 | unet, clip, vae = comfy.diffusers_load.load_diffusers( 64 | model_path, 65 | output_vae = output_vae, 66 | output_clip = output_clip, 67 | embedding_directory = folder_paths.get_folder_paths("embeddings") 68 | ) 69 | # override tokenizer 70 | clip.tokenizer.clip_l = MiaoBiTokenizer() 71 | return (unet, clip, vae) 72 | 73 | NODE_CLASS_MAPPINGS = { 74 | "MiaoBiCLIPLoader": MiaoBiCLIPLoader, 75 | "MiaoBiDiffusersLoader": MiaoBiDiffusersLoader, 76 | } -------------------------------------------------------------------------------- /T5/nodes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import folder_paths 5 | 6 | from .loader import load_t5 7 | from ..utils.dtype import string_to_dtype 8 | 9 | # initialize custom folder path 10 | os.makedirs( 11 | os.path.join(folder_paths.models_dir,"t5"), 12 | exist_ok = True, 13 | ) 14 | folder_paths.folder_names_and_paths["t5"] = ( 15 | [ 16 | os.path.join(folder_paths.models_dir,"t5"), 17 | *folder_paths.folder_names_and_paths.get("t5", [[],set()])[0] 18 | ], 19 | folder_paths.supported_pt_extensions 20 | ) 21 | 22 | dtypes = [ 23 | "default", 24 | "auto (comfy)", 25 | "FP32", 26 | "FP16", 27 | # Note: remove these at some point 28 | "bnb8bit", 29 | "bnb4bit", 30 | ] 31 | try: torch.float8_e5m2 32 | except AttributeError: print("Torch version too old for FP8") 33 | else: dtypes += ["FP8 E4M3", "FP8 E5M2"] 34 | 35 | class T5v11Loader: 36 | @classmethod 37 | def INPUT_TYPES(s): 38 | devices = ["auto", "cpu", "gpu"] 39 | # hack for using second GPU as offload 40 | for k in range(1, torch.cuda.device_count()): 41 | devices.append(f"cuda:{k}") 42 | return { 43 | "required": { 44 | "t5v11_name": (folder_paths.get_filename_list("t5"),), 45 | "t5v11_ver": (["xxl"],), 46 | "path_type": (["folder", "file"],), 47 | "device": (devices, {"default":"cpu"}), 48 | "dtype": (dtypes,), 49 | } 50 | } 51 | RETURN_TYPES = ("T5",) 52 | FUNCTION = "load_model" 53 | CATEGORY = "ExtraModels/T5" 54 | TITLE = "T5v1.1 Loader" 55 | 56 | def load_model(self, t5v11_name, t5v11_ver, path_type, device, dtype): 57 | if "bnb" in dtype: 58 | assert device == "gpu" or device.startswith("cuda"), "BitsAndBytes only works on CUDA! Set device to 'gpu'." 59 | dtype = string_to_dtype(dtype, "text_encoder") 60 | if device == "cpu": 61 | assert dtype in [None, torch.float32], f"Can't use dtype '{dtype}' with CPU! Set dtype to 'default'." 62 | 63 | return (load_t5( 64 | model_type = "t5v11", 65 | model_ver = t5v11_ver, 66 | model_path = folder_paths.get_full_path("t5", t5v11_name), 67 | path_type = path_type, 68 | device = device, 69 | dtype = dtype, 70 | ),) 71 | 72 | class T5TextEncode: 73 | @classmethod 74 | def INPUT_TYPES(s): 75 | return { 76 | "required": { 77 | "text": ("STRING", {"multiline": True}), 78 | "T5": ("T5",), 79 | } 80 | } 81 | 82 | RETURN_TYPES = ("CONDITIONING",) 83 | FUNCTION = "encode" 84 | CATEGORY = "ExtraModels/T5" 85 | TITLE = "T5 Text Encode" 86 | 87 | def encode(self, text, T5=None): 88 | tokens = T5.tokenize(text) 89 | cond = T5.encode_from_tokens(tokens) 90 | return ([[cond, {}]], ) 91 | 92 | NODE_CLASS_MAPPINGS = { 93 | "T5v11Loader" : T5v11Loader, 94 | "T5TextEncode" : T5TextEncode, 95 | } 96 | -------------------------------------------------------------------------------- /HunYuanDiT/loader.py: -------------------------------------------------------------------------------- 1 | import comfy.supported_models_base 2 | import comfy.latent_formats 3 | import comfy.model_patcher 4 | import comfy.model_base 5 | import comfy.utils 6 | import comfy.conds 7 | import torch 8 | from comfy import model_management 9 | from tqdm import tqdm 10 | 11 | class EXM_HYDiT(comfy.supported_models_base.BASE): 12 | unet_config = {} 13 | unet_extra_config = {} 14 | latent_format = comfy.latent_formats.SDXL 15 | 16 | def __init__(self, model_conf): 17 | self.unet_config = model_conf.get("unet_config", {}) 18 | self.sampling_settings = model_conf.get("sampling_settings", {}) 19 | self.latent_format = self.latent_format() 20 | # UNET is handled by extension 21 | self.unet_config["disable_unet_model_creation"] = True 22 | 23 | def model_type(self, state_dict, prefix=""): 24 | return comfy.model_base.ModelType.V_PREDICTION 25 | 26 | class EXM_HYDiT_Model(comfy.model_base.BaseModel): 27 | def __init__(self, *args, **kwargs): 28 | super().__init__(*args, **kwargs) 29 | 30 | def extra_conds(self, **kwargs): 31 | out = super().extra_conds(**kwargs) 32 | 33 | for name in ["context_t5", "context_mask", "context_t5_mask"]: 34 | out[name] = comfy.conds.CONDRegular(kwargs[name]) 35 | 36 | src_size_cond = kwargs.get("src_size_cond", None) 37 | if src_size_cond is not None: 38 | out["src_size_cond"] = comfy.conds.CONDRegular(torch.tensor(src_size_cond)) 39 | 40 | return out 41 | 42 | def load_hydit(model_path, model_conf): 43 | state_dict = comfy.utils.load_torch_file(model_path) 44 | state_dict = state_dict.get("model", state_dict) 45 | 46 | parameters = comfy.utils.calculate_parameters(state_dict) 47 | unet_dtype = model_management.unet_dtype(model_params=parameters) 48 | load_device = comfy.model_management.get_torch_device() 49 | offload_device = comfy.model_management.unet_offload_device() 50 | 51 | # ignore fp8/etc and use directly for now 52 | manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device) 53 | if manual_cast_dtype: 54 | print(f"HunYuanDiT: falling back to {manual_cast_dtype}") 55 | unet_dtype = manual_cast_dtype 56 | 57 | model_conf = EXM_HYDiT(model_conf) 58 | model = EXM_HYDiT_Model( 59 | model_conf, 60 | model_type=comfy.model_base.ModelType.V_PREDICTION, 61 | device=model_management.get_torch_device() 62 | ) 63 | 64 | from .models.models import HunYuanDiT 65 | model.diffusion_model = HunYuanDiT( 66 | **model_conf.unet_config, 67 | log_fn=tqdm.write, 68 | ) 69 | 70 | model.diffusion_model.load_state_dict(state_dict) 71 | model.diffusion_model.dtype = unet_dtype 72 | model.diffusion_model.eval() 73 | model.diffusion_model.to(unet_dtype) 74 | 75 | model_patcher = comfy.model_patcher.ModelPatcher( 76 | model, 77 | load_device = load_device, 78 | offload_device = offload_device, 79 | ) 80 | return model_patcher 81 | -------------------------------------------------------------------------------- /DiT/nodes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import folder_paths 5 | 6 | from .conf import dit_conf 7 | from .loader import load_dit 8 | 9 | class DitCheckpointLoader: 10 | @classmethod 11 | def INPUT_TYPES(s): 12 | return { 13 | "required": { 14 | "ckpt_name": (folder_paths.get_filename_list("checkpoints"),), 15 | "model": (list(dit_conf.keys()),), 16 | "image_size": ([256, 512],), 17 | # "num_classes": ("INT", {"default": 1000, "min": 0,}), 18 | } 19 | } 20 | RETURN_TYPES = ("MODEL",) 21 | RETURN_NAMES = ("model",) 22 | FUNCTION = "load_checkpoint" 23 | CATEGORY = "ExtraModels/DiT" 24 | TITLE = "DitCheckpointLoader" 25 | 26 | def load_checkpoint(self, ckpt_name, model, image_size): 27 | ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) 28 | model_conf = dit_conf[model] 29 | model_conf["unet_config"]["input_size"] = image_size // 8 30 | # model_conf["unet_config"]["num_classes"] = num_classes 31 | dit = load_dit( 32 | model_path = ckpt_path, 33 | model_conf = model_conf, 34 | ) 35 | return (dit,) 36 | 37 | # todo: this needs frontend code to display properly 38 | def get_label_data(label_file="labels/imagenet1000.json"): 39 | label_path = os.path.join( 40 | os.path.dirname(os.path.realpath(__file__)), 41 | label_file, 42 | ) 43 | label_data = {0: "None"} 44 | with open(label_path, "r") as f: 45 | label_data = json.loads(f.read()) 46 | return label_data 47 | label_data = get_label_data() 48 | 49 | class DiTCondLabelSelect: 50 | @classmethod 51 | def INPUT_TYPES(s): 52 | global label_data 53 | return { 54 | "required": { 55 | "model" : ("MODEL",), 56 | "label_name": (list(label_data.values()),), 57 | } 58 | } 59 | 60 | RETURN_TYPES = ("CONDITIONING",) 61 | RETURN_NAMES = ("class",) 62 | FUNCTION = "cond_label" 63 | CATEGORY = "ExtraModels/DiT" 64 | TITLE = "DiTCondLabelSelect" 65 | 66 | def cond_label(self, model, label_name): 67 | global label_data 68 | class_labels = [int(k) for k,v in label_data.items() if v == label_name] 69 | y = torch.tensor([[class_labels[0]]]).to(torch.int) 70 | return ([[y, {}]], ) 71 | 72 | class DiTCondLabelEmpty: 73 | @classmethod 74 | def INPUT_TYPES(s): 75 | global label_data 76 | return { 77 | "required": { 78 | "model" : ("MODEL",), 79 | } 80 | } 81 | 82 | RETURN_TYPES = ("CONDITIONING",) 83 | RETURN_NAMES = ("empty",) 84 | FUNCTION = "cond_empty" 85 | CATEGORY = "ExtraModels/DiT" 86 | TITLE = "DiTCondLabelEmpty" 87 | 88 | def cond_empty(self, model): 89 | # [ID of last class + 1] == [num_classes] 90 | y_null = model.model.model_config.unet_config["num_classes"] 91 | y = torch.tensor([[y_null]]).to(torch.int) 92 | return ([[y, {}]], ) 93 | 94 | NODE_CLASS_MAPPINGS = { 95 | "DitCheckpointLoader" : DitCheckpointLoader, 96 | "DiTCondLabelSelect" : DiTCondLabelSelect, 97 | "DiTCondLabelEmpty" : DiTCondLabelEmpty, 98 | } 99 | -------------------------------------------------------------------------------- /DiT/conf.py: -------------------------------------------------------------------------------- 1 | """ 2 | List of all DiT model types / settings 3 | """ 4 | sampling_settings = { 5 | "beta_schedule" : "sqrt_linear", 6 | "linear_start" : 0.0001, 7 | "linear_end" : 0.02, 8 | "timesteps" : 1000, 9 | } 10 | 11 | dit_conf = { 12 | "XL/2": { # DiT_XL_2 13 | "unet_config": { 14 | "depth" : 28, 15 | "num_heads" : 16, 16 | "patch_size" : 2, 17 | "hidden_size" : 1152, 18 | }, 19 | "sampling_settings" : sampling_settings, 20 | }, 21 | "XL/4": { # DiT_XL_4 22 | "unet_config": { 23 | "depth" : 28, 24 | "num_heads" : 16, 25 | "patch_size" : 4, 26 | "hidden_size" : 1152, 27 | }, 28 | "sampling_settings" : sampling_settings, 29 | }, 30 | "XL/8": { # DiT_XL_8 31 | "unet_config": { 32 | "depth" : 28, 33 | "num_heads" : 16, 34 | "patch_size" : 8, 35 | "hidden_size" : 1152, 36 | }, 37 | "sampling_settings" : sampling_settings, 38 | }, 39 | "L/2": { # DiT_L_2 40 | "unet_config": { 41 | "depth" : 24, 42 | "num_heads" : 16, 43 | "patch_size" : 2, 44 | "hidden_size" : 1024, 45 | }, 46 | "sampling_settings" : sampling_settings, 47 | }, 48 | "L/4": { # DiT_L_4 49 | "unet_config": { 50 | "depth" : 24, 51 | "num_heads" : 16, 52 | "patch_size" : 4, 53 | "hidden_size" : 1024, 54 | }, 55 | "sampling_settings" : sampling_settings, 56 | }, 57 | "L/8": { # DiT_L_8 58 | "unet_config": { 59 | "depth" : 24, 60 | "num_heads" : 16, 61 | "patch_size" : 8, 62 | "hidden_size" : 1024, 63 | }, 64 | "sampling_settings" : sampling_settings, 65 | }, 66 | "B/2": { # DiT_B_2 67 | "unet_config": { 68 | "depth" : 12, 69 | "num_heads" : 12, 70 | "patch_size" : 2, 71 | "hidden_size" : 768, 72 | }, 73 | "sampling_settings" : sampling_settings, 74 | }, 75 | "B/4": { # DiT_B_4 76 | "unet_config": { 77 | "depth" : 12, 78 | "num_heads" : 12, 79 | "patch_size" : 4, 80 | "hidden_size" : 768, 81 | }, 82 | "sampling_settings" : sampling_settings, 83 | }, 84 | "B/8": { # DiT_B_8 85 | "unet_config": { 86 | "depth" : 12, 87 | "num_heads" : 12, 88 | "patch_size" : 8, 89 | "hidden_size" : 768, 90 | }, 91 | "sampling_settings" : sampling_settings, 92 | }, 93 | "S/2": { # DiT_S_2 94 | "unet_config": { 95 | "depth" : 12, 96 | "num_heads" : 6, 97 | "patch_size" : 2, 98 | "hidden_size" : 384, 99 | }, 100 | "sampling_settings" : sampling_settings, 101 | }, 102 | "S/4": { # DiT_S_4 103 | "unet_config": { 104 | "depth" : 12, 105 | "num_heads" : 6, 106 | "patch_size" : 4, 107 | "hidden_size" : 384, 108 | }, 109 | "sampling_settings" : sampling_settings, 110 | }, 111 | "S/8": { # DiT_S_8 112 | "unet_config": { 113 | "depth" : 12, 114 | "num_heads" : 6, 115 | "patch_size" : 8, 116 | "hidden_size" : 384, 117 | }, 118 | "sampling_settings" : sampling_settings, 119 | }, 120 | } 121 | -------------------------------------------------------------------------------- /Sana/loader.py: -------------------------------------------------------------------------------- 1 | import comfy.supported_models_base 2 | import comfy.latent_formats 3 | import comfy.model_patcher 4 | import comfy.model_base 5 | import comfy.utils 6 | import comfy.conds 7 | import torch 8 | import math 9 | from comfy import model_management 10 | from comfy.latent_formats import LatentFormat 11 | from .diffusers_convert import convert_state_dict 12 | 13 | 14 | class SanaLatent(LatentFormat): 15 | latent_channels = 32 16 | def __init__(self): 17 | self.scale_factor = 0.41407 18 | 19 | 20 | class EXM_Sana(comfy.supported_models_base.BASE): 21 | unet_config = {} 22 | unet_extra_config = {} 23 | latent_format = SanaLatent 24 | 25 | def __init__(self, model_conf): 26 | self.model_target = model_conf.get("target") 27 | self.unet_config = model_conf.get("unet_config", {}) 28 | self.sampling_settings = model_conf.get("sampling_settings", {}) 29 | self.latent_format = self.latent_format() 30 | # UNET is handled by extension 31 | self.unet_config["disable_unet_model_creation"] = True 32 | 33 | def model_type(self, state_dict, prefix=""): 34 | return comfy.model_base.ModelType.FLOW 35 | 36 | 37 | class EXM_Sana_Model(comfy.model_base.BaseModel): 38 | def __init__(self, *args, **kwargs): 39 | super().__init__(*args, **kwargs) 40 | 41 | def extra_conds(self, **kwargs): 42 | out = super().extra_conds(**kwargs) 43 | 44 | cn_hint = kwargs.get("cn_hint", None) 45 | if cn_hint is not None: 46 | out["cn_hint"] = comfy.conds.CONDRegular(cn_hint) 47 | 48 | return out 49 | 50 | 51 | def load_sana(model_path, model_conf): 52 | state_dict = comfy.utils.load_torch_file(model_path) 53 | state_dict = state_dict.get("model", state_dict) 54 | 55 | # prefix 56 | for prefix in ["model.diffusion_model.",]: 57 | if any(True for x in state_dict if x.startswith(prefix)): 58 | state_dict = {k[len(prefix):]:v for k,v in state_dict.items()} 59 | 60 | # diffusers 61 | if "adaln_single.linear.weight" in state_dict: 62 | state_dict = convert_state_dict(state_dict) # Diffusers 63 | 64 | parameters = comfy.utils.calculate_parameters(state_dict) 65 | unet_dtype = comfy.model_management.unet_dtype() 66 | load_device = comfy.model_management.get_torch_device() 67 | offload_device = comfy.model_management.unet_offload_device() 68 | 69 | # ignore fp8/etc and use directly for now 70 | manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device) 71 | if manual_cast_dtype: 72 | print(f"Sana: falling back to {manual_cast_dtype}") 73 | unet_dtype = manual_cast_dtype 74 | 75 | model_conf = EXM_Sana(model_conf) # convert to object 76 | model = EXM_Sana_Model( # same as comfy.model_base.BaseModel 77 | model_conf, 78 | model_type=comfy.model_base.ModelType.FLOW, 79 | device=model_management.get_torch_device() 80 | ) 81 | 82 | if model_conf.model_target == "SanaMS": 83 | from .models.sana_multi_scale import SanaMS 84 | model.diffusion_model = SanaMS(**model_conf.unet_config) 85 | else: 86 | raise NotImplementedError(f"Unknown model target '{model_conf.model_target}'") 87 | 88 | m, u = model.diffusion_model.load_state_dict(state_dict, strict=False) 89 | if len(m) > 0: print("Missing UNET keys", m) 90 | if len(u) > 0: print("Leftover UNET keys", u) 91 | model.diffusion_model.dtype = unet_dtype 92 | model.diffusion_model.eval() 93 | model.diffusion_model.to(unet_dtype) 94 | 95 | model_patcher = comfy.model_patcher.ModelPatcher( 96 | model, 97 | load_device = load_device, 98 | offload_device = offload_device, 99 | ) 100 | return model_patcher 101 | -------------------------------------------------------------------------------- /Sana/models/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | from collections.abc import Iterable 18 | from itertools import repeat 19 | from typing import Union, Tuple 20 | 21 | import torch 22 | from torch.utils.checkpoint import checkpoint, checkpoint_sequential 23 | 24 | 25 | def _ntuple(n): 26 | def parse(x): 27 | if isinstance(x, Iterable) and not isinstance(x, str): 28 | return x 29 | return tuple(repeat(x, n)) 30 | 31 | return parse 32 | 33 | 34 | to_1tuple = _ntuple(1) 35 | to_2tuple = _ntuple(2) 36 | 37 | 38 | def auto_grad_checkpoint(module, *args, **kwargs): 39 | if getattr(module, "grad_checkpointing", False): 40 | if isinstance(module, Iterable): 41 | gc_step = module[0].grad_checkpointing_step 42 | return checkpoint_sequential(module, gc_step, *args, **kwargs) 43 | else: 44 | return checkpoint(module, *args, **kwargs) 45 | return module(*args, **kwargs) 46 | 47 | 48 | def checkpoint_sequential(functions, step, input, *args, **kwargs): 49 | 50 | # Hack for keyword-only parameter in a python 2.7-compliant way 51 | preserve = kwargs.pop("preserve_rng_state", True) 52 | if kwargs: 53 | raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)) 54 | 55 | def run_function(start, end, functions): 56 | def forward(input): 57 | for j in range(start, end + 1): 58 | input = functions[j](input, *args) 59 | return input 60 | 61 | return forward 62 | 63 | if isinstance(functions, torch.nn.Sequential): 64 | functions = list(functions.children()) 65 | 66 | # the last chunk has to be non-volatile 67 | end = -1 68 | segment = len(functions) // step 69 | for start in range(0, step * (segment - 1), step): 70 | end = start + step - 1 71 | input = checkpoint(run_function(start, end, functions), input, preserve_rng_state=preserve) 72 | return run_function(end + 1, len(functions) - 1, functions)(input) 73 | 74 | def val2list(x: list or tuple or any, repeat_time=1) -> list: # type: ignore 75 | """Repeat `val` for `repeat_time` times and return the list or val if list/tuple.""" 76 | if isinstance(x, (list, tuple)): 77 | return list(x) 78 | return [x for _ in range(repeat_time)] 79 | 80 | def val2tuple(x: list or tuple or any, min_len: int = 1, idx_repeat: int = -1) -> tuple: # type: ignore 81 | """Return tuple with min_len by repeating element at idx_repeat.""" 82 | # convert to list first 83 | x = val2list(x) 84 | 85 | # repeat elements if necessary 86 | if len(x) > 0: 87 | x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))] 88 | 89 | return tuple(x) 90 | 91 | def get_same_padding(kernel_size: Union[int, Tuple[int, ...]]) -> Union[int, Tuple[int, ...]]: 92 | if isinstance(kernel_size, tuple): 93 | return tuple([get_same_padding(ks) for ks in kernel_size]) 94 | else: 95 | assert kernel_size % 2 > 0, f"kernel size {kernel_size} should be odd number" 96 | return kernel_size // 2 97 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /T5/loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import comfy.utils 4 | import comfy.model_patcher 5 | from comfy import model_management 6 | import folder_paths 7 | 8 | from .t5v11 import T5v11Model, T5v11Tokenizer 9 | 10 | class EXM_T5v11: 11 | def __init__(self, textmodel_ver="xxl", embedding_directory=None, textmodel_path=None, no_init=False, device="cpu", dtype=None): 12 | if no_init: 13 | return 14 | 15 | if device == "auto": 16 | size = 0 17 | self.load_device = model_management.text_encoder_device() 18 | self.offload_device = model_management.text_encoder_offload_device() 19 | self.init_device = "cpu" 20 | elif dtype == "bnb8bit": 21 | # BNB doesn't support size enum 22 | size = 12.4 * (1024**3) 23 | # Or moving between devices 24 | self.load_device = model_management.get_torch_device() 25 | self.offload_device = self.load_device 26 | self.init_device = self.load_device 27 | elif dtype == "bnb4bit": 28 | # This seems to use the same VRAM as 8bit on Pascal? 29 | size = 6.2 * (1024**3) 30 | self.load_device = model_management.get_torch_device() 31 | self.offload_device = self.load_device 32 | self.init_device = self.load_device 33 | elif device == "cpu": 34 | size = 0 35 | self.load_device = "cpu" 36 | self.offload_device = "cpu" 37 | self.init_device="cpu" 38 | elif device.startswith("cuda"): 39 | print("Direct CUDA device override!\nVRAM will not be freed by default.") 40 | size = 0 41 | self.load_device = device 42 | self.offload_device = device 43 | self.init_device = device 44 | else: 45 | size = 0 46 | self.load_device = model_management.get_torch_device() 47 | self.offload_device = "cpu" 48 | self.init_device="cpu" 49 | 50 | self.cond_stage_model = T5v11Model( 51 | textmodel_ver = textmodel_ver, 52 | textmodel_path = textmodel_path, 53 | device = device, 54 | dtype = dtype, 55 | ) 56 | self.tokenizer = T5v11Tokenizer(embedding_directory=embedding_directory) 57 | self.patcher = comfy.model_patcher.ModelPatcher( 58 | self.cond_stage_model, 59 | load_device = self.load_device, 60 | offload_device = self.offload_device, 61 | size = size, 62 | ) 63 | 64 | def clone(self): 65 | n = T5(no_init=True) 66 | n.patcher = self.patcher.clone() 67 | n.cond_stage_model = self.cond_stage_model 68 | n.tokenizer = self.tokenizer 69 | return n 70 | 71 | def tokenize(self, text, return_word_ids=False): 72 | return self.tokenizer.tokenize_with_weights(text, return_word_ids) 73 | 74 | def encode_from_tokens(self, tokens): 75 | self.load_model() 76 | return self.cond_stage_model.encode_token_weights(tokens) 77 | 78 | def encode(self, text): 79 | tokens = self.tokenize(text) 80 | return self.encode_from_tokens(tokens) 81 | 82 | def load_sd(self, sd): 83 | return self.cond_stage_model.load_sd(sd) 84 | 85 | def get_sd(self): 86 | return self.cond_stage_model.state_dict() 87 | 88 | def load_model(self): 89 | if self.load_device != "cpu": 90 | model_management.load_model_gpu(self.patcher) 91 | return self.patcher 92 | 93 | def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): 94 | return self.patcher.add_patches(patches, strength_patch, strength_model) 95 | 96 | def get_key_patches(self): 97 | return self.patcher.get_key_patches() 98 | 99 | 100 | def load_t5(model_type, model_ver, model_path, path_type="file", device="cpu", dtype=None): 101 | assert model_type in ["t5v11"] # Only supported model for now 102 | model_args = { 103 | "textmodel_ver" : model_ver, 104 | "device" : device, 105 | "dtype" : dtype, 106 | } 107 | 108 | if path_type == "folder": 109 | # pass directly to transformers and initialize there 110 | # this is to avoid having to handle multi-file state dict loading for now. 111 | model_args["textmodel_path"] = os.path.dirname(model_path) 112 | return EXM_T5v11(**model_args) 113 | else: 114 | # for some reason this returns garbage with torch.int8 weights, or just OOMs 115 | model = EXM_T5v11(**model_args) 116 | sd = comfy.utils.load_torch_file(model_path) 117 | model.load_sd(sd) 118 | return model 119 | -------------------------------------------------------------------------------- /Sana/conf.py: -------------------------------------------------------------------------------- 1 | """ 2 | List of all Sana model types / settings 3 | """ 4 | 5 | sampling_settings = { 6 | "shift": 3.0, 7 | } 8 | 9 | sana_conf = { 10 | "SanaMS_600M_P1_D28": { 11 | "target": "SanaMS", 12 | "unet_config": { 13 | "in_channels": 32, 14 | "depth": 28, 15 | "hidden_size": 1152, 16 | "patch_size": 1, 17 | "num_heads": 16, 18 | "linear_head_dim": 32, 19 | "model_max_length": 300, 20 | "y_norm": True, 21 | "attn_type": "linear", 22 | "ffn_type": "glumbconv", 23 | "mlp_ratio": 2.5, 24 | "mlp_acts": ["silu", "silu", None], 25 | "use_pe": False, 26 | "pred_sigma": False, 27 | "learn_sigma": False, 28 | "fp32_attention": True, 29 | }, 30 | "sampling_settings" : sampling_settings, 31 | }, 32 | "SanaMS_1600M_P1_D20": { 33 | "target": "SanaMS", 34 | "unet_config": { 35 | "in_channels": 32, 36 | "depth": 20, 37 | "hidden_size": 2240, 38 | "patch_size": 1, 39 | "num_heads": 20, 40 | "linear_head_dim": 32, 41 | "model_max_length": 300, 42 | "y_norm": True, 43 | "attn_type": "linear", 44 | "ffn_type": "glumbconv", 45 | "mlp_ratio": 2.5, 46 | "mlp_acts": ["silu", "silu", None], 47 | "use_pe": False, 48 | "pred_sigma": False, 49 | "learn_sigma": False, 50 | "fp32_attention": True, 51 | }, 52 | "sampling_settings" : sampling_settings, 53 | }, 54 | } 55 | 56 | sana_res = { 57 | "1024px": { # models/SanaMS 1024x1024 58 | '0.25': [512, 2048], '0.26': [512, 1984], '0.27': [512, 1920], '0.28': [512, 1856], 59 | '0.32': [576, 1792], '0.33': [576, 1728], '0.35': [576, 1664], '0.40': [640, 1600], 60 | '0.42': [640, 1536], '0.48': [704, 1472], '0.50': [704, 1408], '0.52': [704, 1344], 61 | '0.57': [768, 1344], '0.60': [768, 1280], '0.68': [832, 1216], '0.72': [832, 1152], 62 | '0.78': [896, 1152], '0.82': [896, 1088], '0.88': [960, 1088], '0.94': [960, 1024], 63 | '1.00': [1024,1024], '1.07': [1024, 960], '1.13': [1088, 960], '1.21': [1088, 896], 64 | '1.29': [1152, 896], '1.38': [1152, 832], '1.46': [1216, 832], '1.67': [1280, 768], 65 | '1.75': [1344, 768], '2.00': [1408, 704], '2.09': [1472, 704], '2.40': [1536, 640], 66 | '2.50': [1600, 640], '2.89': [1664, 576], '3.00': [1728, 576], '3.11': [1792, 576], 67 | '3.62': [1856, 512], '3.75': [1920, 512], '3.88': [1984, 512], '4.00': [2048, 512], 68 | }, 69 | "512px": { # models/SanaMS 512x512 70 | '0.25': [256,1024], '0.26': [256, 992], '0.27': [256, 960], '0.28': [256, 928], 71 | '0.32': [288, 896], '0.33': [288, 864], '0.35': [288, 832], '0.40': [320, 800], 72 | '0.42': [320, 768], '0.48': [352, 736], '0.50': [352, 704], '0.52': [352, 672], 73 | '0.57': [384, 672], '0.60': [384, 640], '0.68': [416, 608], '0.72': [416, 576], 74 | '0.78': [448, 576], '0.82': [448, 544], '0.88': [480, 544], '0.94': [480, 512], 75 | '1.00': [512, 512], '1.07': [512, 480], '1.13': [544, 480], '1.21': [544, 448], 76 | '1.29': [576, 448], '1.38': [576, 416], '1.46': [608, 416], '1.67': [640, 384], 77 | '1.75': [672, 384], '2.00': [704, 352], '2.09': [736, 352], '2.40': [768, 320], 78 | '2.50': [800, 320], '2.89': [832, 288], '3.00': [864, 288], '3.11': [896, 288], 79 | '3.62': [928, 256], '3.75': [960, 256], '3.88': [992, 256], '4.00': [1024,256] 80 | }, 81 | "2K": { 82 | '0.25': [1024, 4096], '0.26': [1024, 3968], '0.27': [1024, 3840], '0.28': [1024, 3712], 83 | '0.32': [1152, 3584], '0.33': [1152, 3456], '0.35': [1152, 3328], '0.40': [1280, 3200], 84 | '0.42': [1280, 3072], '0.48': [1408, 2944], '0.50': [1408, 2816], '0.52': [1408, 2688], 85 | '0.57': [1536, 2688], '0.60': [1536, 2560], '0.68': [1664, 2432], '0.72': [1664, 2304], 86 | '0.78': [1792, 2304], '0.82': [1792, 2176], '0.88': [1920, 2176], '0.94': [1920, 2048], 87 | '1.00': [2048, 2048], '1.07': [2048, 1920], '1.13': [2176, 1920], '1.21': [2176, 1792], 88 | '1.29': [2304, 1792], '1.38': [2304, 1664], '1.46': [2432, 1664], '1.67': [2560, 1536], 89 | '1.75': [2688, 1536], '2.00': [2816, 1408], '2.09': [2944, 1408], '2.40': [3072, 1280], 90 | '2.50': [3200, 1280], '2.89': [3328, 1152], '3.00': [3456, 1152], '3.11': [3584, 1152], 91 | '3.62': [3712, 1024], '3.75': [3840, 1024], '3.88': [3968, 1024], '4.00': [4096, 1024] 92 | } 93 | } 94 | # These should be the same 95 | sana_res.update({ 96 | "SanaMS_600M_P1_D28": sana_res["1024px"], 97 | "SanaMS_1600M_P1_D20": sana_res["1024px"], 98 | }) 99 | -------------------------------------------------------------------------------- /HunYuanDiT/models/embedders.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from einops import repeat 5 | 6 | from timm.models.layers import to_2tuple 7 | 8 | 9 | class PatchEmbed(nn.Module): 10 | """ 2D Image to Patch Embedding 11 | 12 | Image to Patch Embedding using Conv2d 13 | 14 | A convolution based approach to patchifying a 2D image w/ embedding projection. 15 | 16 | Based on the impl in https://github.com/google-research/vision_transformer 17 | 18 | Hacked together by / Copyright 2020 Ross Wightman 19 | 20 | Remove the _assert function in forward function to be compatible with multi-resolution images. 21 | """ 22 | def __init__( 23 | self, 24 | img_size=224, 25 | patch_size=16, 26 | in_chans=3, 27 | embed_dim=768, 28 | norm_layer=None, 29 | flatten=True, 30 | bias=True, 31 | ): 32 | super().__init__() 33 | if isinstance(img_size, int): 34 | img_size = to_2tuple(img_size) 35 | elif isinstance(img_size, (tuple, list)) and len(img_size) == 2: 36 | img_size = tuple(img_size) 37 | else: 38 | raise ValueError(f"img_size must be int or tuple/list of length 2. Got {img_size}") 39 | patch_size = to_2tuple(patch_size) 40 | self.img_size = img_size 41 | self.patch_size = patch_size 42 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 43 | self.num_patches = self.grid_size[0] * self.grid_size[1] 44 | self.flatten = flatten 45 | 46 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) 47 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 48 | 49 | def update_image_size(self, img_size): 50 | self.img_size = img_size 51 | self.grid_size = (img_size[0] // self.patch_size[0], img_size[1] // self.patch_size[1]) 52 | self.num_patches = self.grid_size[0] * self.grid_size[1] 53 | 54 | def forward(self, x): 55 | # B, C, H, W = x.shape 56 | # _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).") 57 | # _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).") 58 | x = self.proj(x) 59 | if self.flatten: 60 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC 61 | x = self.norm(x) 62 | return x 63 | 64 | 65 | def timestep_embedding(t, dim, max_period=10000, repeat_only=False): 66 | """ 67 | Create sinusoidal timestep embeddings. 68 | :param t: a 1-D Tensor of N indices, one per batch element. 69 | These may be fractional. 70 | :param dim: the dimension of the output. 71 | :param max_period: controls the minimum frequency of the embeddings. 72 | :return: an (N, D) Tensor of positional embeddings. 73 | """ 74 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py 75 | if not repeat_only: 76 | half = dim // 2 77 | freqs = torch.exp( 78 | -math.log(max_period) 79 | * torch.arange(start=0, end=half, dtype=torch.float32) 80 | / half 81 | ).to(device=t.device) # size: [dim/2], 一个指数衰减的曲线 82 | args = t[:, None].float() * freqs[None] 83 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 84 | if dim % 2: 85 | embedding = torch.cat( 86 | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 87 | ) 88 | else: 89 | embedding = repeat(t, "b -> b d", d=dim) 90 | return embedding 91 | 92 | 93 | class TimestepEmbedder(nn.Module): 94 | """ 95 | Embeds scalar timesteps into vector representations. 96 | """ 97 | def __init__(self, hidden_size, frequency_embedding_size=256, out_size=None): 98 | super().__init__() 99 | if out_size is None: 100 | out_size = hidden_size 101 | self.mlp = nn.Sequential( 102 | nn.Linear(frequency_embedding_size, hidden_size, bias=True), 103 | nn.SiLU(), 104 | nn.Linear(hidden_size, out_size, bias=True), 105 | ) 106 | self.frequency_embedding_size = frequency_embedding_size 107 | 108 | def forward(self, t): 109 | t_freq = timestep_embedding(t, self.frequency_embedding_size).type(self.mlp[0].weight.dtype) 110 | t_emb = self.mlp(t_freq) 111 | return t_emb 112 | -------------------------------------------------------------------------------- /Gemma/nodes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import folder_paths 4 | from transformers import AutoTokenizer, AutoModelForCausalLM 5 | from ..utils.dtype import string_to_dtype 6 | from huggingface_hub import snapshot_download 7 | 8 | 9 | tenc_root = ( 10 | folder_paths.folder_names_and_paths.get( 11 | "text_encoders", 12 | folder_paths.folder_names_and_paths.get("clip", [[], set()]) 13 | ) 14 | ) 15 | 16 | dtypes = [ 17 | "default", 18 | "auto (comfy)", 19 | "BF16", 20 | "FP32", 21 | "FP16", 22 | ] 23 | try: torch.float8_e5m2 24 | except AttributeError: print("Torch版本过旧,不支持FP8") 25 | else: dtypes += ["FP8 E4M3", "FP8 E5M2"] 26 | 27 | class GemmaLoader: 28 | @classmethod 29 | def INPUT_TYPES(s): 30 | devices = ["auto", "cpu", "cuda"] 31 | # 支持多GPU 32 | for k in range(1, torch.cuda.device_count()): 33 | devices.append(f"cuda:{k}") 34 | return { 35 | "required": { 36 | "model_name": (["Efficient-Large-Model/gemma-2-2b-it", "google/gemma-2-2b-it", "unsloth/gemma-2-2b-it-bnb-4bit"],), 37 | "device": (devices, {"default":"cpu"}), 38 | "dtype": (dtypes,), 39 | } 40 | } 41 | RETURN_TYPES = ("GEMMA",) 42 | FUNCTION = "load_model" 43 | CATEGORY = "ExtraModels/Gemma" 44 | TITLE = "Gemma Loader" 45 | 46 | def load_model(self, model_name, device, dtype): 47 | dtype = string_to_dtype(dtype, "text_encoder") 48 | if device == "cpu": 49 | assert dtype in [None, torch.float32], f"Can't use dtype '{dtype}' with CPU! Set dtype to 'default'." 50 | 51 | if model_name == 'google/gemma-2-2b-it': 52 | text_encoder_dir = os.path.join(folder_paths.models_dir, 'text_encoders', 'models--google--gemma-2-2b-it') 53 | if not os.path.exists(os.path.join(text_encoder_dir, 'model.safetensors')): 54 | snapshot_download('google/gemma-2-2b-it', local_dir=text_encoder_dir) 55 | elif model_name == 'unsloth/gemma-2-2b-it-bnb-4bit': 56 | text_encoder_dir = os.path.join(folder_paths.models_dir, 'text_encoders', 'models--unsloth--gemma-2-2b-it-bnb-4bit') 57 | if not os.path.exists(os.path.join(text_encoder_dir, 'model.safetensors')): 58 | snapshot_download('unsloth/gemma-2-2b-it-bnb-4bit', local_dir=text_encoder_dir) 59 | elif model_name == 'Efficient-Large-Model/gemma-2-2b-it': 60 | text_encoder_dir = os.path.join(folder_paths.models_dir, 'text_encoders', 'models--Efficient-Large-Model--gemma-2-2b-it') 61 | if not os.path.exists(os.path.join(text_encoder_dir, 'model.safetensors')): 62 | snapshot_download('Efficient-Large-Model/gemma-2-2b-it', local_dir=text_encoder_dir) 63 | else: 64 | raise ValueError('Not implemented!') 65 | 66 | tokenizer = AutoTokenizer.from_pretrained(model_name) 67 | text_encoder_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=dtype) 68 | tokenizer.padding_side = "right" 69 | text_encoder = text_encoder_model.get_decoder() 70 | 71 | if device != "cpu": 72 | text_encoder = text_encoder.to(device) 73 | 74 | return ({ 75 | "tokenizer": tokenizer, 76 | "text_encoder": text_encoder, 77 | "text_encoder_model": text_encoder_model 78 | },) 79 | 80 | 81 | class GemmaTextEncode: 82 | @classmethod 83 | def INPUT_TYPES(s): 84 | return { 85 | "required": { 86 | "text": ("STRING", {"multiline": True}), 87 | "GEMMA": ("GEMMA",), 88 | } 89 | } 90 | 91 | RETURN_TYPES = ("CONDITIONING",) 92 | FUNCTION = "encode" 93 | CATEGORY = "ExtraModels/Gemma" 94 | TITLE = "Gemma Text Encode" 95 | 96 | def encode(self, text, GEMMA=None): 97 | print(text) 98 | tokenizer = GEMMA["tokenizer"] 99 | text_encoder = GEMMA["text_encoder"] 100 | 101 | with torch.no_grad(): 102 | tokens = tokenizer( 103 | text, 104 | max_length=300, 105 | padding="max_length", 106 | truncation=True, 107 | return_tensors="pt" 108 | ).to(text_encoder.device) 109 | 110 | cond = text_encoder(tokens.input_ids, tokens.attention_mask)[0] 111 | emb_masks = tokens.attention_mask 112 | 113 | cond = cond * emb_masks.unsqueeze(-1) 114 | 115 | return ([[cond, {}]], ) 116 | 117 | NODE_CLASS_MAPPINGS = { 118 | "GemmaLoader": GemmaLoader, 119 | "GemmaTextEncode": GemmaTextEncode, 120 | } 121 | 122 | NODE_DISPLAY_NAME_MAPPINGS = { 123 | "GemmaLoader": "Gemma Loader", 124 | "GemmaTextEncode": "Gemma Text Encode", 125 | } 126 | -------------------------------------------------------------------------------- /PixArt/models/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.utils.checkpoint import checkpoint, checkpoint_sequential 5 | from collections.abc import Iterable 6 | from itertools import repeat 7 | 8 | def _ntuple(n): 9 | def parse(x): 10 | if isinstance(x, Iterable) and not isinstance(x, str): 11 | return x 12 | return tuple(repeat(x, n)) 13 | return parse 14 | 15 | to_1tuple = _ntuple(1) 16 | to_2tuple = _ntuple(2) 17 | 18 | def set_grad_checkpoint(model, use_fp32_attention=False, gc_step=1): 19 | assert isinstance(model, nn.Module) 20 | 21 | def set_attr(module): 22 | module.grad_checkpointing = True 23 | module.fp32_attention = use_fp32_attention 24 | module.grad_checkpointing_step = gc_step 25 | model.apply(set_attr) 26 | 27 | def auto_grad_checkpoint(module, *args, **kwargs): 28 | if getattr(module, 'grad_checkpointing', False): 29 | if isinstance(module, Iterable): 30 | gc_step = module[0].grad_checkpointing_step 31 | return checkpoint_sequential(module, gc_step, *args, **kwargs) 32 | else: 33 | return checkpoint(module, *args, **kwargs) 34 | return module(*args, **kwargs) 35 | 36 | def checkpoint_sequential(functions, step, input, *args, **kwargs): 37 | 38 | # Hack for keyword-only parameter in a python 2.7-compliant way 39 | preserve = kwargs.pop('preserve_rng_state', True) 40 | if kwargs: 41 | raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)) 42 | 43 | def run_function(start, end, functions): 44 | def forward(input): 45 | for j in range(start, end + 1): 46 | input = functions[j](input, *args) 47 | return input 48 | return forward 49 | 50 | if isinstance(functions, torch.nn.Sequential): 51 | functions = list(functions.children()) 52 | 53 | # the last chunk has to be non-volatile 54 | end = -1 55 | segment = len(functions) // step 56 | for start in range(0, step * (segment - 1), step): 57 | end = start + step - 1 58 | input = checkpoint(run_function(start, end, functions), input, preserve_rng_state=preserve) 59 | return run_function(end + 1, len(functions) - 1, functions)(input) 60 | 61 | def get_rel_pos(q_size, k_size, rel_pos): 62 | """ 63 | Get relative positional embeddings according to the relative positions of 64 | query and key sizes. 65 | Args: 66 | q_size (int): size of query q. 67 | k_size (int): size of key k. 68 | rel_pos (Tensor): relative position embeddings (L, C). 69 | 70 | Returns: 71 | Extracted positional embeddings according to relative positions. 72 | """ 73 | max_rel_dist = int(2 * max(q_size, k_size) - 1) 74 | # Interpolate rel pos if needed. 75 | if rel_pos.shape[0] != max_rel_dist: 76 | # Interpolate rel pos. 77 | rel_pos_resized = F.interpolate( 78 | rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), 79 | size=max_rel_dist, 80 | mode="linear", 81 | ) 82 | rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) 83 | else: 84 | rel_pos_resized = rel_pos 85 | 86 | # Scale the coords with short length if shapes for q and k are different. 87 | q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) 88 | k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) 89 | relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) 90 | 91 | return rel_pos_resized[relative_coords.long()] 92 | 93 | def add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size): 94 | """ 95 | Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. 96 | https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 97 | Args: 98 | attn (Tensor): attention map. 99 | q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). 100 | rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. 101 | rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. 102 | q_size (Tuple): spatial sequence size of query q with (q_h, q_w). 103 | k_size (Tuple): spatial sequence size of key k with (k_h, k_w). 104 | 105 | Returns: 106 | attn (Tensor): attention map with added relative positional embeddings. 107 | """ 108 | q_h, q_w = q_size 109 | k_h, k_w = k_size 110 | Rh = get_rel_pos(q_h, k_h, rel_pos_h) 111 | Rw = get_rel_pos(q_w, k_w, rel_pos_w) 112 | 113 | B, _, dim = q.shape 114 | r_q = q.reshape(B, q_h, q_w, dim) 115 | rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) 116 | rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) 117 | 118 | attn = ( 119 | attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] 120 | ).view(B, q_h * q_w, k_h * k_w) 121 | 122 | return attn 123 | -------------------------------------------------------------------------------- /Sana/nodes.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import folder_paths 3 | from nodes import EmptyLatentImage 4 | 5 | from .conf import sana_conf, sana_res 6 | from .loader import load_sana 7 | 8 | dtypes = [ 9 | "auto", 10 | "FP32", 11 | "FP16", 12 | "BF16" 13 | ] 14 | 15 | class SanaCheckpointLoader: 16 | @classmethod 17 | def INPUT_TYPES(s): 18 | return { 19 | "required": { 20 | "ckpt_name": (folder_paths.get_filename_list("checkpoints"),), 21 | "model": (list(sana_conf.keys()),), 22 | } 23 | } 24 | RETURN_TYPES = ("MODEL",) 25 | RETURN_NAMES = ("model",) 26 | FUNCTION = "load_checkpoint" 27 | CATEGORY = "ExtraModels/Sana" 28 | TITLE = "Sana Checkpoint Loader" 29 | 30 | def load_checkpoint(self, ckpt_name, model): 31 | ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) 32 | model_conf = sana_conf[model] 33 | model = load_sana( 34 | model_path = ckpt_path, 35 | model_conf = model_conf, 36 | ) 37 | return (model,) 38 | 39 | 40 | class EmptySanaLatentImage(EmptyLatentImage): 41 | CATEGORY = "ExtraModels/Sana" 42 | TITLE = "Empty Sana Latent Image" 43 | 44 | def generate(self, width, height, batch_size=1): 45 | latent = torch.zeros([batch_size, 32, height // 32, width // 32], device=self.device) 46 | return ({"samples":latent}, ) 47 | 48 | 49 | class SanaResolutionSelect(): 50 | @classmethod 51 | def INPUT_TYPES(s): 52 | return { 53 | "required": { 54 | "model": (list(sana_res.keys()),), 55 | "ratio": (list(sana_res["1024px"].keys()),{"default":"1.00"}), 56 | } 57 | } 58 | RETURN_TYPES = ("INT","INT") 59 | RETURN_NAMES = ("width","height") 60 | FUNCTION = "get_res" 61 | CATEGORY = "ExtraModels/Sana" 62 | TITLE = "Sana Resolution Select" 63 | 64 | def get_res(self, model, ratio): 65 | width, height = sana_res[model][ratio] 66 | return (width,height) 67 | 68 | 69 | class SanaResolutionCond: 70 | @classmethod 71 | def INPUT_TYPES(s): 72 | return { 73 | "required": { 74 | "cond": ("CONDITIONING", ), 75 | "width": ("INT", {"default": 1024.0, "min": 0, "max": 8192}), 76 | "height": ("INT", {"default": 1024.0, "min": 0, "max": 8192}), 77 | } 78 | } 79 | 80 | RETURN_TYPES = ("CONDITIONING",) 81 | RETURN_NAMES = ("cond",) 82 | FUNCTION = "add_cond" 83 | CATEGORY = "ExtraModels/Sana" 84 | TITLE = "Sana Resolution Conditioning" 85 | 86 | def add_cond(self, cond, width, height): 87 | for c in range(len(cond)): 88 | cond[c][1].update({ 89 | "img_hw": [[height, width]], 90 | "aspect_ratio": [[height/width]], 91 | }) 92 | return (cond,) 93 | 94 | 95 | class SanaTextEncode: 96 | @classmethod 97 | def INPUT_TYPES(s): 98 | return { 99 | "required": { 100 | "text": ("STRING", {"multiline": True}), 101 | "GEMMA": ("GEMMA",), 102 | } 103 | } 104 | 105 | RETURN_TYPES = ("CONDITIONING",) 106 | FUNCTION = "encode" 107 | CATEGORY = "ExtraModels/Sana" 108 | TITLE = "Sana Text Encode" 109 | 110 | def encode(self, text, GEMMA=None): 111 | tokenizer = GEMMA["tokenizer"] 112 | text_encoder = GEMMA["text_encoder"] 113 | 114 | with torch.no_grad(): 115 | chi_prompt = "\n".join(preset_te_prompt) 116 | full_prompt = chi_prompt + text 117 | num_chi_tokens = len(tokenizer.encode(chi_prompt)) 118 | max_length = num_chi_tokens + 300 - 2 119 | 120 | tokens = tokenizer( 121 | [full_prompt], 122 | max_length=max_length, 123 | padding="max_length", 124 | truncation=True, 125 | return_tensors="pt" 126 | ).to(text_encoder.device) 127 | 128 | select_idx = [0] + list(range(-300 + 1, 0)) 129 | embs = text_encoder(tokens.input_ids, tokens.attention_mask)[0][:, None][:, :, select_idx] 130 | emb_masks = tokens.attention_mask[:, select_idx] 131 | embs = embs * emb_masks.unsqueeze(-1) 132 | 133 | return ([[embs, {}]], ) 134 | 135 | preset_te_prompt = [ 136 | 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:', 137 | '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.', 138 | '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.', 139 | 'Here are examples of how to transform or refine prompts:', 140 | '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.', 141 | '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.', 142 | 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:', 143 | 'User Prompt: ' 144 | ] 145 | 146 | NODE_CLASS_MAPPINGS = { 147 | "SanaCheckpointLoader" : SanaCheckpointLoader, 148 | "SanaResolutionSelect" : SanaResolutionSelect, 149 | "SanaTextEncode" : SanaTextEncode, 150 | "SanaResolutionCond" : SanaResolutionCond, 151 | "EmptySanaLatentImage": EmptySanaLatentImage, 152 | } 153 | -------------------------------------------------------------------------------- /VAE/models/LICENSE-SDV: -------------------------------------------------------------------------------- 1 | STABLE VIDEO DIFFUSION NON-COMMERCIAL COMMUNITY LICENSE AGREEMENT 2 | Dated: November 21, 2023 3 | 4 | “AUP” means the Stability AI Acceptable Use Policy available at https://stability.ai/use-policy, as may be updated from time to time. 5 | 6 | "Agreement" means the terms and conditions for use, reproduction, distribution and modification of the Software Products set forth herein. 7 | "Derivative Work(s)” means (a) any derivative work of the Software Products as recognized by U.S. copyright laws and (b) any modifications to a Model, and any other model created which is based on or derived from the Model or the Model’s output. For clarity, Derivative Works do not include the output of any Model. 8 | “Documentation” means any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software. 9 | 10 | "Licensee" or "you" means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity's behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf. 11 | 12 | "Stability AI" or "we" means Stability AI Ltd. 13 | 14 | "Software" means, collectively, Stability AI’s proprietary models and algorithms, including machine-learning models, trained model weights and other elements of the foregoing, made available under this Agreement. 15 | 16 | “Software Products” means Software and Documentation. 17 | 18 | By using or distributing any portion or element of the Software Products, you agree to be bound by this Agreement. 19 | 20 | 21 | 22 | License Rights and Redistribution. 23 | Subject to your compliance with this Agreement, the AUP (which is hereby incorporated herein by reference), and the Documentation, Stability AI grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license under Stability AI’s intellectual property or other rights owned by Stability AI embodied in the Software Products to reproduce, distribute, and create Derivative Works of the Software Products for purposes other than commercial or production use. 24 | b. If you distribute or make the Software Products, or any Derivative Works thereof, available to a third party, the Software Products, Derivative Works, or any portion thereof, respectively, will remain subject to this Agreement and you must (i) provide a copy of this Agreement to such third party, and (ii) retain the following attribution notice within a "Notice" text file distributed as a part of such copies: "Stable Video Diffusion is licensed under the Stable Video Diffusion Research License, Copyright (c) Stability AI Ltd. All Rights Reserved.” If you create a Derivative Work of a Software Product, you may add your own attribution notices to the Notice file included with the Software Product, provided that you clearly indicate which attributions apply to the Software Product and you must state in the NOTICE file that you changed the Software Product and how it was modified. 25 | 2. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE SOFTWARE PRODUCTS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE SOFTWARE PRODUCTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE SOFTWARE PRODUCTS AND ANY OUTPUT AND RESULTS. 26 | 3. Limitation of Liability. IN NO EVENT WILL STABILITY AI OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF STABILITY AI OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING. 27 | 3. Intellectual Property. 28 | a. No trademark licenses are granted under this Agreement, and in connection with the Software Products, neither Stability AI nor Licensee may use any name or mark owned by or associated with the other or any of its affiliates, except as required for reasonable and customary use in describing and redistributing the Software Products. 29 | Subject to Stability AI’s ownership of the Software Products and Derivative Works made by or for Stability AI, with respect to any Derivative Works that are made by you, as between you and Stability AI, you are and will be the owner of such Derivative Works. 30 | If you institute litigation or other proceedings against Stability AI (including a cross-claim or counterclaim in a lawsuit) alleging that the Software Products or associated outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Stability AI from and against any claim by any third party arising out of or related to your use or distribution of the Software Products in violation of this Agreement. 31 | 4. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Software Products and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Stability AI may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the Software Products. Sections 2-4 shall survive the termination of this Agreement. 32 | -------------------------------------------------------------------------------- /PixArt/conf.py: -------------------------------------------------------------------------------- 1 | """ 2 | List of all PixArt model types / settings 3 | """ 4 | 5 | sampling_settings = { 6 | "beta_schedule" : "sqrt_linear", 7 | "linear_start" : 0.0001, 8 | "linear_end" : 0.02, 9 | "timesteps" : 1000, 10 | } 11 | 12 | pixart_conf = { 13 | "PixArtMS_XL_2": { # models/PixArtMS 14 | "target": "PixArtMS", 15 | "unet_config": { 16 | "input_size" : 1024//8, 17 | "depth" : 28, 18 | "num_heads" : 16, 19 | "patch_size" : 2, 20 | "hidden_size" : 1152, 21 | "pe_interpolation": 2, 22 | }, 23 | "sampling_settings" : sampling_settings, 24 | }, 25 | "PixArtMS_Sigma_XL_2": { 26 | "target": "PixArtMSSigma", 27 | "unet_config": { 28 | "input_size" : 1024//8, 29 | "token_num" : 300, 30 | "depth" : 28, 31 | "num_heads" : 16, 32 | "patch_size" : 2, 33 | "hidden_size" : 1152, 34 | "micro_condition": False, 35 | "pe_interpolation": 2, 36 | "model_max_length": 300, 37 | }, 38 | "sampling_settings" : sampling_settings, 39 | }, 40 | "PixArtMS_Sigma_XL_2_900M": { 41 | "target": "PixArtMSSigma", 42 | "unet_config": { 43 | "input_size": 1024 // 8, 44 | "token_num": 300, 45 | "depth": 42, 46 | "num_heads": 16, 47 | "patch_size": 2, 48 | "hidden_size": 1152, 49 | "micro_condition": False, 50 | "pe_interpolation": 2, 51 | "model_max_length": 300, 52 | }, 53 | "sampling_settings": sampling_settings, 54 | }, 55 | "PixArtMS_Sigma_XL_2_2K": { 56 | "target": "PixArtMSSigma", 57 | "unet_config": { 58 | "input_size" : 2048//8, 59 | "token_num" : 300, 60 | "depth" : 28, 61 | "num_heads" : 16, 62 | "patch_size" : 2, 63 | "hidden_size" : 1152, 64 | "micro_condition": False, 65 | "pe_interpolation": 4, 66 | "model_max_length": 300, 67 | }, 68 | "sampling_settings" : sampling_settings, 69 | }, 70 | "PixArt_XL_2": { # models/PixArt 71 | "target": "PixArt", 72 | "unet_config": { 73 | "input_size" : 512//8, 74 | "token_num" : 120, 75 | "depth" : 28, 76 | "num_heads" : 16, 77 | "patch_size" : 2, 78 | "hidden_size" : 1152, 79 | "pe_interpolation": 1, 80 | }, 81 | "sampling_settings" : sampling_settings, 82 | }, 83 | } 84 | 85 | pixart_conf.update({ # controlnet models 86 | "ControlPixArtHalf": { 87 | "target": "ControlPixArtHalf", 88 | "unet_config": pixart_conf["PixArt_XL_2"]["unet_config"], 89 | "sampling_settings": pixart_conf["PixArt_XL_2"]["sampling_settings"], 90 | }, 91 | "ControlPixArtMSHalf": { 92 | "target": "ControlPixArtMSHalf", 93 | "unet_config": pixart_conf["PixArtMS_XL_2"]["unet_config"], 94 | "sampling_settings": pixart_conf["PixArtMS_XL_2"]["sampling_settings"], 95 | } 96 | }) 97 | 98 | pixart_res = { 99 | "PixArtMS_XL_2": { # models/PixArtMS 1024x1024 100 | '0.25': [512, 2048], '0.26': [512, 1984], '0.27': [512, 1920], '0.28': [512, 1856], 101 | '0.32': [576, 1792], '0.33': [576, 1728], '0.35': [576, 1664], '0.40': [640, 1600], 102 | '0.42': [640, 1536], '0.48': [704, 1472], '0.50': [704, 1408], '0.52': [704, 1344], 103 | '0.57': [768, 1344], '0.60': [768, 1280], '0.68': [832, 1216], '0.72': [832, 1152], 104 | '0.78': [896, 1152], '0.82': [896, 1088], '0.88': [960, 1088], '0.94': [960, 1024], 105 | '1.00': [1024,1024], '1.07': [1024, 960], '1.13': [1088, 960], '1.21': [1088, 896], 106 | '1.29': [1152, 896], '1.38': [1152, 832], '1.46': [1216, 832], '1.67': [1280, 768], 107 | '1.75': [1344, 768], '2.00': [1408, 704], '2.09': [1472, 704], '2.40': [1536, 640], 108 | '2.50': [1600, 640], '2.89': [1664, 576], '3.00': [1728, 576], '3.11': [1792, 576], 109 | '3.62': [1856, 512], '3.75': [1920, 512], '3.88': [1984, 512], '4.00': [2048, 512], 110 | }, 111 | "PixArt_XL_2": { # models/PixArt 512x512 112 | '0.25': [256,1024], '0.26': [256, 992], '0.27': [256, 960], '0.28': [256, 928], 113 | '0.32': [288, 896], '0.33': [288, 864], '0.35': [288, 832], '0.40': [320, 800], 114 | '0.42': [320, 768], '0.48': [352, 736], '0.50': [352, 704], '0.52': [352, 672], 115 | '0.57': [384, 672], '0.60': [384, 640], '0.68': [416, 608], '0.72': [416, 576], 116 | '0.78': [448, 576], '0.82': [448, 544], '0.88': [480, 544], '0.94': [480, 512], 117 | '1.00': [512, 512], '1.07': [512, 480], '1.13': [544, 480], '1.21': [544, 448], 118 | '1.29': [576, 448], '1.38': [576, 416], '1.46': [608, 416], '1.67': [640, 384], 119 | '1.75': [672, 384], '2.00': [704, 352], '2.09': [736, 352], '2.40': [768, 320], 120 | '2.50': [800, 320], '2.89': [832, 288], '3.00': [864, 288], '3.11': [896, 288], 121 | '3.62': [928, 256], '3.75': [960, 256], '3.88': [992, 256], '4.00': [1024,256] 122 | }, 123 | "PixArtMS_Sigma_XL_2_2K": { 124 | '0.25': [1024, 4096], '0.26': [1024, 3968], '0.27': [1024, 3840], '0.28': [1024, 3712], 125 | '0.32': [1152, 3584], '0.33': [1152, 3456], '0.35': [1152, 3328], '0.40': [1280, 3200], 126 | '0.42': [1280, 3072], '0.48': [1408, 2944], '0.50': [1408, 2816], '0.52': [1408, 2688], 127 | '0.57': [1536, 2688], '0.60': [1536, 2560], '0.68': [1664, 2432], '0.72': [1664, 2304], 128 | '0.78': [1792, 2304], '0.82': [1792, 2176], '0.88': [1920, 2176], '0.94': [1920, 2048], 129 | '1.00': [2048, 2048], '1.07': [2048, 1920], '1.13': [2176, 1920], '1.21': [2176, 1792], 130 | '1.29': [2304, 1792], '1.38': [2304, 1664], '1.46': [2432, 1664], '1.67': [2560, 1536], 131 | '1.75': [2688, 1536], '2.00': [2816, 1408], '2.09': [2944, 1408], '2.40': [3072, 1280], 132 | '2.50': [3200, 1280], '2.89': [3328, 1152], '3.00': [3456, 1152], '3.11': [3584, 1152], 133 | '3.62': [3712, 1024], '3.75': [3840, 1024], '3.88': [3968, 1024], '4.00': [4096, 1024] 134 | } 135 | } 136 | # These should be the same 137 | pixart_res.update({ 138 | "PixArtMS_Sigma_XL_2": pixart_res["PixArtMS_XL_2"], 139 | "PixArtMS_Sigma_XL_2_512": pixart_res["PixArt_XL_2"], 140 | }) 141 | -------------------------------------------------------------------------------- /PixArt/lora.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import json 4 | import torch 5 | import comfy.lora 6 | import comfy.model_management 7 | from comfy.model_patcher import ModelPatcher 8 | from .diffusers_convert import convert_lora_state_dict 9 | 10 | class EXM_PixArt_ModelPatcher(ModelPatcher): 11 | def calculate_weight(self, patches, weight, key): 12 | """ 13 | This is almost the same as the comfy function, but stripped down to just the LoRA patch code. 14 | The problem with the original code is the q/k/v keys being combined into one for the attention. 15 | In the diffusers code, they're treated as separate keys, but in the reference code they're recombined (q+kv|qkv). 16 | This means, for example, that the [1152,1152] weights become [3456,1152] in the state dict. 17 | The issue with this is that the LoRA weights are [128,1152],[1152,128] and become [384,1162],[3456,128] instead. 18 | 19 | This is the best thing I could think of that would fix that, but it's very fragile. 20 | - Check key shape to determine if it needs the fallback logic 21 | - Cut the input into parts based on the shape (undoing the torch.cat) 22 | - Do the matrix multiplication logic 23 | - Recombine them to match the expected shape 24 | """ 25 | for p in patches: 26 | alpha = p[0] 27 | v = p[1] 28 | strength_model = p[2] 29 | if strength_model != 1.0: 30 | weight *= strength_model 31 | 32 | if isinstance(v, list): 33 | v = (self.calculate_weight(v[1:], v[0].clone(), key), ) 34 | 35 | if len(v) == 2: 36 | patch_type = v[0] 37 | v = v[1] 38 | 39 | if patch_type == "lora": 40 | mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32) 41 | mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32) 42 | if v[2] is not None: 43 | alpha *= v[2] / mat2.shape[0] 44 | try: 45 | mat1 = mat1.flatten(start_dim=1) 46 | mat2 = mat2.flatten(start_dim=1) 47 | 48 | ch1 = mat1.shape[0] // mat2.shape[1] 49 | ch2 = mat2.shape[0] // mat1.shape[1] 50 | ### Fallback logic for shape mismatch ### 51 | if mat1.shape[0] != mat2.shape[1] and ch1 == ch2 and (mat1.shape[0]/mat2.shape[1])%1 == 0: 52 | mat1 = mat1.chunk(ch1, dim=0) 53 | mat2 = mat2.chunk(ch1, dim=0) 54 | weight += torch.cat( 55 | [alpha * torch.mm(mat1[x], mat2[x]) for x in range(ch1)], 56 | dim=0, 57 | ).reshape(weight.shape).type(weight.dtype) 58 | else: 59 | weight += (alpha * torch.mm(mat1, mat2)).reshape(weight.shape).type(weight.dtype) 60 | except Exception as e: 61 | print("ERROR", key, e) 62 | return weight 63 | 64 | def clone(self): 65 | n = EXM_PixArt_ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update) 66 | n.patches = {} 67 | for k in self.patches: 68 | n.patches[k] = self.patches[k][:] 69 | 70 | n.object_patches = self.object_patches.copy() 71 | n.model_options = copy.deepcopy(self.model_options) 72 | n.model_keys = self.model_keys 73 | return n 74 | 75 | def replace_model_patcher(model): 76 | n = EXM_PixArt_ModelPatcher( 77 | model = model.model, 78 | size = model.size, 79 | load_device = model.load_device, 80 | offload_device = model.offload_device, 81 | weight_inplace_update = model.weight_inplace_update, 82 | ) 83 | n.patches = {} 84 | for k in model.patches: 85 | n.patches[k] = model.patches[k][:] 86 | 87 | n.object_patches = model.object_patches.copy() 88 | n.model_options = copy.deepcopy(model.model_options) 89 | return n 90 | 91 | def find_peft_alpha(path): 92 | def load_json(json_path): 93 | with open(json_path) as f: 94 | data = json.load(f) 95 | alpha = data.get("lora_alpha") 96 | alpha = alpha or data.get("alpha") 97 | if not alpha: 98 | print(" Found config but `lora_alpha` is missing!") 99 | else: 100 | print(f" Found config at {json_path} [alpha:{alpha}]") 101 | return alpha 102 | 103 | # For some weird reason peft doesn't include the alpha in the actual model 104 | print("PixArt: Warning! This is a PEFT LoRA. Trying to find config...") 105 | files = [ 106 | f"{os.path.splitext(path)[0]}.json", 107 | f"{os.path.splitext(path)[0]}.config.json", 108 | os.path.join(os.path.dirname(path),"adapter_config.json"), 109 | ] 110 | for file in files: 111 | if os.path.isfile(file): 112 | return load_json(file) 113 | 114 | print(" Missing config/alpha! assuming alpha of 8. Consider converting it/adding a config json to it.") 115 | return 8.0 116 | 117 | def load_pixart_lora(model, lora, lora_path, strength): 118 | k_back = lambda x: x.replace(".lora_up.weight", "") 119 | # need to convert the actual weights for this to work. 120 | if any(True for x in lora.keys() if x.endswith("adaln_single.linear.lora_A.weight")): 121 | lora = convert_lora_state_dict(lora, peft=True) 122 | alpha = find_peft_alpha(lora_path) 123 | lora.update({f"{k_back(x)}.alpha":torch.tensor(alpha) for x in lora.keys() if "lora_up" in x}) 124 | else: # OneTrainer 125 | lora = convert_lora_state_dict(lora, peft=False) 126 | 127 | key_map = {k_back(x):f"diffusion_model.{k_back(x)}.weight" for x in lora.keys() if "lora_up" in x} # fake 128 | 129 | loaded = comfy.lora.load_lora(lora, key_map) 130 | if model is not None: 131 | # switch to custom model patcher when using LoRAs 132 | if isinstance(model, EXM_PixArt_ModelPatcher): 133 | new_modelpatcher = model.clone() 134 | else: 135 | new_modelpatcher = replace_model_patcher(model) 136 | k = new_modelpatcher.add_patches(loaded, strength) 137 | else: 138 | k = () 139 | new_modelpatcher = None 140 | 141 | k = set(k) 142 | for x in loaded: 143 | if (x not in k): 144 | print("NOT LOADED", x) 145 | 146 | return new_modelpatcher 147 | -------------------------------------------------------------------------------- /VAE/conf.py: -------------------------------------------------------------------------------- 1 | """ 2 | List of all VAE configs, with training parts stripped. 3 | """ 4 | vae_conf = { 5 | ### AutoencoderKL ### 6 | "kl-f4": { 7 | "type" : "AutoencoderKL", 8 | "embed_scale" : 4, 9 | "embed_dim" : 3, 10 | "z_channels" : 3, 11 | "double_z" : True, 12 | "resolution" : 256, 13 | "in_channels" : 3, 14 | "out_ch" : 3, 15 | "ch" : 128, 16 | "ch_mult" : [1,2,4], 17 | "num_res_blocks" : 2, 18 | "attn_resolutions" : [], 19 | }, 20 | "kl-f8": { # Default SD1.5 VAE 21 | "type" : "AutoencoderKL", 22 | "embed_scale" : 8, 23 | "embed_dim" : 4, 24 | "z_channels" : 4, 25 | "double_z" : True, 26 | "resolution" : 256, 27 | "in_channels" : 3, 28 | "out_ch" : 3, 29 | "ch" : 128, 30 | "ch_mult" : [1,2,4,4], 31 | "num_res_blocks" : 2, 32 | "attn_resolutions" : [], 33 | }, 34 | "kl-f8-d16": { # 16 channel VAE from https://huggingface.co/ostris/vae-kl-f8-d16/tree/main 35 | "type" : "AutoencoderKL", 36 | "embed_scale" : 8, 37 | "embed_dim" : 16, 38 | "z_channels" : 16, 39 | "double_z" : True, 40 | "resolution" : 256, 41 | "in_channels" : 3, 42 | "out_ch" : 3, 43 | "ch" : 128, 44 | "ch_mult" : [1,1,2,4], 45 | "num_res_blocks" : 2, 46 | "attn_resolutions" : [], 47 | }, 48 | "kl-f16": { 49 | "type" : "AutoencoderKL", 50 | "embed_scale" : 16, 51 | "embed_dim" : 16, 52 | "z_channels" : 16, 53 | "double_z" : True, 54 | "resolution" : 256, 55 | "in_channels" : 3, 56 | "out_ch" : 3, 57 | "ch" : 128, 58 | "ch_mult" : [1,1,2,2,4], 59 | "num_res_blocks" : 2, 60 | "attn_resolutions" : [16], 61 | }, 62 | "kl-f32": { 63 | "type" : "AutoencoderKL", 64 | "embed_scale" : 32, 65 | "embed_dim" : 64, 66 | "z_channels" : 64, 67 | "double_z" : True, 68 | "resolution" : 256, 69 | "in_channels" : 3, 70 | "out_ch" : 3, 71 | "ch" : 128, 72 | "ch_mult" : [1,1,2,2,4,4], 73 | "num_res_blocks" : 2, 74 | "attn_resolutions" : [16,8], 75 | }, 76 | ### VQModel ### 77 | "vq-f4": { 78 | "type" : "VQModel", 79 | "embed_scale" : 4, 80 | "n_embed" : 8192, 81 | "embed_dim" : 3, 82 | "z_channels" : 3, 83 | "double_z" : False, 84 | "resolution" : 256, 85 | "in_channels" : 3, 86 | "out_ch" : 3, 87 | "ch" : 128, 88 | "ch_mult" : [1,2,4], 89 | "num_res_blocks" : 2, 90 | "attn_resolutions" : [], 91 | }, 92 | "vq-f8": { 93 | "type" : "VQModel", 94 | "embed_scale" : 8, 95 | "n_embed" : 16384, 96 | "embed_dim" : 4, 97 | "z_channels" : 4, 98 | "double_z" : False, 99 | "resolution" : 256, 100 | "in_channels" : 3, 101 | "out_ch" : 3, 102 | "ch" : 128, 103 | "ch_mult" : [1,2,2,4], 104 | "num_res_blocks" : 2, 105 | "attn_resolutions" : [32], 106 | }, 107 | "vq-f16": { 108 | "type" : "VQModel", 109 | "embed_scale" : 16, 110 | "n_embed" : 16384, 111 | "embed_dim" : 8, 112 | "z_channels" : 8, 113 | "double_z" : False, 114 | "resolution" : 256, 115 | "in_channels" : 3, 116 | "out_ch" : 3, 117 | "ch" : 128, 118 | "ch_mult" : [1,1,2,2,4], 119 | "num_res_blocks" : 2, 120 | "attn_resolutions" : [16], 121 | }, 122 | # OpenAI Consistency Decoder 123 | "Consistency-Decoder": { 124 | "type" : "ConsistencyDecoder", 125 | "embed_scale" : 8, 126 | "embed_dim" : 4, 127 | }, 128 | # SAI Video Decoder 129 | "SDV-VideoDecoder": { 130 | "type" : "AutoencoderKL-VideoDecoder", 131 | "embed_scale" : 8, 132 | "embed_dim" : 4, 133 | "z_channels" : 4, 134 | "double_z" : True, 135 | "resolution" : 256, 136 | "in_channels" : 3, 137 | "out_ch" : 3, 138 | "ch" : 128, 139 | "ch_mult" : [1,2,4,4], 140 | "num_res_blocks" : 2, 141 | "attn_resolutions" : [], 142 | "video_kernel_size": [3, 1, 1] 143 | }, 144 | # Kandinsky-3 145 | "MoVQ3": { 146 | "type" : "MoVQ3", 147 | "embed_scale" : 8, 148 | "embed_dim" : 4, 149 | "double_z" : False, 150 | "z_channels" : 4, 151 | "resolution" : 256, 152 | "in_channels" : 3, 153 | "out_ch" : 3, 154 | "ch" : 256, 155 | "ch_mult" : [1, 2, 2, 4], 156 | "num_res_blocks" : 2, 157 | "attn_resolutions" : [32], 158 | }, 159 | # DCAE configs 160 | "dcae-f32c32-sana-1.0": { 161 | "type" : "DCAE", 162 | "in_channels" : 3, 163 | "embed_scale" : 32, 164 | "embed_dim" : 32, 165 | "encoder_block_type" : ["ResBlock", "ResBlock", "ResBlock", "EViTS5GLU", "EViTS5GLU", "EViTS5GLU"], 166 | "encoder_width_list" : [128, 256, 512, 512, 1024, 1024], 167 | "encoder_depth_list" : [2, 2, 2, 3, 3, 3], 168 | "encoder_norm" : "rms2d", 169 | "encoder_act" : "silu", 170 | "downsample_block_type" : "Conv", 171 | "decoder_block_type" : ["ResBlock", "ResBlock", "ResBlock", "EViTS5GLU", "EViTS5GLU", "EViTS5GLU"], 172 | "decoder_width_list" : [128, 256, 512, 512, 1024, 1024], 173 | "decoder_depth_list" : [3, 3, 3, 3, 3, 3], 174 | "decoder_norm" : "rms2d", 175 | "decoder_act" : "silu", 176 | "upsample_block_type" : "InterpolateConv", 177 | "scaling_factor" : 0.41407 178 | } 179 | } 180 | -------------------------------------------------------------------------------- /HunYuanDiT/nodes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import folder_paths 3 | from copy import deepcopy 4 | 5 | from .conf import hydit_conf 6 | from .loader import load_hydit 7 | 8 | class HYDiTCheckpointLoader: 9 | @classmethod 10 | def INPUT_TYPES(s): 11 | return { 12 | "required": { 13 | "ckpt_name": (folder_paths.get_filename_list("checkpoints"),), 14 | "model": (list(hydit_conf.keys()),{"default":"G/2"}), 15 | } 16 | } 17 | RETURN_TYPES = ("MODEL",) 18 | RETURN_NAMES = ("model",) 19 | FUNCTION = "load_checkpoint" 20 | CATEGORY = "ExtraModels/HunyuanDiT" 21 | TITLE = "Hunyuan DiT Checkpoint Loader" 22 | 23 | def load_checkpoint(self, ckpt_name, model): 24 | ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) 25 | model_conf = hydit_conf[model] 26 | model = load_hydit( 27 | model_path = ckpt_path, 28 | model_conf = model_conf, 29 | ) 30 | return (model,) 31 | 32 | #### temp stuff for the text encoder #### 33 | import torch 34 | from .tenc import load_clip, load_t5 35 | from ..utils.dtype import string_to_dtype 36 | dtypes = [ 37 | "default", 38 | "auto (comfy)", 39 | "FP32", 40 | "FP16", 41 | "BF16" 42 | ] 43 | 44 | class HYDiTTextEncoderLoader: 45 | @classmethod 46 | def INPUT_TYPES(s): 47 | devices = ["auto", "cpu", "gpu"] 48 | # hack for using second GPU as offload 49 | for k in range(1, torch.cuda.device_count()): 50 | devices.append(f"cuda:{k}") 51 | return { 52 | "required": { 53 | "clip_name": (folder_paths.get_filename_list("clip"),), 54 | "mt5_name": (folder_paths.get_filename_list("t5"),), 55 | "device": (devices, {"default":"cpu"}), 56 | "dtype": (dtypes,), 57 | } 58 | } 59 | 60 | RETURN_TYPES = ("CLIP", "T5") 61 | FUNCTION = "load_model" 62 | CATEGORY = "ExtraModels/HunyuanDiT" 63 | TITLE = "Hunyuan DiT Text Encoder Loader" 64 | 65 | def load_model(self, clip_name, mt5_name, device, dtype): 66 | dtype = string_to_dtype(dtype, "text_encoder") 67 | if device == "cpu": 68 | assert dtype in [None, torch.float32, torch.bfloat16], f"Can't use dtype '{dtype}' with CPU! Set dtype to 'default' or 'bf16'." 69 | 70 | clip = load_clip( 71 | model_path = folder_paths.get_full_path("clip", clip_name), 72 | device = device, 73 | dtype = dtype, 74 | ) 75 | t5 = load_t5( 76 | model_path = folder_paths.get_full_path("t5", mt5_name), 77 | device = device, 78 | dtype = dtype, 79 | ) 80 | return(clip, t5) 81 | 82 | class HYDiTTextEncode: 83 | @classmethod 84 | def INPUT_TYPES(s): 85 | return { 86 | "required": { 87 | "text": ("STRING", {"multiline": True}), 88 | "text_t5": ("STRING", {"multiline": True}), 89 | "CLIP": ("CLIP",), 90 | "T5": ("T5",), 91 | } 92 | } 93 | 94 | RETURN_TYPES = ("CONDITIONING",) 95 | FUNCTION = "encode" 96 | CATEGORY = "ExtraModels/HunyuanDiT" 97 | TITLE = "Hunyuan DiT Text Encode" 98 | 99 | def encode(self, text, text_t5, CLIP, T5): 100 | # T5 101 | T5.load_model() 102 | t5_pre = T5.tokenizer( 103 | text_t5, 104 | max_length = T5.cond_stage_model.max_length, 105 | padding = 'max_length', 106 | truncation = True, 107 | return_attention_mask = True, 108 | add_special_tokens = True, 109 | return_tensors = 'pt' 110 | ) 111 | t5_mask = t5_pre["attention_mask"] 112 | with torch.no_grad(): 113 | t5_outs = T5.cond_stage_model.transformer( 114 | input_ids = t5_pre["input_ids"].to(T5.load_device), 115 | attention_mask = t5_mask.to(T5.load_device), 116 | output_hidden_states = True, 117 | ) 118 | # to-do: replace -1 for clip skip 119 | t5_embs = t5_outs["hidden_states"][-1].float().cpu() 120 | 121 | # "clip" 122 | CLIP.load_model() 123 | clip_pre = CLIP.tokenizer( 124 | text, 125 | max_length = CLIP.cond_stage_model.max_length, 126 | padding = 'max_length', 127 | truncation = True, 128 | return_attention_mask = True, 129 | add_special_tokens = True, 130 | return_tensors = 'pt' 131 | ) 132 | clip_mask = clip_pre["attention_mask"] 133 | with torch.no_grad(): 134 | clip_outs = CLIP.cond_stage_model.transformer( 135 | input_ids = clip_pre["input_ids"].to(CLIP.load_device), 136 | attention_mask = clip_mask.to(CLIP.load_device), 137 | ) 138 | # to-do: add hidden states 139 | clip_embs = clip_outs[0].float().cpu() 140 | 141 | # combined cond 142 | return ([[ 143 | clip_embs, { 144 | "context_t5": t5_embs, 145 | "context_mask": clip_mask.float(), 146 | "context_t5_mask": t5_mask.float() 147 | } 148 | ]],) 149 | 150 | class HYDiTTextEncodeSimple(HYDiTTextEncode): 151 | @classmethod 152 | def INPUT_TYPES(s): 153 | return { 154 | "required": { 155 | "text": ("STRING", {"multiline": True}), 156 | "CLIP": ("CLIP",), 157 | "T5": ("T5",), 158 | } 159 | } 160 | 161 | FUNCTION = "encode_simple" 162 | TITLE = "Hunyuan DiT Text Encode (simple)" 163 | 164 | def encode_simple(self, text, **args): 165 | return self.encode(text=text, text_t5=text, **args) 166 | 167 | class HYDiTSrcSizeCond: 168 | @classmethod 169 | def INPUT_TYPES(s): 170 | return { 171 | "required": { 172 | "cond": ("CONDITIONING", ), 173 | "width": ("INT", {"default": 1024.0, "min": 0, "max": 8192, "step": 16}), 174 | "height": ("INT", {"default": 1024.0, "min": 0, "max": 8192, "step": 16}), 175 | } 176 | } 177 | 178 | RETURN_TYPES = ("CONDITIONING",) 179 | RETURN_NAMES = ("cond",) 180 | FUNCTION = "add_cond" 181 | CATEGORY = "ExtraModels/HunyuanDiT" 182 | TITLE = "Hunyuan DiT Size Conditioning (advanced)" 183 | 184 | def add_cond(self, cond, width, height): 185 | cond = deepcopy(cond) 186 | for c in range(len(cond)): 187 | cond[c][1].update({ 188 | "src_size_cond": [[height, width]], 189 | }) 190 | return (cond,) 191 | 192 | NODE_CLASS_MAPPINGS = { 193 | "HYDiTCheckpointLoader": HYDiTCheckpointLoader, 194 | "HYDiTTextEncoderLoader": HYDiTTextEncoderLoader, 195 | "HYDiTTextEncode": HYDiTTextEncode, 196 | "HYDiTTextEncodeSimple": HYDiTTextEncodeSimple, 197 | "HYDiTSrcSizeCond": HYDiTSrcSizeCond, 198 | } 199 | -------------------------------------------------------------------------------- /HunYuanDiT/tenc.py: -------------------------------------------------------------------------------- 1 | # This is for loading the CLIP (bert?) + mT5 encoder for HunYuanDiT 2 | import os 3 | import torch 4 | from transformers import AutoTokenizer, modeling_utils 5 | from transformers import T5Config, T5EncoderModel, BertConfig, BertModel 6 | 7 | from comfy import model_management 8 | import comfy.model_patcher 9 | import comfy.utils 10 | 11 | class mT5Model(torch.nn.Module): 12 | def __init__(self, textmodel_json_config=None, device="cpu", max_length=256, freeze=True, dtype=None): 13 | super().__init__() 14 | self.device = device 15 | self.dtype = dtype 16 | self.max_length = max_length 17 | if textmodel_json_config is None: 18 | textmodel_json_config = os.path.join( 19 | os.path.dirname(os.path.realpath(__file__)), 20 | f"config_mt5.json" 21 | ) 22 | config = T5Config.from_json_file(textmodel_json_config) 23 | with modeling_utils.no_init_weights(): 24 | self.transformer = T5EncoderModel(config) 25 | self.to(dtype) 26 | if freeze: 27 | self.freeze() 28 | 29 | def freeze(self): 30 | self.transformer = self.transformer.eval() 31 | for param in self.parameters(): 32 | param.requires_grad = False 33 | 34 | def load_sd(self, sd): 35 | return self.transformer.load_state_dict(sd, strict=False) 36 | 37 | def to(self, *args, **kwargs): 38 | return self.transformer.to(*args, **kwargs) 39 | 40 | class hyCLIPModel(torch.nn.Module): 41 | def __init__(self, textmodel_json_config=None, device="cpu", max_length=77, freeze=True, dtype=None): 42 | super().__init__() 43 | self.device = device 44 | self.dtype = dtype 45 | self.max_length = max_length 46 | if textmodel_json_config is None: 47 | textmodel_json_config = os.path.join( 48 | os.path.dirname(os.path.realpath(__file__)), 49 | f"config_clip.json" 50 | ) 51 | config = BertConfig.from_json_file(textmodel_json_config) 52 | with modeling_utils.no_init_weights(): 53 | self.transformer = BertModel(config) 54 | self.to(dtype) 55 | if freeze: 56 | self.freeze() 57 | 58 | def freeze(self): 59 | self.transformer = self.transformer.eval() 60 | for param in self.parameters(): 61 | param.requires_grad = False 62 | 63 | def load_sd(self, sd): 64 | return self.transformer.load_state_dict(sd, strict=False) 65 | 66 | def to(self, *args, **kwargs): 67 | return self.transformer.to(*args, **kwargs) 68 | 69 | class EXM_HyDiT_Tenc_Temp: 70 | def __init__(self, no_init=False, device="cpu", dtype=None, model_class="mT5", *kwargs): 71 | if no_init: 72 | return 73 | 74 | size = 8 if model_class == "mT5" else 2 75 | if dtype == torch.float32: 76 | size *= 2 77 | size *= (1024**3) 78 | 79 | if device == "auto": 80 | self.load_device = model_management.text_encoder_device() 81 | self.offload_device = model_management.text_encoder_offload_device() 82 | self.init_device = "cpu" 83 | elif device == "cpu": 84 | size = 0 # doesn't matter 85 | self.load_device = "cpu" 86 | self.offload_device = "cpu" 87 | self.init_device="cpu" 88 | elif device.startswith("cuda"): 89 | print("Direct CUDA device override!\nVRAM will not be freed by default.") 90 | size = 0 # not used 91 | self.load_device = device 92 | self.offload_device = device 93 | self.init_device = device 94 | else: 95 | self.load_device = model_management.get_torch_device() 96 | self.offload_device = "cpu" 97 | self.init_device="cpu" 98 | 99 | self.dtype = dtype 100 | self.device = self.load_device 101 | if model_class == "mT5": 102 | self.cond_stage_model = mT5Model( 103 | device = self.load_device, 104 | dtype = self.dtype, 105 | ) 106 | tokenizer_args = {"subfolder": "t2i/mt5"} # web 107 | tokenizer_path = os.path.join( # local 108 | os.path.dirname(os.path.realpath(__file__)), 109 | "mt5_tokenizer", 110 | ) 111 | else: 112 | self.cond_stage_model = hyCLIPModel( 113 | device = self.load_device, 114 | dtype = self.dtype, 115 | ) 116 | tokenizer_args = {"subfolder": "t2i/tokenizer",} # web 117 | tokenizer_path = os.path.join( # local 118 | os.path.dirname(os.path.realpath(__file__)), 119 | "tokenizer", 120 | ) 121 | # self.tokenizer = AutoTokenizer.from_pretrained( 122 | # "Tencent-Hunyuan/HunyuanDiT", 123 | # **tokenizer_args 124 | # ) 125 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) 126 | self.patcher = comfy.model_patcher.ModelPatcher( 127 | self.cond_stage_model, 128 | load_device = self.load_device, 129 | offload_device = self.offload_device, 130 | size = size, 131 | ) 132 | 133 | def clone(self): 134 | n = EXM_HyDiT_Tenc_Temp(no_init=True) 135 | n.patcher = self.patcher.clone() 136 | n.cond_stage_model = self.cond_stage_model 137 | n.tokenizer = self.tokenizer 138 | return n 139 | 140 | def load_sd(self, sd): 141 | return self.cond_stage_model.load_sd(sd) 142 | 143 | def get_sd(self): 144 | return self.cond_stage_model.state_dict() 145 | 146 | def load_model(self): 147 | if self.load_device != "cpu": 148 | model_management.load_model_gpu(self.patcher) 149 | return self.patcher 150 | 151 | def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): 152 | return self.patcher.add_patches(patches, strength_patch, strength_model) 153 | 154 | def get_key_patches(self): 155 | return self.patcher.get_key_patches() 156 | 157 | def load_clip(model_path, **kwargs): 158 | model = EXM_HyDiT_Tenc_Temp(model_class="clip", **kwargs) 159 | sd = comfy.utils.load_torch_file(model_path) 160 | 161 | prefix = "bert." 162 | state_dict = {} 163 | for key in sd: 164 | nkey = key 165 | if key.startswith(prefix): 166 | nkey = key[len(prefix):] 167 | state_dict[nkey] = sd[key] 168 | 169 | m, e = model.load_sd(state_dict) 170 | if len(m) > 0 or len(e) > 0: 171 | print(f"HYDiT: clip missing {len(m)} keys ({len(e)} extra)") 172 | return model 173 | 174 | def load_t5(model_path, **kwargs): 175 | model = EXM_HyDiT_Tenc_Temp(model_class="mT5", **kwargs) 176 | sd = comfy.utils.load_torch_file(model_path) 177 | m, e = model.load_sd(sd) 178 | if len(m) > 0 or len(e) > 0: 179 | print(f"HYDiT: mT5 missing {len(m)} keys ({len(e)} extra)") 180 | return model 181 | -------------------------------------------------------------------------------- /VAE/models/vq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch import nn 4 | from einops import rearrange 5 | 6 | from .kl import Encoder, Decoder 7 | 8 | class VQModel(nn.Module): 9 | def __init__(self, 10 | config, 11 | remap=None, 12 | sane_index_shape=False, # tell vector quantizer to return indices as bhw 13 | ): 14 | super().__init__() 15 | self.embed_dim = config["embed_dim"] 16 | self.n_embed = config["n_embed"] 17 | self.encoder = Encoder(**config) 18 | self.decoder = Decoder(**config) 19 | self.quantize = VectorQuantizer(self.n_embed, self.embed_dim, beta=0.25, 20 | remap=remap, 21 | sane_index_shape=sane_index_shape) 22 | self.quant_conv = torch.nn.Conv2d(config["z_channels"], self.embed_dim, 1) 23 | self.post_quant_conv = torch.nn.Conv2d(self.embed_dim, config["z_channels"], 1) 24 | 25 | def encode(self, x): 26 | h = self.encoder(x) 27 | h = self.quant_conv(h) 28 | return h 29 | 30 | def decode(self, h, force_not_quantize=False): 31 | # also go through quantization layer 32 | if not force_not_quantize: 33 | quant, emb_loss, info = self.quantize(h) 34 | else: 35 | quant = h 36 | quant = self.post_quant_conv(quant) 37 | dec = self.decoder(quant) 38 | return dec 39 | 40 | def forward(self, input, return_pred_indices=False): 41 | quant, diff, (_,_,ind) = self.encode(input) 42 | dec = self.decode(quant) 43 | if return_pred_indices: 44 | return dec, diff, ind 45 | return dec, diff 46 | 47 | 48 | class VectorQuantizer(nn.Module): 49 | """ 50 | Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly 51 | avoids costly matrix multiplications and allows for post-hoc remapping of indices. 52 | """ 53 | # NOTE: due to a bug the beta term was applied to the wrong term. for 54 | # backwards compatibility we use the buggy version by default, but you can 55 | # specify legacy=False to fix it. 56 | def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", 57 | sane_index_shape=False, legacy=True): 58 | super().__init__() 59 | self.n_e = n_e 60 | self.e_dim = e_dim 61 | self.beta = beta 62 | self.legacy = legacy 63 | 64 | self.embedding = nn.Embedding(self.n_e, self.e_dim) 65 | self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) 66 | 67 | self.remap = remap 68 | if self.remap is not None: 69 | self.register_buffer("used", torch.tensor(np.load(self.remap))) 70 | self.re_embed = self.used.shape[0] 71 | self.unknown_index = unknown_index # "random" or "extra" or integer 72 | if self.unknown_index == "extra": 73 | self.unknown_index = self.re_embed 74 | self.re_embed = self.re_embed+1 75 | print(f"Remapping {self.n_e} indices to {self.re_embed} indices. " 76 | f"Using {self.unknown_index} for unknown indices.") 77 | else: 78 | self.re_embed = n_e 79 | 80 | self.sane_index_shape = sane_index_shape 81 | 82 | def remap_to_used(self, inds): 83 | ishape = inds.shape 84 | assert len(ishape)>1 85 | inds = inds.reshape(ishape[0],-1) 86 | used = self.used.to(inds) 87 | match = (inds[:,:,None]==used[None,None,...]).long() 88 | new = match.argmax(-1) 89 | unknown = match.sum(2)<1 90 | if self.unknown_index == "random": 91 | new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device) 92 | else: 93 | new[unknown] = self.unknown_index 94 | return new.reshape(ishape) 95 | 96 | def unmap_to_all(self, inds): 97 | ishape = inds.shape 98 | assert len(ishape)>1 99 | inds = inds.reshape(ishape[0],-1) 100 | used = self.used.to(inds) 101 | if self.re_embed > self.used.shape[0]: # extra token 102 | inds[inds>=self.used.shape[0]] = 0 # simply set to zero 103 | back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds) 104 | return back.reshape(ishape) 105 | 106 | def forward(self, z, temp=None, rescale_logits=False, return_logits=False): 107 | assert temp is None or temp==1.0, "Only for interface compatible with Gumbel" 108 | assert rescale_logits==False, "Only for interface compatible with Gumbel" 109 | assert return_logits==False, "Only for interface compatible with Gumbel" 110 | # reshape z -> (batch, height, width, channel) and flatten 111 | z = rearrange(z, 'b c h w -> b h w c').contiguous() 112 | z_flattened = z.view(-1, self.e_dim) 113 | # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z 114 | 115 | d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ 116 | torch.sum(self.embedding.weight**2, dim=1) - 2 * \ 117 | torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n')) 118 | 119 | min_encoding_indices = torch.argmin(d, dim=1) 120 | z_q = self.embedding(min_encoding_indices).view(z.shape) 121 | perplexity = None 122 | min_encodings = None 123 | 124 | # compute loss for embedding 125 | if not self.legacy: 126 | loss = self.beta * torch.mean((z_q.detach()-z)**2) + \ 127 | torch.mean((z_q - z.detach()) ** 2) 128 | else: 129 | loss = torch.mean((z_q.detach()-z)**2) + self.beta * \ 130 | torch.mean((z_q - z.detach()) ** 2) 131 | 132 | # preserve gradients 133 | z_q = z + (z_q - z).detach() 134 | 135 | # reshape back to match original input shape 136 | z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous() 137 | 138 | if self.remap is not None: 139 | min_encoding_indices = min_encoding_indices.reshape(z.shape[0],-1) # add batch axis 140 | min_encoding_indices = self.remap_to_used(min_encoding_indices) 141 | min_encoding_indices = min_encoding_indices.reshape(-1,1) # flatten 142 | 143 | if self.sane_index_shape: 144 | min_encoding_indices = min_encoding_indices.reshape( 145 | z_q.shape[0], z_q.shape[2], z_q.shape[3]) 146 | 147 | return z_q, loss, (perplexity, min_encodings, min_encoding_indices) 148 | 149 | def get_codebook_entry(self, indices, shape): 150 | # shape specifying (batch, height, width, channel) 151 | if self.remap is not None: 152 | indices = indices.reshape(shape[0],-1) # add batch axis 153 | indices = self.unmap_to_all(indices) 154 | indices = indices.reshape(-1) # flatten again 155 | 156 | # get quantized latent vectors 157 | z_q = self.embedding(indices) 158 | 159 | if shape is not None: 160 | z_q = z_q.view(shape) 161 | # reshape back to match original input shape 162 | z_q = z_q.permute(0, 3, 1, 2).contiguous() 163 | 164 | return z_q 165 | -------------------------------------------------------------------------------- /PixArt/loader.py: -------------------------------------------------------------------------------- 1 | import comfy.supported_models_base 2 | import comfy.latent_formats 3 | import comfy.model_patcher 4 | import comfy.model_base 5 | import comfy.utils 6 | import comfy.conds 7 | import torch 8 | import math 9 | from comfy import model_management 10 | from .diffusers_convert import convert_state_dict 11 | 12 | class EXM_PixArt(comfy.supported_models_base.BASE): 13 | unet_config = {} 14 | unet_extra_config = {} 15 | latent_format = comfy.latent_formats.SD15 16 | 17 | def __init__(self, model_conf): 18 | self.model_target = model_conf.get("target") 19 | self.unet_config = model_conf.get("unet_config", {}) 20 | self.sampling_settings = model_conf.get("sampling_settings", {}) 21 | self.latent_format = self.latent_format() 22 | # UNET is handled by extension 23 | self.unet_config["disable_unet_model_creation"] = True 24 | 25 | def model_type(self, state_dict, prefix=""): 26 | return comfy.model_base.ModelType.EPS 27 | 28 | class EXM_PixArt_Model(comfy.model_base.BaseModel): 29 | def __init__(self, *args, **kwargs): 30 | super().__init__(*args, **kwargs) 31 | 32 | def extra_conds(self, **kwargs): 33 | out = super().extra_conds(**kwargs) 34 | 35 | img_hw = kwargs.get("img_hw", None) 36 | if img_hw is not None: 37 | out["img_hw"] = comfy.conds.CONDRegular(torch.tensor(img_hw)) 38 | 39 | aspect_ratio = kwargs.get("aspect_ratio", None) 40 | if aspect_ratio is not None: 41 | out["aspect_ratio"] = comfy.conds.CONDRegular(torch.tensor(aspect_ratio)) 42 | 43 | cn_hint = kwargs.get("cn_hint", None) 44 | if cn_hint is not None: 45 | out["cn_hint"] = comfy.conds.CONDRegular(cn_hint) 46 | 47 | return out 48 | 49 | def load_pixart(model_path, model_conf=None): 50 | state_dict = comfy.utils.load_torch_file(model_path) 51 | state_dict = state_dict.get("model", state_dict) 52 | 53 | # prefix 54 | for prefix in ["model.diffusion_model.",]: 55 | if any(True for x in state_dict if x.startswith(prefix)): 56 | state_dict = {k[len(prefix):]:v for k,v in state_dict.items()} 57 | 58 | # diffusers 59 | if "adaln_single.linear.weight" in state_dict: 60 | state_dict = convert_state_dict(state_dict) # Diffusers 61 | 62 | # guess auto config 63 | if model_conf is None: 64 | model_conf = guess_pixart_config(state_dict) 65 | 66 | parameters = comfy.utils.calculate_parameters(state_dict) 67 | unet_dtype = model_management.unet_dtype(model_params=parameters) 68 | load_device = comfy.model_management.get_torch_device() 69 | offload_device = comfy.model_management.unet_offload_device() 70 | 71 | # ignore fp8/etc and use directly for now 72 | manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device) 73 | if manual_cast_dtype: 74 | print(f"PixArt: falling back to {manual_cast_dtype}") 75 | unet_dtype = manual_cast_dtype 76 | 77 | model_conf = EXM_PixArt(model_conf) # convert to object 78 | model = EXM_PixArt_Model( # same as comfy.model_base.BaseModel 79 | model_conf, 80 | model_type=comfy.model_base.ModelType.EPS, 81 | device=model_management.get_torch_device() 82 | ) 83 | 84 | if model_conf.model_target == "PixArtMS": 85 | from .models.PixArtMS import PixArtMS 86 | model.diffusion_model = PixArtMS(**model_conf.unet_config) 87 | elif model_conf.model_target == "PixArt": 88 | from .models.PixArt import PixArt 89 | model.diffusion_model = PixArt(**model_conf.unet_config) 90 | elif model_conf.model_target == "PixArtMSSigma": 91 | from .models.PixArtMS import PixArtMS 92 | model.diffusion_model = PixArtMS(**model_conf.unet_config) 93 | model.latent_format = comfy.latent_formats.SDXL() 94 | elif model_conf.model_target == "ControlPixArtMSHalf": 95 | from .models.PixArtMS import PixArtMS 96 | from .models.pixart_controlnet import ControlPixArtMSHalf 97 | model.diffusion_model = PixArtMS(**model_conf.unet_config) 98 | model.diffusion_model = ControlPixArtMSHalf(model.diffusion_model) 99 | elif model_conf.model_target == "ControlPixArtHalf": 100 | from .models.PixArt import PixArt 101 | from .models.pixart_controlnet import ControlPixArtHalf 102 | model.diffusion_model = PixArt(**model_conf.unet_config) 103 | model.diffusion_model = ControlPixArtHalf(model.diffusion_model) 104 | else: 105 | raise NotImplementedError(f"Unknown model target '{model_conf.model_target}'") 106 | 107 | m, u = model.diffusion_model.load_state_dict(state_dict, strict=False) 108 | if len(m) > 0: print("Missing UNET keys", m) 109 | if len(u) > 0: print("Leftover UNET keys", u) 110 | model.diffusion_model.dtype = unet_dtype 111 | model.diffusion_model.eval() 112 | model.diffusion_model.to(unet_dtype) 113 | 114 | model_patcher = comfy.model_patcher.ModelPatcher( 115 | model, 116 | load_device = load_device, 117 | offload_device = offload_device, 118 | ) 119 | return model_patcher 120 | 121 | def guess_pixart_config(sd): 122 | """ 123 | Guess config based on converted state dict. 124 | """ 125 | # Shared settings based on DiT_XL_2 - could be enumerated 126 | config = { 127 | "num_heads" : 16, # get from attention 128 | "patch_size" : 2, # final layer I guess? 129 | "hidden_size" : 1152, # pos_embed.shape[2] 130 | } 131 | config["depth"] = sum([key.endswith(".attn.proj.weight") for key in sd.keys()]) or 28 132 | 133 | try: 134 | # this is not present in the diffusers version for sigma? 135 | config["model_max_length"] = sd["y_embedder.y_embedding"].shape[0] 136 | except KeyError: 137 | # need better logic to guess this 138 | config["model_max_length"] = 300 139 | 140 | if "pos_embed" in sd: 141 | config["input_size"] = int(math.sqrt(sd["pos_embed"].shape[1])) * config["patch_size"] 142 | config["pe_interpolation"] = config["input_size"] // (512//8) # dumb guess 143 | 144 | target_arch = "PixArtMS" 145 | if config["model_max_length"] == 300: 146 | # Sigma 147 | target_arch = "PixArtMSSigma" 148 | config["micro_condition"] = False 149 | if "input_size" not in config: 150 | # The diffusers weights for 1K/2K are exactly the same...? 151 | # replace patch embed logic with HyDiT? 152 | print(f"PixArt: diffusers weights - 2K model will be broken, use manual loading!") 153 | config["input_size"] = 1024//8 154 | else: 155 | # Alpha 156 | if "csize_embedder.mlp.0.weight" in sd: 157 | # MS (microconds) 158 | target_arch = "PixArtMS" 159 | config["micro_condition"] = True 160 | if "input_size" not in config: 161 | config["input_size"] = 1024//8 162 | config["pe_interpolation"] = 2 163 | else: 164 | # PixArt 165 | target_arch = "PixArt" 166 | if "input_size" not in config: 167 | config["input_size"] = 512//8 168 | config["pe_interpolation"] = 1 169 | 170 | print("PixArt guessed config:", target_arch, config) 171 | return { 172 | "target": target_arch, 173 | "unet_config": config, 174 | "sampling_settings": { 175 | "beta_schedule" : "sqrt_linear", 176 | "linear_start" : 0.0001, 177 | "linear_end" : 0.02, 178 | "timesteps" : 1000, 179 | } 180 | } 181 | -------------------------------------------------------------------------------- /VAE/loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import comfy.sd 3 | import comfy.utils 4 | from comfy import model_management 5 | from comfy import diffusers_convert 6 | 7 | class EXVAE(comfy.sd.VAE): 8 | def __init__(self, model_path, model_conf, dtype=torch.float32): 9 | self.latent_dim = model_conf["embed_dim"] 10 | self.latent_scale = model_conf["embed_scale"] 11 | self.device = model_management.vae_device() 12 | self.offload_device = model_management.vae_offload_device() 13 | self.vae_dtype = dtype 14 | 15 | sd = comfy.utils.load_torch_file(model_path) 16 | model = None 17 | if model_conf["type"] == "AutoencoderKL": 18 | from .models.kl import AutoencoderKL 19 | model = AutoencoderKL(config=model_conf) 20 | if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): 21 | sd = diffusers_convert.convert_vae_state_dict(sd) 22 | elif model_conf["type"] == "AutoencoderKL-VideoDecoder": 23 | from .models.temporal_ae import AutoencoderKL 24 | model = AutoencoderKL(config=model_conf) 25 | elif model_conf["type"] == "VQModel": 26 | from .models.vq import VQModel 27 | model = VQModel(config=model_conf) 28 | elif model_conf["type"] == "ConsistencyDecoder": 29 | from .models.consistencydecoder import ConsistencyDecoder 30 | model = ConsistencyDecoder() 31 | sd = {f"model.{k}":v for k,v in sd.items()} 32 | elif model_conf["type"] == "MoVQ3": 33 | from .models.movq3 import MoVQ 34 | model = MoVQ(model_conf) 35 | elif model_conf["type"] == "DCAE": 36 | from .models.dcae import DCAE 37 | if 'decoder.project_out.op_list.0.bias' in sd: 38 | from .models import dcae_key_mapping 39 | sd = dcae_key_mapping.convert_sd(sd) 40 | model = DCAE(**model_conf) 41 | else: 42 | raise NotImplementedError(f"Unknown VAE type '{model_conf['type']}'") 43 | 44 | self.first_stage_model = model.eval() 45 | m, u = self.first_stage_model.load_state_dict(sd, strict=False) 46 | if len(m) > 0: print("Missing VAE keys", m) 47 | if len(u) > 0: print("Leftover VAE keys", u) 48 | 49 | self.first_stage_model.to(self.vae_dtype).to(self.offload_device) 50 | 51 | ### Encode/Decode functions below needed due to source repo having 4 VAE channels and a scale factor of 8 hardcoded 52 | def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): 53 | steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap) 54 | steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap) 55 | steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap) 56 | pbar = comfy.utils.ProgressBar(steps) 57 | 58 | decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)) + 1.0).float() 59 | output = torch.clamp(( 60 | (comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.latent_scale, pbar = pbar) + 61 | comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.latent_scale, pbar = pbar) + 62 | comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = self.latent_scale, pbar = pbar)) 63 | / 3.0) / 2.0, min=0.0, max=1.0) 64 | return output 65 | 66 | def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): 67 | steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap) 68 | steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap) 69 | steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap) 70 | pbar = comfy.utils.ProgressBar(steps) 71 | 72 | encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).float() 73 | samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.latent_scale), out_channels=self.latent_dim, pbar=pbar) 74 | samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.latent_scale), out_channels=self.latent_dim, pbar=pbar) 75 | samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.latent_scale), out_channels=self.latent_dim, pbar=pbar) 76 | samples /= 3.0 77 | return samples 78 | 79 | def decode(self, samples_in): 80 | self.first_stage_model = self.first_stage_model.to(self.device) 81 | try: 82 | memory_used = (2562 * samples_in.shape[2] * samples_in.shape[3] * 64) * 1.7 83 | model_management.free_memory(memory_used, self.device) 84 | free_memory = model_management.get_free_memory(self.device) 85 | batch_number = int(free_memory / memory_used) 86 | batch_number = max(1, batch_number) 87 | 88 | pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * self.latent_scale), round(samples_in.shape[3] * self.latent_scale)), device="cpu") 89 | for x in range(0, samples_in.shape[0], batch_number): 90 | samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device) 91 | pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples).cpu().float() + 1.0) / 2.0, min=0.0, max=1.0) 92 | except model_management.OOM_EXCEPTION as e: 93 | print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") 94 | pixel_samples = self.decode_tiled_(samples_in) 95 | 96 | self.first_stage_model = self.first_stage_model.to(self.offload_device) 97 | pixel_samples = pixel_samples.cpu().movedim(1,-1) 98 | return pixel_samples 99 | 100 | def encode(self, pixel_samples): 101 | self.first_stage_model = self.first_stage_model.to(self.device) 102 | pixel_samples = pixel_samples.movedim(-1,1) 103 | try: 104 | memory_used = (2078 * pixel_samples.shape[2] * pixel_samples.shape[3]) * 1.7 #NOTE: this constant along with the one in the decode above are estimated from the mem usage for the VAE and could change. 105 | model_management.free_memory(memory_used, self.device) 106 | free_memory = model_management.get_free_memory(self.device) 107 | batch_number = int(free_memory / memory_used) 108 | batch_number = max(1, batch_number) 109 | samples = torch.empty((pixel_samples.shape[0], self.latent_dim, round(pixel_samples.shape[2] // self.latent_scale), round(pixel_samples.shape[3] // self.latent_scale)), device="cpu") 110 | for x in range(0, pixel_samples.shape[0], batch_number): 111 | pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.vae_dtype).to(self.device) 112 | samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).cpu().float() 113 | 114 | except model_management.OOM_EXCEPTION as e: 115 | print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") 116 | samples = self.encode_tiled_(pixel_samples) 117 | 118 | self.first_stage_model = self.first_stage_model.to(self.offload_device) 119 | return samples 120 | -------------------------------------------------------------------------------- /PixArt/nodes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import folder_paths 5 | 6 | from comfy import utils 7 | from .conf import pixart_conf, pixart_res 8 | from .lora import load_pixart_lora 9 | from .loader import load_pixart 10 | 11 | class PixArtCheckpointLoader: 12 | @classmethod 13 | def INPUT_TYPES(s): 14 | return { 15 | "required": { 16 | "ckpt_name": (folder_paths.get_filename_list("checkpoints"),), 17 | "model": (list(pixart_conf.keys()),), 18 | } 19 | } 20 | RETURN_TYPES = ("MODEL",) 21 | RETURN_NAMES = ("model",) 22 | FUNCTION = "load_checkpoint" 23 | CATEGORY = "ExtraModels/PixArt" 24 | TITLE = "PixArt Checkpoint Loader" 25 | 26 | def load_checkpoint(self, ckpt_name, model): 27 | ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) 28 | model_conf = pixart_conf[model] 29 | model = load_pixart( 30 | model_path = ckpt_path, 31 | model_conf = model_conf, 32 | ) 33 | return (model,) 34 | 35 | class PixArtCheckpointLoaderSimple(PixArtCheckpointLoader): 36 | @classmethod 37 | def INPUT_TYPES(s): 38 | return { 39 | "required": { 40 | "ckpt_name": (folder_paths.get_filename_list("checkpoints"),), 41 | } 42 | } 43 | TITLE = "PixArt Checkpoint Loader (auto)" 44 | 45 | def load_checkpoint(self, ckpt_name): 46 | ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) 47 | model = load_pixart(model_path=ckpt_path) 48 | return (model,) 49 | 50 | class PixArtResolutionSelect(): 51 | @classmethod 52 | def INPUT_TYPES(s): 53 | return { 54 | "required": { 55 | "model": (list(pixart_res.keys()),), 56 | # keys are the same for both 57 | "ratio": (list(pixart_res["PixArtMS_XL_2"].keys()),{"default":"1.00"}), 58 | } 59 | } 60 | RETURN_TYPES = ("INT","INT") 61 | RETURN_NAMES = ("width","height") 62 | FUNCTION = "get_res" 63 | CATEGORY = "ExtraModels/PixArt" 64 | TITLE = "PixArt Resolution Select" 65 | 66 | def get_res(self, model, ratio): 67 | width, height = pixart_res[model][ratio] 68 | return (width,height) 69 | 70 | class PixArtLoraLoader: 71 | def __init__(self): 72 | self.loaded_lora = None 73 | 74 | @classmethod 75 | def INPUT_TYPES(s): 76 | return { 77 | "required": { 78 | "model": ("MODEL",), 79 | "lora_name": (folder_paths.get_filename_list("loras"), ), 80 | "strength": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), 81 | } 82 | } 83 | RETURN_TYPES = ("MODEL",) 84 | FUNCTION = "load_lora" 85 | CATEGORY = "ExtraModels/PixArt" 86 | TITLE = "PixArt Load LoRA" 87 | 88 | def load_lora(self, model, lora_name, strength,): 89 | if strength == 0: 90 | return (model) 91 | 92 | lora_path = folder_paths.get_full_path("loras", lora_name) 93 | lora = None 94 | if self.loaded_lora is not None: 95 | if self.loaded_lora[0] == lora_path: 96 | lora = self.loaded_lora[1] 97 | else: 98 | temp = self.loaded_lora 99 | self.loaded_lora = None 100 | del temp 101 | 102 | if lora is None: 103 | lora = utils.load_torch_file(lora_path, safe_load=True) 104 | self.loaded_lora = (lora_path, lora) 105 | 106 | model_lora = load_pixart_lora(model, lora, lora_path, strength,) 107 | return (model_lora,) 108 | 109 | class PixArtResolutionCond: 110 | @classmethod 111 | def INPUT_TYPES(s): 112 | return { 113 | "required": { 114 | "cond": ("CONDITIONING", ), 115 | "width": ("INT", {"default": 1024.0, "min": 0, "max": 8192}), 116 | "height": ("INT", {"default": 1024.0, "min": 0, "max": 8192}), 117 | } 118 | } 119 | 120 | RETURN_TYPES = ("CONDITIONING",) 121 | RETURN_NAMES = ("cond",) 122 | FUNCTION = "add_cond" 123 | CATEGORY = "ExtraModels/PixArt" 124 | TITLE = "PixArt Resolution Conditioning" 125 | 126 | def add_cond(self, cond, width, height): 127 | for c in range(len(cond)): 128 | cond[c][1].update({ 129 | "img_hw": [[height, width]], 130 | "aspect_ratio": [[height/width]], 131 | }) 132 | return (cond,) 133 | 134 | class PixArtControlNetCond: 135 | @classmethod 136 | def INPUT_TYPES(s): 137 | return { 138 | "required": { 139 | "cond": ("CONDITIONING",), 140 | "latent": ("LATENT",), 141 | # "image": ("IMAGE",), 142 | # "vae": ("VAE",), 143 | # "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}) 144 | } 145 | } 146 | 147 | RETURN_TYPES = ("CONDITIONING",) 148 | RETURN_NAMES = ("cond",) 149 | FUNCTION = "add_cond" 150 | CATEGORY = "ExtraModels/PixArt" 151 | TITLE = "PixArt ControlNet Conditioning" 152 | 153 | def add_cond(self, cond, latent): 154 | for c in range(len(cond)): 155 | cond[c][1]["cn_hint"] = latent["samples"] * 0.18215 156 | return (cond,) 157 | 158 | class PixArtT5TextEncode: 159 | """ 160 | Reference code, mostly to verify compatibility. 161 | Once everything works, this should instead inherit from the 162 | T5 text encode node and simply add the extra conds (res/ar). 163 | """ 164 | @classmethod 165 | def INPUT_TYPES(s): 166 | return { 167 | "required": { 168 | "text": ("STRING", {"multiline": True}), 169 | "T5": ("T5",), 170 | } 171 | } 172 | 173 | RETURN_TYPES = ("CONDITIONING",) 174 | FUNCTION = "encode" 175 | CATEGORY = "ExtraModels/PixArt" 176 | TITLE = "PixArt T5 Text Encode [Reference]" 177 | 178 | def mask_feature(self, emb, mask): 179 | if emb.shape[0] == 1: 180 | keep_index = mask.sum().item() 181 | return emb[:, :, :keep_index, :], keep_index 182 | else: 183 | masked_feature = emb * mask[:, None, :, None] 184 | return masked_feature, emb.shape[2] 185 | 186 | def encode(self, text, T5): 187 | text = text.lower().strip() 188 | tokenizer_out = T5.tokenizer.tokenizer( 189 | text, 190 | max_length = 120, 191 | padding = 'max_length', 192 | truncation = True, 193 | return_attention_mask = True, 194 | add_special_tokens = True, 195 | return_tensors = 'pt' 196 | ) 197 | tokens = tokenizer_out["input_ids"] 198 | mask = tokenizer_out["attention_mask"] 199 | embs = T5.cond_stage_model.transformer( 200 | input_ids = tokens.to(T5.load_device), 201 | attention_mask = mask.to(T5.load_device), 202 | )['last_hidden_state'].float()[:, None] 203 | masked_embs, keep_index = self.mask_feature( 204 | embs.detach().to("cpu"), 205 | mask.detach().to("cpu") 206 | ) 207 | masked_embs = masked_embs.squeeze(0) # match CLIP/internal 208 | print("Encoded T5:", masked_embs.shape) 209 | return ([[masked_embs, {}]], ) 210 | 211 | class PixArtT5FromSD3CLIP: 212 | """ 213 | Split the T5 text encoder away from SD3 214 | """ 215 | @classmethod 216 | def INPUT_TYPES(s): 217 | return { 218 | "required": { 219 | "sd3_clip": ("CLIP",), 220 | "padding": ("INT", {"default": 1, "min": 1, "max": 300}), 221 | } 222 | } 223 | 224 | RETURN_TYPES = ("CLIP",) 225 | RETURN_NAMES = ("t5",) 226 | FUNCTION = "split" 227 | CATEGORY = "ExtraModels/PixArt" 228 | TITLE = "PixArt T5 from SD3 CLIP" 229 | 230 | def split(self, sd3_clip, padding): 231 | try: 232 | from comfy.text_encoders.sd3_clip import SD3Tokenizer, SD3ClipModel 233 | except ImportError: 234 | # fallback for older ComfyUI versions 235 | from comfy.sd3_clip import SD3Tokenizer, SD3ClipModel 236 | import copy 237 | 238 | clip = sd3_clip.clone() 239 | assert clip.cond_stage_model.t5xxl is not None, "CLIP must have T5 loaded!" 240 | 241 | # remove transformer 242 | transformer = clip.cond_stage_model.t5xxl.transformer 243 | clip.cond_stage_model.t5xxl.transformer = None 244 | 245 | # clone object 246 | tmp = SD3ClipModel(clip_l=False, clip_g=False, t5=False) 247 | tmp.t5xxl = copy.deepcopy(clip.cond_stage_model.t5xxl) 248 | # put transformer back 249 | clip.cond_stage_model.t5xxl.transformer = transformer 250 | tmp.t5xxl.transformer = transformer 251 | 252 | # override special tokens 253 | tmp.t5xxl.special_tokens = copy.deepcopy(clip.cond_stage_model.t5xxl.special_tokens) 254 | tmp.t5xxl.special_tokens.pop("end") # make sure empty tokens match 255 | 256 | # add attn mask opt if present in original 257 | if hasattr(sd3_clip.cond_stage_model, "t5_attention_mask"): 258 | tmp.t5_attention_mask = False 259 | 260 | # tokenizer 261 | tok = SD3Tokenizer() 262 | tok.t5xxl.min_length = padding 263 | 264 | clip.cond_stage_model = tmp 265 | clip.tokenizer = tok 266 | 267 | return (clip, ) 268 | 269 | NODE_CLASS_MAPPINGS = { 270 | "PixArtCheckpointLoader" : PixArtCheckpointLoader, 271 | "PixArtCheckpointLoaderSimple" : PixArtCheckpointLoaderSimple, 272 | "PixArtResolutionSelect" : PixArtResolutionSelect, 273 | "PixArtLoraLoader" : PixArtLoraLoader, 274 | "PixArtT5TextEncode" : PixArtT5TextEncode, 275 | "PixArtResolutionCond" : PixArtResolutionCond, 276 | "PixArtControlNetCond" : PixArtControlNetCond, 277 | "PixArtT5FromSD3CLIP": PixArtT5FromSD3CLIP, 278 | } 279 | -------------------------------------------------------------------------------- /Sana/models/norms.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | import copy 18 | import warnings 19 | 20 | import torch 21 | import torch.nn as nn 22 | from torch.nn.modules.batchnorm import _BatchNorm 23 | 24 | __all__ = ["LayerNorm2d", "build_norm", "get_norm_name", "reset_bn", "remove_bn", "set_norm_eps"] 25 | 26 | 27 | class LayerNorm2d(nn.LayerNorm): 28 | rmsnorm = False 29 | 30 | def forward(self, x: torch.Tensor) -> torch.Tensor: 31 | out = x if LayerNorm2d.rmsnorm else x - torch.mean(x, dim=1, keepdim=True) 32 | out = out / torch.sqrt(torch.square(out).mean(dim=1, keepdim=True) + self.eps) 33 | if self.elementwise_affine: 34 | out = out * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1) 35 | return out 36 | 37 | def extra_repr(self) -> str: 38 | return f"{self.normalized_shape}, eps={self.eps}, elementwise_affine={self.elementwise_affine}, rmsnorm={self.rmsnorm}" 39 | 40 | 41 | # register normalization function here 42 | # name: module, kwargs with default values 43 | REGISTERED_NORMALIZATION_DICT: dict[str, tuple[type, dict[str, any]]] = { 44 | "bn2d": (nn.BatchNorm2d, {"num_features": None, "eps": 1e-5, "momentum": 0.1, "affine": True}), 45 | "syncbn": (nn.SyncBatchNorm, {"num_features": None, "eps": 1e-5, "momentum": 0.1, "affine": True}), 46 | "ln": (nn.LayerNorm, {"normalized_shape": None, "eps": 1e-5, "elementwise_affine": True}), 47 | "ln2d": (LayerNorm2d, {"normalized_shape": None, "eps": 1e-5, "elementwise_affine": True}), 48 | } 49 | 50 | 51 | def build_norm(name="bn2d", num_features=None, affine=True, **kwargs) -> nn.Module or None: 52 | if name in ["ln", "ln2d"]: 53 | kwargs["normalized_shape"] = num_features 54 | kwargs["elementwise_affine"] = affine 55 | else: 56 | kwargs["num_features"] = num_features 57 | kwargs["affine"] = affine 58 | if name in REGISTERED_NORMALIZATION_DICT: 59 | norm_cls, default_args = copy.deepcopy(REGISTERED_NORMALIZATION_DICT[name]) 60 | for key in default_args: 61 | if key in kwargs: 62 | default_args[key] = kwargs[key] 63 | return norm_cls(**default_args) 64 | elif name is None or name.lower() == "none": 65 | return None 66 | else: 67 | raise ValueError("do not support: %s" % name) 68 | 69 | 70 | def get_norm_name(norm: nn.Module or None) -> str or None: 71 | if norm is None: 72 | return None 73 | module2name = {} 74 | for key, config in REGISTERED_NORMALIZATION_DICT.items(): 75 | module2name[config[0].__name__] = key 76 | return module2name.get(type(norm).__name__, "unknown") 77 | 78 | 79 | def reset_bn( 80 | model: nn.Module, 81 | data_loader: list, 82 | sync=True, 83 | progress_bar=False, 84 | ) -> None: 85 | import copy 86 | 87 | import torch.nn.functional as F 88 | from packages.apps.utils import AverageMeter, is_master, sync_tensor 89 | from packages.models.utils import get_device, list_join 90 | from tqdm import tqdm 91 | 92 | bn_mean = {} 93 | bn_var = {} 94 | 95 | tmp_model = copy.deepcopy(model) 96 | for name, m in tmp_model.named_modules(): 97 | if isinstance(m, _BatchNorm): 98 | bn_mean[name] = AverageMeter(is_distributed=False) 99 | bn_var[name] = AverageMeter(is_distributed=False) 100 | 101 | def new_forward(bn, mean_est, var_est): 102 | def lambda_forward(x): 103 | x = x.contiguous() 104 | if sync: 105 | batch_mean = x.mean(0, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) # 1, C, 1, 1 106 | batch_mean = sync_tensor(batch_mean, reduce="cat") 107 | batch_mean = torch.mean(batch_mean, dim=0, keepdim=True) 108 | 109 | batch_var = (x - batch_mean) * (x - batch_mean) 110 | batch_var = batch_var.mean(0, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) 111 | batch_var = sync_tensor(batch_var, reduce="cat") 112 | batch_var = torch.mean(batch_var, dim=0, keepdim=True) 113 | else: 114 | batch_mean = x.mean(0, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) # 1, C, 1, 1 115 | batch_var = (x - batch_mean) * (x - batch_mean) 116 | batch_var = batch_var.mean(0, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) 117 | 118 | batch_mean = torch.squeeze(batch_mean) 119 | batch_var = torch.squeeze(batch_var) 120 | 121 | mean_est.update(batch_mean.data, x.size(0)) 122 | var_est.update(batch_var.data, x.size(0)) 123 | 124 | # bn forward using calculated mean & var 125 | _feature_dim = batch_mean.shape[0] 126 | return F.batch_norm( 127 | x, 128 | batch_mean, 129 | batch_var, 130 | bn.weight[:_feature_dim], 131 | bn.bias[:_feature_dim], 132 | False, 133 | 0.0, 134 | bn.eps, 135 | ) 136 | 137 | return lambda_forward 138 | 139 | m.forward = new_forward(m, bn_mean[name], bn_var[name]) 140 | 141 | # skip if there is no batch normalization layers in the network 142 | if len(bn_mean) == 0: 143 | return 144 | 145 | tmp_model.eval() 146 | with torch.inference_mode(): 147 | with tqdm(total=len(data_loader), desc="reset bn", disable=not progress_bar or not is_master()) as t: 148 | for images in data_loader: 149 | images = images.to(get_device(tmp_model)) 150 | tmp_model(images) 151 | t.set_postfix( 152 | { 153 | "bs": images.size(0), 154 | "res": list_join(images.shape[-2:], "x"), 155 | } 156 | ) 157 | t.update() 158 | 159 | for name, m in model.named_modules(): 160 | if name in bn_mean and bn_mean[name].count > 0: 161 | feature_dim = bn_mean[name].avg.size(0) 162 | assert isinstance(m, _BatchNorm) 163 | m.running_mean.data[:feature_dim].copy_(bn_mean[name].avg) 164 | m.running_var.data[:feature_dim].copy_(bn_var[name].avg) 165 | 166 | 167 | def remove_bn(model: nn.Module) -> None: 168 | for m in model.modules(): 169 | if isinstance(m, _BatchNorm): 170 | m.weight = m.bias = None 171 | m.forward = lambda x: x 172 | 173 | 174 | def set_norm_eps(model: nn.Module, eps: float or None = None, momentum: float or None = None) -> None: 175 | for m in model.modules(): 176 | if isinstance(m, (nn.GroupNorm, nn.LayerNorm, _BatchNorm)): 177 | if eps is not None: 178 | m.eps = eps 179 | if momentum is not None: 180 | m.momentum = momentum 181 | 182 | 183 | class RMSNorm(torch.nn.Module): 184 | def __init__(self, dim: int, scale_factor=1.0, eps: float = 1e-6): 185 | """ 186 | Initialize the RMSNorm normalization layer. 187 | 188 | Args: 189 | dim (int): The dimension of the input tensor. 190 | eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. 191 | 192 | Attributes: 193 | eps (float): A small value added to the denominator for numerical stability. 194 | weight (nn.Parameter): Learnable scaling parameter. 195 | 196 | """ 197 | super().__init__() 198 | self.eps = eps 199 | self.weight = nn.Parameter(torch.ones(dim) * scale_factor) 200 | 201 | def _norm(self, x): 202 | """ 203 | Apply the RMSNorm normalization to the input tensor. 204 | 205 | Args: 206 | x (torch.Tensor): The input tensor. 207 | 208 | Returns: 209 | torch.Tensor: The normalized tensor. 210 | 211 | """ 212 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 213 | 214 | def forward(self, x): 215 | """ 216 | Forward pass through the RMSNorm layer. 217 | 218 | Args: 219 | x (torch.Tensor): The input tensor. 220 | 221 | Returns: 222 | torch.Tensor: The output tensor after applying RMSNorm. 223 | 224 | """ 225 | return (self.weight * self._norm(x.float())).type_as(x) 226 | -------------------------------------------------------------------------------- /HunYuanDiT/models/posemb_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from typing import Union 4 | 5 | 6 | def _to_tuple(x): 7 | if isinstance(x, int): 8 | return x, x 9 | else: 10 | return x 11 | 12 | 13 | def get_fill_resize_and_crop(src, tgt): # src 来源的分辨率 tgt base 分辨率 14 | th, tw = _to_tuple(tgt) 15 | h, w = _to_tuple(src) 16 | 17 | tr = th / tw # base 分辨率 18 | r = h / w # 目标分辨率 19 | 20 | # resize 21 | if r > tr: 22 | resize_height = th 23 | resize_width = int(round(th / h * w)) 24 | else: 25 | resize_width = tw 26 | resize_height = int(round(tw / w * h)) # 根据base分辨率,将目标分辨率resize下来 27 | 28 | crop_top = int(round((th - resize_height) / 2.0)) 29 | crop_left = int(round((tw - resize_width) / 2.0)) 30 | 31 | return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) 32 | 33 | 34 | def get_meshgrid(start, *args): 35 | if len(args) == 0: 36 | # start is grid_size 37 | num = _to_tuple(start) 38 | start = (0, 0) 39 | stop = num 40 | elif len(args) == 1: 41 | # start is start, args[0] is stop, step is 1 42 | start = _to_tuple(start) 43 | stop = _to_tuple(args[0]) 44 | num = (stop[0] - start[0], stop[1] - start[1]) 45 | elif len(args) == 2: 46 | # start is start, args[0] is stop, args[1] is num 47 | start = _to_tuple(start) # 左上角 eg: 12,0 48 | stop = _to_tuple(args[0]) # 右下角 eg: 20,32 49 | num = _to_tuple(args[1]) # 目标大小 eg: 32,124 50 | else: 51 | raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}") 52 | 53 | grid_h = np.linspace(start[0], stop[0], num[0], endpoint=False, dtype=np.float32) # 12-20 中间差值32份 0-32 中间差值124份 54 | grid_w = np.linspace(start[1], stop[1], num[1], endpoint=False, dtype=np.float32) 55 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 56 | grid = np.stack(grid, axis=0) # [2, W, H] 57 | return grid 58 | 59 | ################################################################################# 60 | # Sine/Cosine Positional Embedding Functions # 61 | ################################################################################# 62 | # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py 63 | 64 | def get_2d_sincos_pos_embed(embed_dim, start, *args, cls_token=False, extra_tokens=0): 65 | """ 66 | grid_size: int of the grid height and width 67 | return: 68 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 69 | """ 70 | grid = get_meshgrid(start, *args) # [2, H, w] 71 | # grid_h = np.arange(grid_size, dtype=np.float32) 72 | # grid_w = np.arange(grid_size, dtype=np.float32) 73 | # grid = np.meshgrid(grid_w, grid_h) # here w goes first 74 | # grid = np.stack(grid, axis=0) # [2, W, H] 75 | 76 | grid = grid.reshape([2, 1, *grid.shape[1:]]) 77 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 78 | if cls_token and extra_tokens > 0: 79 | pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) 80 | return pos_embed 81 | 82 | 83 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 84 | assert embed_dim % 2 == 0 85 | 86 | # use half of dimensions to encode grid_h 87 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 88 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 89 | 90 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 91 | return emb 92 | 93 | 94 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 95 | """ 96 | embed_dim: output dimension for each position 97 | pos: a list of positions to be encoded: size (W,H) 98 | out: (M, D) 99 | """ 100 | assert embed_dim % 2 == 0 101 | omega = np.arange(embed_dim // 2, dtype=np.float64) 102 | omega /= embed_dim / 2. 103 | omega = 1. / 10000**omega # (D/2,) 104 | 105 | pos = pos.reshape(-1) # (M,) 106 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 107 | 108 | emb_sin = np.sin(out) # (M, D/2) 109 | emb_cos = np.cos(out) # (M, D/2) 110 | 111 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 112 | return emb 113 | 114 | 115 | ################################################################################# 116 | # Rotary Positional Embedding Functions # 117 | ################################################################################# 118 | # https://github.com/facebookresearch/llama/blob/main/llama/model.py#L443 119 | 120 | def get_2d_rotary_pos_embed(embed_dim, start, *args, use_real=True): 121 | """ 122 | This is a 2d version of precompute_freqs_cis, which is a RoPE for image tokens with 2d structure. 123 | 124 | Parameters 125 | ---------- 126 | embed_dim: int 127 | embedding dimension size 128 | start: int or tuple of int 129 | If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop, step is 1; 130 | If len(args) == 2, start is start, args[0] is stop, args[1] is num. 131 | use_real: bool 132 | If True, return real part and imaginary part separately. Otherwise, return complex numbers. 133 | 134 | Returns 135 | ------- 136 | pos_embed: torch.Tensor 137 | [HW, D/2] 138 | """ 139 | grid = get_meshgrid(start, *args) # [2, H, w] 140 | grid = grid.reshape([2, 1, *grid.shape[1:]]) # 返回一个采样矩阵 分辨率与目标分辨率一致 141 | pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real) 142 | return pos_embed 143 | 144 | 145 | def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False): 146 | assert embed_dim % 4 == 0 147 | 148 | # use half of dimensions to encode grid_h 149 | emb_h = get_1d_rotary_pos_embed(embed_dim // 2, grid[0].reshape(-1), use_real=use_real) # (H*W, D/4) 150 | emb_w = get_1d_rotary_pos_embed(embed_dim // 2, grid[1].reshape(-1), use_real=use_real) # (H*W, D/4) 151 | 152 | if use_real: 153 | cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D/2) 154 | sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D/2) 155 | return cos, sin 156 | else: 157 | emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2) 158 | return emb 159 | 160 | 161 | def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False): 162 | """ 163 | Precompute the frequency tensor for complex exponentials (cis) with given dimensions. 164 | 165 | This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' 166 | and the end index 'end'. The 'theta' parameter scales the frequencies. 167 | The returned tensor contains complex values in complex64 data type. 168 | 169 | Args: 170 | dim (int): Dimension of the frequency tensor. 171 | pos (np.ndarray, int): Position indices for the frequency tensor. [S] or scalar 172 | theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. 173 | use_real (bool, optional): If True, return real part and imaginary part separately. 174 | Otherwise, return complex numbers. 175 | 176 | Returns: 177 | torch.Tensor: Precomputed frequency tensor with complex exponentials. [S, D/2] 178 | 179 | """ 180 | if isinstance(pos, int): 181 | pos = np.arange(pos) 182 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2] 183 | t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S] 184 | freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2] 185 | if use_real: 186 | freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D] 187 | freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D] 188 | return freqs_cos, freqs_sin 189 | else: 190 | freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] 191 | return freqs_cis 192 | 193 | 194 | 195 | def calc_sizes(rope_img, patch_size, th, tw): 196 | """ 计算 RoPE 的尺寸. """ 197 | if rope_img == 'extend': 198 | # 拓展模式 199 | sub_args = [(th, tw)] 200 | elif rope_img.startswith('base'): 201 | # 基于一个尺寸, 其他尺寸插值获得. 202 | base_size = int(rope_img[4:]) // 8 // patch_size # 基于512作为base,其他根据512差值得到 203 | start, stop = get_fill_resize_and_crop((th, tw), base_size) # 需要在32x32里面 crop的左上角和右下角 204 | sub_args = [start, stop, (th, tw)] 205 | else: 206 | raise ValueError(f"Unknown rope_img: {rope_img}") 207 | return sub_args 208 | 209 | 210 | def init_image_posemb(rope_img, 211 | resolutions, 212 | patch_size, 213 | hidden_size, 214 | num_heads, 215 | log_fn, 216 | rope_real=True, 217 | ): 218 | freqs_cis_img = {} 219 | for reso in resolutions: 220 | th, tw = reso.height // 8 // patch_size, reso.width // 8 // patch_size 221 | sub_args = calc_sizes(rope_img, patch_size, th, tw) # [左上角, 右下角, 目标高宽] 需要在32x32里面 crop的左上角和右下角 222 | freqs_cis_img[str(reso)] = get_2d_rotary_pos_embed(hidden_size // num_heads, *sub_args, use_real=rope_real) 223 | log_fn(f" Using image RoPE ({rope_img}) ({'real' if rope_real else 'complex'}): {sub_args} | ({reso}) " 224 | f"{freqs_cis_img[str(reso)][0].shape if rope_real else freqs_cis_img[str(reso)].shape}") 225 | return freqs_cis_img 226 | -------------------------------------------------------------------------------- /Sana/diffusers_convert.py: -------------------------------------------------------------------------------- 1 | # For using the diffusers format weights 2 | # Based on the original ComfyUI function + 3 | # https://github.com/NVlabs/Sana/blob/main/tools/convert_sana_to_diffusers.py 4 | import torch 5 | 6 | 7 | def get_depth(state_dict): 8 | return sum(key.endswith('.attn1.to_k.bias') for key in state_dict.keys()) 9 | 10 | def get_lora_depth(state_dict): 11 | cnt = max([ 12 | sum(key.endswith('.attn1.to_k.lora_A.weight') for key in state_dict.keys()), 13 | sum(key.endswith('_attn1_to_k.lora_A.weight') for key in state_dict.keys()), 14 | sum(key.endswith('.attn1.to_k.lora_up.weight') for key in state_dict.keys()), 15 | sum(key.endswith('_attn1_to_k.lora_up.weight') for key in state_dict.keys()), 16 | ]) 17 | assert cnt > 0, "Unable to detect model depth!" 18 | return cnt 19 | 20 | def get_conversion_map(state_dict): 21 | conversion_map = [ # main SD conversion map (Sana reference, HF Diffusers) 22 | # Patch embeddings 23 | ("x_embedder.proj.weight", "pos_embed.proj.weight"), 24 | ("x_embedder.proj.bias", "pos_embed.proj.bias"), 25 | # Caption projection 26 | ("y_embedder.y_embedding", "caption_projection.y_embedding"), 27 | ("y_embedder.y_proj.fc1.weight", "caption_projection.linear_1.weight"), 28 | ("y_embedder.y_proj.fc1.bias", "caption_projection.linear_1.bias"), 29 | ("y_embedder.y_proj.fc2.weight", "caption_projection.linear_2.weight"), 30 | ("y_embedder.y_proj.fc2.bias", "caption_projection.linear_2.bias"), 31 | # AdaLN-single LN 32 | ("t_embedder.mlp.0.weight", "adaln_single.emb.timestep_embedder.linear_1.weight"), 33 | ("t_embedder.mlp.0.bias", "adaln_single.emb.timestep_embedder.linear_1.bias"), 34 | ("t_embedder.mlp.2.weight", "adaln_single.emb.timestep_embedder.linear_2.weight"), 35 | ("t_embedder.mlp.2.bias", "adaln_single.emb.timestep_embedder.linear_2.bias"), 36 | # Shared norm 37 | ("t_block.1.weight", "adaln_single.linear.weight"), 38 | ("t_block.1.bias", "adaln_single.linear.bias"), 39 | # Final block 40 | ("final_layer.linear.weight", "proj_out.weight"), 41 | ("final_layer.linear.bias", "proj_out.bias"), 42 | ("final_layer.scale_shift_table", "scale_shift_table"), 43 | ] 44 | 45 | # Add actual transformer blocks 46 | for depth in range(get_depth(state_dict)): 47 | # Transformer blocks 48 | conversion_map += [ 49 | (f"blocks.{depth}.scale_shift_table", f"transformer_blocks.{depth}.scale_shift_table"), 50 | # Projection 51 | (f"blocks.{depth}.attn.proj.weight", f"transformer_blocks.{depth}.attn1.to_out.0.weight"), 52 | (f"blocks.{depth}.attn.proj.bias", f"transformer_blocks.{depth}.attn1.to_out.0.bias"), 53 | # Feed-forward 54 | (f"blocks.{depth}.mlp.fc1.weight", f"transformer_blocks.{depth}.ff.net.0.proj.weight"), 55 | (f"blocks.{depth}.mlp.fc1.bias", f"transformer_blocks.{depth}.ff.net.0.proj.bias"), 56 | (f"blocks.{depth}.mlp.fc2.weight", f"transformer_blocks.{depth}.ff.net.2.weight"), 57 | (f"blocks.{depth}.mlp.fc2.bias", f"transformer_blocks.{depth}.ff.net.2.bias"), 58 | # Cross-attention (proj) 59 | (f"blocks.{depth}.cross_attn.proj.weight" ,f"transformer_blocks.{depth}.attn2.to_out.0.weight"), 60 | (f"blocks.{depth}.cross_attn.proj.bias" ,f"transformer_blocks.{depth}.attn2.to_out.0.bias"), 61 | ] 62 | return conversion_map 63 | 64 | def find_prefix(state_dict, target_key): 65 | prefix = "" 66 | for k in state_dict.keys(): 67 | if k.endswith(target_key): 68 | prefix = k.split(target_key)[0] 69 | break 70 | return prefix 71 | 72 | def convert_state_dict(state_dict): 73 | cmap = get_conversion_map(state_dict) 74 | 75 | missing = [k for k,v in cmap if v not in state_dict] 76 | new_state_dict = {k: state_dict[v] for k,v in cmap if k not in missing} 77 | matched = list(v for k,v in cmap if v in state_dict.keys()) 78 | 79 | for depth in range(get_depth(state_dict)): 80 | for wb in ["weight", "bias"]: 81 | # Self Attention 82 | key = lambda a: f"transformer_blocks.{depth}.attn1.to_{a}.{wb}" 83 | new_state_dict[f"blocks.{depth}.attn.qkv.{wb}"] = torch.cat(( 84 | state_dict[key('q')], state_dict[key('k')], state_dict[key('v')] 85 | ), dim=0) 86 | matched += [key('q'), key('k'), key('v')] 87 | 88 | # Cross-attention (linear) 89 | key = lambda a: f"transformer_blocks.{depth}.attn2.to_{a}.{wb}" 90 | new_state_dict[f"blocks.{depth}.cross_attn.q_linear.{wb}"] = state_dict[key('q')] 91 | new_state_dict[f"blocks.{depth}.cross_attn.kv_linear.{wb}"] = torch.cat(( 92 | state_dict[key('k')], state_dict[key('v')] 93 | ), dim=0) 94 | matched += [key('q'), key('k'), key('v')] 95 | 96 | if len(matched) < len(state_dict): 97 | print(f"Sana: UNET conversion has leftover keys! ({len(matched)} vs {len(state_dict)})") 98 | print(list( set(state_dict.keys()) - set(matched) )) 99 | 100 | if len(missing) > 0: 101 | print(f"Sana: UNET conversion has missing keys!") 102 | print(missing) 103 | 104 | return new_state_dict 105 | 106 | # Same as above but for LoRA weights: 107 | # TODO: Not used yet, need to support LoRA for Sana 108 | def convert_lora_state_dict(state_dict, peft=True): 109 | # koyha 110 | rep_ak = lambda x: x.replace(".weight", ".lora_down.weight") 111 | rep_bk = lambda x: x.replace(".weight", ".lora_up.weight") 112 | rep_pk = lambda x: x.replace(".weight", ".alpha") 113 | if peft: # peft 114 | rep_ap = lambda x: x.replace(".weight", ".lora_A.weight") 115 | rep_bp = lambda x: x.replace(".weight", ".lora_B.weight") 116 | rep_pp = lambda x: x.replace(".weight", ".alpha") 117 | 118 | prefix = find_prefix(state_dict, "adaln_single.linear.lora_A.weight") 119 | state_dict = {k[len(prefix):]:v for k,v in state_dict.items()} 120 | else: # OneTrainer 121 | rep_ap = lambda x: x.replace(".", "_")[:-7] + ".lora_down.weight" 122 | rep_bp = lambda x: x.replace(".", "_")[:-7] + ".lora_up.weight" 123 | rep_pp = lambda x: x.replace(".", "_")[:-7] + ".alpha" 124 | 125 | prefix = "lora_transformer_" 126 | gemma_marker = "lora_te_encoder" 127 | gemma_keys = [] 128 | for key in list(state_dict.keys()): 129 | if key.startswith(prefix): 130 | state_dict[key[len(prefix):]] = state_dict.pop(key) 131 | elif gemma_marker in key: 132 | gemma_keys.append(state_dict.pop(key)) 133 | if len(gemma_keys) > 0: 134 | print(f"Text Encoder not supported for Sana LoRA, ignoring {len(gemma_keys)} keys") 135 | 136 | cmap = [] 137 | cmap_unet = get_conversion_map(state_dict) # todo: 512 model 138 | for k, v in cmap_unet: 139 | if v.endswith(".weight"): 140 | cmap.append((rep_ak(k), rep_ap(v))) 141 | cmap.append((rep_bk(k), rep_bp(v))) 142 | if not peft: 143 | cmap.append((rep_pk(k), rep_pp(v))) 144 | 145 | missing = [k for k,v in cmap if v not in state_dict] 146 | new_state_dict = {k: state_dict[v] for k,v in cmap if k not in missing} 147 | matched = list(v for k,v in cmap if v in state_dict.keys()) 148 | 149 | lora_depth = get_lora_depth(state_dict) 150 | for fp, fk in ((rep_ap, rep_ak),(rep_bp, rep_bk)): 151 | for depth in range(lora_depth): 152 | # Self Attention 153 | key = lambda a: fp(f"transformer_blocks.{depth}.attn1.to_{a}.weight") 154 | new_state_dict[fk(f"blocks.{depth}.attn.qkv.weight")] = torch.cat(( 155 | state_dict[key('q')], state_dict[key('k')], state_dict[key('v')] 156 | ), dim=0) 157 | 158 | matched += [key('q'), key('k'), key('v')] 159 | if not peft: 160 | akey = lambda a: rep_pp(f"transformer_blocks.{depth}.attn1.to_{a}.weight") 161 | new_state_dict[rep_pk((f"blocks.{depth}.attn.qkv.weight"))] = state_dict[akey("q")] 162 | matched += [akey('q'), akey('k'), akey('v')] 163 | 164 | # Self Attention projection? 165 | key = lambda a: fp(f"transformer_blocks.{depth}.attn1.to_{a}.weight") 166 | new_state_dict[fk(f"blocks.{depth}.attn.proj.weight")] = state_dict[key('out.0')] 167 | matched += [key('out.0')] 168 | 169 | # Cross-attention (linear) 170 | key = lambda a: fp(f"transformer_blocks.{depth}.attn2.to_{a}.weight") 171 | new_state_dict[fk(f"blocks.{depth}.cross_attn.q_linear.weight")] = state_dict[key('q')] 172 | new_state_dict[fk(f"blocks.{depth}.cross_attn.kv_linear.weight")] = torch.cat(( 173 | state_dict[key('k')], state_dict[key('v')] 174 | ), dim=0) 175 | matched += [key('q'), key('k'), key('v')] 176 | if not peft: 177 | akey = lambda a: rep_pp(f"transformer_blocks.{depth}.attn2.to_{a}.weight") 178 | new_state_dict[rep_pk((f"blocks.{depth}.cross_attn.q_linear.weight"))] = state_dict[akey("q")] 179 | new_state_dict[rep_pk((f"blocks.{depth}.cross_attn.kv_linear.weight"))] = state_dict[akey("k")] 180 | matched += [akey('q'), akey('k'), akey('v')] 181 | 182 | # Cross Attention projection? 183 | key = lambda a: fp(f"transformer_blocks.{depth}.attn2.to_{a}.weight") 184 | new_state_dict[fk(f"blocks.{depth}.cross_attn.proj.weight")] = state_dict[key('out.0')] 185 | matched += [key('out.0')] 186 | 187 | try: 188 | key = fp(f"transformer_blocks.{depth}.ff.net.0.proj.weight") 189 | new_state_dict[fk(f"blocks.{depth}.mlp.fc1.weight")] = state_dict[key] 190 | matched += [key] 191 | except KeyError: 192 | pass 193 | 194 | try: 195 | key = fp(f"transformer_blocks.{depth}.ff.net.2.weight") 196 | new_state_dict[fk(f"blocks.{depth}.mlp.fc2.weight")] = state_dict[key] 197 | matched += [key] 198 | except KeyError: 199 | pass 200 | 201 | if len(matched) < len(state_dict): 202 | print(f"Sana: LoRA conversion has leftover keys! ({len(matched)} vs {len(state_dict)})") 203 | print(list( set(state_dict.keys()) - set(matched) )) 204 | 205 | if len(missing) > 0: 206 | print(f"Sana: LoRA conversion has missing keys! (probably)") 207 | print(missing) 208 | 209 | return new_state_dict 210 | -------------------------------------------------------------------------------- /utils/IPEX/attention.py: -------------------------------------------------------------------------------- 1 | # Code lifted from https://github.com/Disty0/ipex_to_cuda/blob/main/attention.py 2 | # Thanks to Disty0! 3 | 4 | import os 5 | import torch 6 | import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import 7 | from functools import cache 8 | 9 | # pylint: disable=protected-access, missing-function-docstring, line-too-long 10 | 11 | # ARC GPUs can't allocate more than 4GB to a single block so we slice the attetion layers 12 | 13 | sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 6)) 14 | attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4)) 15 | 16 | # Find something divisible with the input_tokens 17 | @cache 18 | def find_slice_size(slice_size, slice_block_size): 19 | while (slice_size * slice_block_size) > attention_slice_rate: 20 | slice_size = slice_size // 2 21 | if slice_size <= 1: 22 | slice_size = 1 23 | break 24 | return slice_size 25 | 26 | # Find slice sizes for SDPA 27 | @cache 28 | def find_sdpa_slice_sizes(query_shape, query_element_size): 29 | if len(query_shape) == 3: 30 | batch_size_attention, query_tokens, shape_three = query_shape 31 | shape_four = 1 32 | else: 33 | batch_size_attention, query_tokens, shape_three, shape_four = query_shape 34 | 35 | slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size 36 | block_size = batch_size_attention * slice_block_size 37 | 38 | split_slice_size = batch_size_attention 39 | split_2_slice_size = query_tokens 40 | split_3_slice_size = shape_three 41 | 42 | do_split = False 43 | do_split_2 = False 44 | do_split_3 = False 45 | 46 | if block_size > sdpa_slice_trigger_rate: 47 | do_split = True 48 | split_slice_size = find_slice_size(split_slice_size, slice_block_size) 49 | if split_slice_size * slice_block_size > attention_slice_rate: 50 | slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size 51 | do_split_2 = True 52 | split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size) 53 | if split_2_slice_size * slice_2_block_size > attention_slice_rate: 54 | slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size 55 | do_split_3 = True 56 | split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size) 57 | 58 | return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size 59 | 60 | # Find slice sizes for BMM 61 | @cache 62 | def find_bmm_slice_sizes(input_shape, input_element_size, mat2_shape): 63 | batch_size_attention, input_tokens, mat2_atten_shape = input_shape[0], input_shape[1], mat2_shape[2] 64 | slice_block_size = input_tokens * mat2_atten_shape / 1024 / 1024 * input_element_size 65 | block_size = batch_size_attention * slice_block_size 66 | 67 | split_slice_size = batch_size_attention 68 | split_2_slice_size = input_tokens 69 | split_3_slice_size = mat2_atten_shape 70 | 71 | do_split = False 72 | do_split_2 = False 73 | do_split_3 = False 74 | 75 | if block_size > attention_slice_rate: 76 | do_split = True 77 | split_slice_size = find_slice_size(split_slice_size, slice_block_size) 78 | if split_slice_size * slice_block_size > attention_slice_rate: 79 | slice_2_block_size = split_slice_size * mat2_atten_shape / 1024 / 1024 * input_element_size 80 | do_split_2 = True 81 | split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size) 82 | if split_2_slice_size * slice_2_block_size > attention_slice_rate: 83 | slice_3_block_size = split_slice_size * split_2_slice_size / 1024 / 1024 * input_element_size 84 | do_split_3 = True 85 | split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size) 86 | 87 | return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size 88 | 89 | 90 | original_torch_bmm = torch.bmm 91 | def torch_bmm_32_bit(input, mat2, *, out=None): 92 | if input.device.type != "xpu": 93 | return original_torch_bmm(input, mat2, out=out) 94 | do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_bmm_slice_sizes(input.shape, input.element_size(), mat2.shape) 95 | 96 | # Slice BMM 97 | if do_split: 98 | batch_size_attention, input_tokens, mat2_atten_shape = input.shape[0], input.shape[1], mat2.shape[2] 99 | hidden_states = torch.zeros(input.shape[0], input.shape[1], mat2.shape[2], device=input.device, dtype=input.dtype) 100 | for i in range(batch_size_attention // split_slice_size): 101 | start_idx = i * split_slice_size 102 | end_idx = (i + 1) * split_slice_size 103 | if do_split_2: 104 | for i2 in range(input_tokens // split_2_slice_size): # pylint: disable=invalid-name 105 | start_idx_2 = i2 * split_2_slice_size 106 | end_idx_2 = (i2 + 1) * split_2_slice_size 107 | if do_split_3: 108 | for i3 in range(mat2_atten_shape // split_3_slice_size): # pylint: disable=invalid-name 109 | start_idx_3 = i3 * split_3_slice_size 110 | end_idx_3 = (i3 + 1) * split_3_slice_size 111 | hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_torch_bmm( 112 | input[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], 113 | mat2[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], 114 | out=out 115 | ) 116 | else: 117 | hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm( 118 | input[start_idx:end_idx, start_idx_2:end_idx_2], 119 | mat2[start_idx:end_idx, start_idx_2:end_idx_2], 120 | out=out 121 | ) 122 | else: 123 | hidden_states[start_idx:end_idx] = original_torch_bmm( 124 | input[start_idx:end_idx], 125 | mat2[start_idx:end_idx], 126 | out=out 127 | ) 128 | torch.xpu.synchronize(input.device) 129 | else: 130 | return original_torch_bmm(input, mat2, out=out) 131 | return hidden_states 132 | 133 | original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention 134 | def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, **kwargs): 135 | if query.device.type != "xpu": 136 | return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs) 137 | do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_sdpa_slice_sizes(query.shape, query.element_size()) 138 | 139 | # Slice SDPA 140 | if do_split: 141 | batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2] 142 | hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype) 143 | for i in range(batch_size_attention // split_slice_size): 144 | start_idx = i * split_slice_size 145 | end_idx = (i + 1) * split_slice_size 146 | if do_split_2: 147 | for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name 148 | start_idx_2 = i2 * split_2_slice_size 149 | end_idx_2 = (i2 + 1) * split_2_slice_size 150 | if do_split_3: 151 | for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name 152 | start_idx_3 = i3 * split_3_slice_size 153 | end_idx_3 = (i3 + 1) * split_3_slice_size 154 | hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_scaled_dot_product_attention( 155 | query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], 156 | key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], 157 | value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], 158 | attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attn_mask is not None else attn_mask, 159 | dropout_p=dropout_p, is_causal=is_causal, **kwargs 160 | ) 161 | else: 162 | hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention( 163 | query[start_idx:end_idx, start_idx_2:end_idx_2], 164 | key[start_idx:end_idx, start_idx_2:end_idx_2], 165 | value[start_idx:end_idx, start_idx_2:end_idx_2], 166 | attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask, 167 | dropout_p=dropout_p, is_causal=is_causal, **kwargs 168 | ) 169 | else: 170 | hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention( 171 | query[start_idx:end_idx], 172 | key[start_idx:end_idx], 173 | value[start_idx:end_idx], 174 | attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask, 175 | dropout_p=dropout_p, is_causal=is_causal, **kwargs 176 | ) 177 | torch.xpu.synchronize(query.device) 178 | else: 179 | return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs) 180 | return hidden_states -------------------------------------------------------------------------------- /PixArt/diffusers_convert.py: -------------------------------------------------------------------------------- 1 | # For using the diffusers format weights 2 | # Based on the original ComfyUI function + 3 | # https://github.com/PixArt-alpha/PixArt-alpha/blob/master/tools/convert_pixart_alpha_to_diffusers.py 4 | import torch 5 | 6 | conversion_map_ms = [ # for multi_scale_train (MS) 7 | # Resolution 8 | ("csize_embedder.mlp.0.weight", "adaln_single.emb.resolution_embedder.linear_1.weight"), 9 | ("csize_embedder.mlp.0.bias", "adaln_single.emb.resolution_embedder.linear_1.bias"), 10 | ("csize_embedder.mlp.2.weight", "adaln_single.emb.resolution_embedder.linear_2.weight"), 11 | ("csize_embedder.mlp.2.bias", "adaln_single.emb.resolution_embedder.linear_2.bias"), 12 | # Aspect ratio 13 | ("ar_embedder.mlp.0.weight", "adaln_single.emb.aspect_ratio_embedder.linear_1.weight"), 14 | ("ar_embedder.mlp.0.bias", "adaln_single.emb.aspect_ratio_embedder.linear_1.bias"), 15 | ("ar_embedder.mlp.2.weight", "adaln_single.emb.aspect_ratio_embedder.linear_2.weight"), 16 | ("ar_embedder.mlp.2.bias", "adaln_single.emb.aspect_ratio_embedder.linear_2.bias"), 17 | ] 18 | 19 | def get_depth(state_dict): 20 | return sum(key.endswith('.attn1.to_k.bias') for key in state_dict.keys()) 21 | 22 | def get_lora_depth(state_dict): 23 | cnt = max([ 24 | sum(key.endswith('.attn1.to_k.lora_A.weight') for key in state_dict.keys()), 25 | sum(key.endswith('_attn1_to_k.lora_A.weight') for key in state_dict.keys()), 26 | sum(key.endswith('.attn1.to_k.lora_up.weight') for key in state_dict.keys()), 27 | sum(key.endswith('_attn1_to_k.lora_up.weight') for key in state_dict.keys()), 28 | ]) 29 | assert cnt > 0, "Unable to detect model depth!" 30 | return cnt 31 | 32 | def get_conversion_map(state_dict): 33 | conversion_map = [ # main SD conversion map (PixArt reference, HF Diffusers) 34 | # Patch embeddings 35 | ("x_embedder.proj.weight", "pos_embed.proj.weight"), 36 | ("x_embedder.proj.bias", "pos_embed.proj.bias"), 37 | # Caption projection 38 | ("y_embedder.y_embedding", "caption_projection.y_embedding"), 39 | ("y_embedder.y_proj.fc1.weight", "caption_projection.linear_1.weight"), 40 | ("y_embedder.y_proj.fc1.bias", "caption_projection.linear_1.bias"), 41 | ("y_embedder.y_proj.fc2.weight", "caption_projection.linear_2.weight"), 42 | ("y_embedder.y_proj.fc2.bias", "caption_projection.linear_2.bias"), 43 | # AdaLN-single LN 44 | ("t_embedder.mlp.0.weight", "adaln_single.emb.timestep_embedder.linear_1.weight"), 45 | ("t_embedder.mlp.0.bias", "adaln_single.emb.timestep_embedder.linear_1.bias"), 46 | ("t_embedder.mlp.2.weight", "adaln_single.emb.timestep_embedder.linear_2.weight"), 47 | ("t_embedder.mlp.2.bias", "adaln_single.emb.timestep_embedder.linear_2.bias"), 48 | # Shared norm 49 | ("t_block.1.weight", "adaln_single.linear.weight"), 50 | ("t_block.1.bias", "adaln_single.linear.bias"), 51 | # Final block 52 | ("final_layer.linear.weight", "proj_out.weight"), 53 | ("final_layer.linear.bias", "proj_out.bias"), 54 | ("final_layer.scale_shift_table", "scale_shift_table"), 55 | ] 56 | 57 | # Add actual transformer blocks 58 | for depth in range(get_depth(state_dict)): 59 | # Transformer blocks 60 | conversion_map += [ 61 | (f"blocks.{depth}.scale_shift_table", f"transformer_blocks.{depth}.scale_shift_table"), 62 | # Projection 63 | (f"blocks.{depth}.attn.proj.weight", f"transformer_blocks.{depth}.attn1.to_out.0.weight"), 64 | (f"blocks.{depth}.attn.proj.bias", f"transformer_blocks.{depth}.attn1.to_out.0.bias"), 65 | # Feed-forward 66 | (f"blocks.{depth}.mlp.fc1.weight", f"transformer_blocks.{depth}.ff.net.0.proj.weight"), 67 | (f"blocks.{depth}.mlp.fc1.bias", f"transformer_blocks.{depth}.ff.net.0.proj.bias"), 68 | (f"blocks.{depth}.mlp.fc2.weight", f"transformer_blocks.{depth}.ff.net.2.weight"), 69 | (f"blocks.{depth}.mlp.fc2.bias", f"transformer_blocks.{depth}.ff.net.2.bias"), 70 | # Cross-attention (proj) 71 | (f"blocks.{depth}.cross_attn.proj.weight" ,f"transformer_blocks.{depth}.attn2.to_out.0.weight"), 72 | (f"blocks.{depth}.cross_attn.proj.bias" ,f"transformer_blocks.{depth}.attn2.to_out.0.bias"), 73 | ] 74 | return conversion_map 75 | 76 | def find_prefix(state_dict, target_key): 77 | prefix = "" 78 | for k in state_dict.keys(): 79 | if k.endswith(target_key): 80 | prefix = k.split(target_key)[0] 81 | break 82 | return prefix 83 | 84 | def convert_state_dict(state_dict): 85 | if "adaln_single.emb.resolution_embedder.linear_1.weight" in state_dict.keys(): 86 | cmap = get_conversion_map(state_dict) + conversion_map_ms 87 | else: 88 | cmap = get_conversion_map(state_dict) 89 | 90 | missing = [k for k,v in cmap if v not in state_dict] 91 | new_state_dict = {k: state_dict[v] for k,v in cmap if k not in missing} 92 | matched = list(v for k,v in cmap if v in state_dict.keys()) 93 | 94 | for depth in range(get_depth(state_dict)): 95 | for wb in ["weight", "bias"]: 96 | # Self Attention 97 | key = lambda a: f"transformer_blocks.{depth}.attn1.to_{a}.{wb}" 98 | new_state_dict[f"blocks.{depth}.attn.qkv.{wb}"] = torch.cat(( 99 | state_dict[key('q')], state_dict[key('k')], state_dict[key('v')] 100 | ), dim=0) 101 | matched += [key('q'), key('k'), key('v')] 102 | 103 | # Cross-attention (linear) 104 | key = lambda a: f"transformer_blocks.{depth}.attn2.to_{a}.{wb}" 105 | new_state_dict[f"blocks.{depth}.cross_attn.q_linear.{wb}"] = state_dict[key('q')] 106 | new_state_dict[f"blocks.{depth}.cross_attn.kv_linear.{wb}"] = torch.cat(( 107 | state_dict[key('k')], state_dict[key('v')] 108 | ), dim=0) 109 | matched += [key('q'), key('k'), key('v')] 110 | 111 | if len(matched) < len(state_dict): 112 | print(f"PixArt: UNET conversion has leftover keys! ({len(matched)} vs {len(state_dict)})") 113 | print(list( set(state_dict.keys()) - set(matched) )) 114 | 115 | if len(missing) > 0: 116 | print(f"PixArt: UNET conversion has missing keys!") 117 | print(missing) 118 | 119 | return new_state_dict 120 | 121 | # Same as above but for LoRA weights: 122 | def convert_lora_state_dict(state_dict, peft=True): 123 | # koyha 124 | rep_ak = lambda x: x.replace(".weight", ".lora_down.weight") 125 | rep_bk = lambda x: x.replace(".weight", ".lora_up.weight") 126 | rep_pk = lambda x: x.replace(".weight", ".alpha") 127 | if peft: # peft 128 | rep_ap = lambda x: x.replace(".weight", ".lora_A.weight") 129 | rep_bp = lambda x: x.replace(".weight", ".lora_B.weight") 130 | rep_pp = lambda x: x.replace(".weight", ".alpha") 131 | 132 | prefix = find_prefix(state_dict, "adaln_single.linear.lora_A.weight") 133 | state_dict = {k[len(prefix):]:v for k,v in state_dict.items()} 134 | else: # OneTrainer 135 | rep_ap = lambda x: x.replace(".", "_")[:-7] + ".lora_down.weight" 136 | rep_bp = lambda x: x.replace(".", "_")[:-7] + ".lora_up.weight" 137 | rep_pp = lambda x: x.replace(".", "_")[:-7] + ".alpha" 138 | 139 | prefix = "lora_transformer_" 140 | t5_marker = "lora_te_encoder" 141 | t5_keys = [] 142 | for key in list(state_dict.keys()): 143 | if key.startswith(prefix): 144 | state_dict[key[len(prefix):]] = state_dict.pop(key) 145 | elif t5_marker in key: 146 | t5_keys.append(state_dict.pop(key)) 147 | if len(t5_keys) > 0: 148 | print(f"Text Encoder not supported for PixArt LoRA, ignoring {len(t5_keys)} keys") 149 | 150 | cmap = [] 151 | cmap_unet = get_conversion_map(state_dict) + conversion_map_ms # todo: 512 model 152 | for k, v in cmap_unet: 153 | if v.endswith(".weight"): 154 | cmap.append((rep_ak(k), rep_ap(v))) 155 | cmap.append((rep_bk(k), rep_bp(v))) 156 | if not peft: 157 | cmap.append((rep_pk(k), rep_pp(v))) 158 | 159 | missing = [k for k,v in cmap if v not in state_dict] 160 | new_state_dict = {k: state_dict[v] for k,v in cmap if k not in missing} 161 | matched = list(v for k,v in cmap if v in state_dict.keys()) 162 | 163 | lora_depth = get_lora_depth(state_dict) 164 | for fp, fk in ((rep_ap, rep_ak),(rep_bp, rep_bk)): 165 | for depth in range(lora_depth): 166 | # Self Attention 167 | key = lambda a: fp(f"transformer_blocks.{depth}.attn1.to_{a}.weight") 168 | new_state_dict[fk(f"blocks.{depth}.attn.qkv.weight")] = torch.cat(( 169 | state_dict[key('q')], state_dict[key('k')], state_dict[key('v')] 170 | ), dim=0) 171 | 172 | matched += [key('q'), key('k'), key('v')] 173 | if not peft: 174 | akey = lambda a: rep_pp(f"transformer_blocks.{depth}.attn1.to_{a}.weight") 175 | new_state_dict[rep_pk((f"blocks.{depth}.attn.qkv.weight"))] = state_dict[akey("q")] 176 | matched += [akey('q'), akey('k'), akey('v')] 177 | 178 | # Self Attention projection? 179 | key = lambda a: fp(f"transformer_blocks.{depth}.attn1.to_{a}.weight") 180 | new_state_dict[fk(f"blocks.{depth}.attn.proj.weight")] = state_dict[key('out.0')] 181 | matched += [key('out.0')] 182 | 183 | # Cross-attention (linear) 184 | key = lambda a: fp(f"transformer_blocks.{depth}.attn2.to_{a}.weight") 185 | new_state_dict[fk(f"blocks.{depth}.cross_attn.q_linear.weight")] = state_dict[key('q')] 186 | new_state_dict[fk(f"blocks.{depth}.cross_attn.kv_linear.weight")] = torch.cat(( 187 | state_dict[key('k')], state_dict[key('v')] 188 | ), dim=0) 189 | matched += [key('q'), key('k'), key('v')] 190 | if not peft: 191 | akey = lambda a: rep_pp(f"transformer_blocks.{depth}.attn2.to_{a}.weight") 192 | new_state_dict[rep_pk((f"blocks.{depth}.cross_attn.q_linear.weight"))] = state_dict[akey("q")] 193 | new_state_dict[rep_pk((f"blocks.{depth}.cross_attn.kv_linear.weight"))] = state_dict[akey("k")] 194 | matched += [akey('q'), akey('k'), akey('v')] 195 | 196 | # Cross Attention projection? 197 | key = lambda a: fp(f"transformer_blocks.{depth}.attn2.to_{a}.weight") 198 | new_state_dict[fk(f"blocks.{depth}.cross_attn.proj.weight")] = state_dict[key('out.0')] 199 | matched += [key('out.0')] 200 | 201 | try: 202 | key = fp(f"transformer_blocks.{depth}.ff.net.0.proj.weight") 203 | new_state_dict[fk(f"blocks.{depth}.mlp.fc1.weight")] = state_dict[key] 204 | matched += [key] 205 | except KeyError: 206 | pass 207 | 208 | try: 209 | key = fp(f"transformer_blocks.{depth}.ff.net.2.weight") 210 | new_state_dict[fk(f"blocks.{depth}.mlp.fc2.weight")] = state_dict[key] 211 | matched += [key] 212 | except KeyError: 213 | pass 214 | 215 | if len(matched) < len(state_dict): 216 | print(f"PixArt: LoRA conversion has leftover keys! ({len(matched)} vs {len(state_dict)})") 217 | print(list( set(state_dict.keys()) - set(matched) )) 218 | 219 | if len(missing) > 0: 220 | print(f"PixArt: LoRA conversion has missing keys! (probably)") 221 | print(missing) 222 | 223 | return new_state_dict 224 | -------------------------------------------------------------------------------- /T5/t5v11.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from comfyui CLIP code. 3 | https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/sd1_clip.py 4 | """ 5 | 6 | import os 7 | 8 | from transformers import T5Tokenizer, T5EncoderModel, T5Config, modeling_utils 9 | import torch 10 | import traceback 11 | import zipfile 12 | from comfy import model_management 13 | 14 | from comfy.sd1_clip import parse_parentheses, token_weights, escape_important, unescape_important, safe_load_embed_zip, expand_directory_list, load_embed 15 | 16 | class T5v11Model(torch.nn.Module): 17 | def __init__(self, textmodel_ver="xxl", textmodel_json_config=None, textmodel_path=None, device="cpu", max_length=120, freeze=True, dtype=None): 18 | super().__init__() 19 | 20 | self.num_layers = 24 21 | self.max_length = max_length 22 | self.bnb = False 23 | 24 | if textmodel_path is not None: 25 | model_args = {} 26 | model_args["low_cpu_mem_usage"] = True # Don't take 2x system ram on cpu 27 | if dtype == "bnb8bit": 28 | self.bnb = True 29 | model_args["load_in_8bit"] = True 30 | elif dtype == "bnb4bit": 31 | self.bnb = True 32 | model_args["load_in_4bit"] = True 33 | else: 34 | if dtype: model_args["torch_dtype"] = dtype 35 | self.bnb = False 36 | # second GPU offload hack part 2 37 | if device.startswith("cuda"): 38 | model_args["device_map"] = device 39 | print(f"Loading T5 from '{textmodel_path}'") 40 | self.transformer = T5EncoderModel.from_pretrained(textmodel_path, **model_args) 41 | else: 42 | if textmodel_json_config is None: 43 | textmodel_json_config = os.path.join( 44 | os.path.dirname(os.path.realpath(__file__)), 45 | f"t5v11-{textmodel_ver}_config.json" 46 | ) 47 | config = T5Config.from_json_file(textmodel_json_config) 48 | self.num_layers = config.num_hidden_layers 49 | with modeling_utils.no_init_weights(): 50 | self.transformer = T5EncoderModel(config) 51 | 52 | if freeze: 53 | self.freeze() 54 | self.empty_tokens = [[0] * self.max_length] # token 55 | 56 | def freeze(self): 57 | self.transformer = self.transformer.eval() 58 | for param in self.parameters(): 59 | param.requires_grad = False 60 | 61 | def forward(self, tokens): 62 | device = self.transformer.get_input_embeddings().weight.device 63 | tokens = torch.LongTensor(tokens).to(device) 64 | attention_mask = torch.zeros_like(tokens) 65 | max_token = 1 # token 66 | for x in range(attention_mask.shape[0]): 67 | for y in range(attention_mask.shape[1]): 68 | attention_mask[x, y] = 1 69 | if tokens[x, y] == max_token: 70 | break 71 | 72 | outputs = self.transformer(input_ids=tokens, attention_mask=attention_mask) 73 | 74 | z = outputs['last_hidden_state'] 75 | z.detach().cpu().float() 76 | return z 77 | 78 | def encode(self, tokens): 79 | return self(tokens) 80 | 81 | def load_sd(self, sd): 82 | return self.transformer.load_state_dict(sd, strict=False) 83 | 84 | def to(self, *args, **kwargs): 85 | """BNB complains if you try to change the device or dtype""" 86 | if self.bnb: 87 | print("Thanks to BitsAndBytes, T5 becomes an immovable rock.", args, kwargs) 88 | else: 89 | self.transformer.to(*args, **kwargs) 90 | 91 | def encode_token_weights(self, token_weight_pairs, return_padded=False): 92 | to_encode = list(self.empty_tokens) 93 | for x in token_weight_pairs: 94 | tokens = list(map(lambda a: a[0], x)) 95 | to_encode.append(tokens) 96 | 97 | out = self.encode(to_encode) 98 | z_empty = out[0:1] 99 | 100 | output = [] 101 | for k in range(1, out.shape[0]): 102 | z = out[k:k+1] 103 | for i in range(len(z)): 104 | for j in range(len(z[i])): 105 | weight = token_weight_pairs[k - 1][j][1] 106 | z[i][j] = (z[i][j] - z_empty[0][j]) * weight + z_empty[0][j] 107 | output.append(z) 108 | 109 | if (len(output) == 0): 110 | return z_empty.cpu() 111 | 112 | out = torch.cat(output, dim=-2) 113 | if not return_padded: 114 | # Count number of tokens that aren't , then use that number as an index. 115 | keep_index = sum([sum([1 for y in x if y[0] != 0]) for x in token_weight_pairs]) 116 | out = out[:, :keep_index, :] 117 | return out 118 | 119 | 120 | class T5v11Tokenizer: 121 | """ 122 | This is largely just based on the ComfyUI CLIP code. 123 | """ 124 | def __init__(self, tokenizer_path=None, max_length=120, embedding_directory=None, embedding_size=4096, embedding_key='t5'): 125 | if tokenizer_path is None: 126 | tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") 127 | self.tokenizer = T5Tokenizer.from_pretrained(tokenizer_path) 128 | self.max_length = max_length 129 | self.max_tokens_per_section = self.max_length - 1 # but no 130 | 131 | self.pad_token = self.tokenizer("", add_special_tokens=False)["input_ids"][0] 132 | self.end_token = self.tokenizer("", add_special_tokens=False)["input_ids"][0] 133 | vocab = self.tokenizer.get_vocab() 134 | self.inv_vocab = {v: k for k, v in vocab.items()} 135 | self.embedding_directory = embedding_directory 136 | self.max_word_length = 8 # haven't verified this 137 | self.embedding_identifier = "embedding:" 138 | self.embedding_size = embedding_size 139 | self.embedding_key = embedding_key 140 | 141 | def _try_get_embedding(self, embedding_name:str): 142 | ''' 143 | Takes a potential embedding name and tries to retrieve it. 144 | Returns a Tuple consisting of the embedding and any leftover string, embedding can be None. 145 | ''' 146 | embed = load_embed(embedding_name, self.embedding_directory, self.embedding_size, self.embedding_key) 147 | if embed is None: 148 | stripped = embedding_name.strip(',') 149 | if len(stripped) < len(embedding_name): 150 | embed = load_embed(stripped, self.embedding_directory, self.embedding_size, self.embedding_key) 151 | return (embed, embedding_name[len(stripped):]) 152 | return (embed, "") 153 | 154 | def tokenize_with_weights(self, text:str, return_word_ids=False): 155 | ''' 156 | Takes a prompt and converts it to a list of (token, weight, word id) elements. 157 | Tokens can both be integer tokens and pre computed T5 tensors. 158 | Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens. 159 | Returned list has the dimensions NxM where M is the input size of T5 160 | ''' 161 | pad_token = self.pad_token 162 | text = escape_important(text) 163 | parsed_weights = token_weights(text, 1.0) 164 | 165 | #tokenize words 166 | tokens = [] 167 | for weighted_segment, weight in parsed_weights: 168 | to_tokenize = unescape_important(weighted_segment).replace("\n", " ").split(' ') 169 | to_tokenize = [x for x in to_tokenize if x != ""] 170 | for word in to_tokenize: 171 | #if we find an embedding, deal with the embedding 172 | if word.startswith(self.embedding_identifier) and self.embedding_directory is not None: 173 | embedding_name = word[len(self.embedding_identifier):].strip('\n') 174 | embed, leftover = self._try_get_embedding(embedding_name) 175 | if embed is None: 176 | print(f"warning, embedding:{embedding_name} does not exist, ignoring") 177 | else: 178 | if len(embed.shape) == 1: 179 | tokens.append([(embed, weight)]) 180 | else: 181 | tokens.append([(embed[x], weight) for x in range(embed.shape[0])]) 182 | #if we accidentally have leftover text, continue parsing using leftover, else move on to next word 183 | if leftover != "": 184 | word = leftover 185 | else: 186 | continue 187 | #parse word 188 | tokens.append([(t, weight) for t in self.tokenizer(word, add_special_tokens=False)["input_ids"]]) 189 | 190 | #reshape token array to T5 input size 191 | batched_tokens = [] 192 | batch = [] 193 | batched_tokens.append(batch) 194 | for i, t_group in enumerate(tokens): 195 | #determine if we're going to try and keep the tokens in a single batch 196 | is_large = len(t_group) >= self.max_word_length 197 | 198 | while len(t_group) > 0: 199 | if len(t_group) + len(batch) > self.max_length - 1: 200 | remaining_length = self.max_length - len(batch) - 1 201 | #break word in two and add end token 202 | if is_large: 203 | batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]]) 204 | batch.append((self.end_token, 1.0, 0)) 205 | t_group = t_group[remaining_length:] 206 | #add end token and pad 207 | else: 208 | batch.append((self.end_token, 1.0, 0)) 209 | batch.extend([(self.pad_token, 1.0, 0)] * (remaining_length)) 210 | #start new batch 211 | batch = [] 212 | batched_tokens.append(batch) 213 | else: 214 | batch.extend([(t,w,i+1) for t,w in t_group]) 215 | t_group = [] 216 | 217 | # fill last batch 218 | batch.extend([(self.end_token, 1.0, 0)] + [(self.pad_token, 1.0, 0)] * (self.max_length - len(batch) - 1)) 219 | # instead of filling, just add EOS (DEBUG) 220 | # batch.extend([(self.end_token, 1.0, 0)]) 221 | 222 | if not return_word_ids: 223 | batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens] 224 | return batched_tokens 225 | 226 | def untokenize(self, token_weight_pair): 227 | return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair)) 228 | --------------------------------------------------------------------------------