├── .gitattributes ├── .gitignore ├── AITemplate ├── AITemplate.py ├── ComfyCode │ └── model_management.py ├── __init__.py ├── ait │ ├── __init__.py │ ├── ait.py │ ├── compile │ │ ├── clip.py │ │ ├── controlnet.py │ │ ├── release.py │ │ ├── unet.py │ │ ├── util.py │ │ └── vae.py │ ├── inference.py │ ├── load.py │ ├── modeling │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── clip.py │ │ ├── controlnet.py │ │ ├── embeddings.py │ │ ├── resnet.py │ │ ├── unet_2d_condition.py │ │ ├── unet_blocks.py │ │ └── vae.py │ ├── module │ │ ├── __init__.py │ │ ├── dtype.py │ │ ├── misc.py │ │ ├── model.py │ │ └── torch_utils.py │ └── util │ │ ├── __init__.py │ │ ├── ckpt_convert.py │ │ ├── mapping │ │ ├── __init__.py │ │ ├── clip.py │ │ ├── controlnet.py │ │ ├── unet.py │ │ └── vae.py │ │ └── torch_dtype_from_str.py ├── clip.py ├── controlnet.py ├── download_pipeline.py ├── modules │ ├── modules.json │ └── place_modules_here ├── test.py ├── unet.py └── vae.py ├── LICENSE ├── README.md ├── __init__.py ├── docs ├── clip.md ├── compile.md ├── compvis.md ├── controlnet.md ├── unet.md └── vae.md └── workflows ├── aitemplate_controlnet.json ├── aitemplate_img2img_unet_vae.json ├── aitemplate_two_pass.json ├── aitemplate_two_pass_two_model.json ├── aitemplate_unet_only.json └── aitemplate_unet_vae.json /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto eol=lf 3 | 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | tmp/ 2 | ait_tmp/ 3 | *.xz 4 | *.png 5 | test*.py 6 | compile_*.py 7 | 8 | .vscode/ 9 | 10 | # Byte-compiled / optimized / DLL files 11 | __pycache__/ 12 | *.py[cod] 13 | *$py.class 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | share/python-wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | MANIFEST 37 | 38 | # PyInstaller 39 | # Usually these files are written by a python script from a template 40 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 41 | *.manifest 42 | *.spec 43 | 44 | # Installer logs 45 | pip-log.txt 46 | pip-delete-this-directory.txt 47 | 48 | # Unit test / coverage reports 49 | htmlcov/ 50 | .tox/ 51 | .nox/ 52 | .coverage 53 | .coverage.* 54 | .cache 55 | nosetests.xml 56 | coverage.xml 57 | *.cover 58 | *.py,cover 59 | .hypothesis/ 60 | .pytest_cache/ 61 | cover/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | local_settings.py 70 | db.sqlite3 71 | db.sqlite3-journal 72 | 73 | # Flask stuff: 74 | instance/ 75 | .webassets-cache 76 | 77 | # Scrapy stuff: 78 | .scrapy 79 | 80 | # Sphinx documentation 81 | docs/_build/ 82 | 83 | # PyBuilder 84 | .pybuilder/ 85 | target/ 86 | 87 | # Jupyter Notebook 88 | .ipynb_checkpoints 89 | 90 | # IPython 91 | profile_default/ 92 | ipython_config.py 93 | 94 | # pyenv 95 | # For a library or package, you might want to ignore these files since the code is 96 | # intended to run in multiple environments; otherwise, check them in: 97 | # .python-version 98 | 99 | # pipenv 100 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 101 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 102 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 103 | # install all needed dependencies. 104 | #Pipfile.lock 105 | 106 | # poetry 107 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 108 | # This is especially recommended for binary packages to ensure reproducibility, and is more 109 | # commonly ignored for libraries. 110 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 111 | #poetry.lock 112 | 113 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 114 | __pypackages__/ 115 | 116 | # Celery stuff 117 | celerybeat-schedule 118 | celerybeat.pid 119 | 120 | # SageMath parsed files 121 | *.sage.py 122 | 123 | # Environments 124 | .env 125 | .venv 126 | env/ 127 | venv/ 128 | ENV/ 129 | env.bak/ 130 | venv.bak/ 131 | 132 | # Spyder project settings 133 | .spyderproject 134 | .spyproject 135 | 136 | # Rope project settings 137 | .ropeproject 138 | 139 | # mkdocs documentation 140 | /site 141 | 142 | # mypy 143 | .mypy_cache/ 144 | .dmypy.json 145 | dmypy.json 146 | 147 | # Pyre type checker 148 | .pyre/ 149 | 150 | # pytype static type analyzer 151 | .pytype/ 152 | 153 | # Cython debug symbols 154 | cython_debug/ 155 | 156 | # PyCharm 157 | # JetBrains specific template is maintainted in a separate JetBrains.gitignore that can 158 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 159 | # and can be added to the global gitignore or merged into this file. For a more nuclear 160 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 161 | #.idea/ 162 | -------------------------------------------------------------------------------- /AITemplate/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FizzleDorf/AIT/b46f17dd095283c31d12d5ffd6fb90457c85e5e0/AITemplate/__init__.py -------------------------------------------------------------------------------- /AITemplate/ait/__init__.py: -------------------------------------------------------------------------------- 1 | from .load import AITLoader 2 | from .inference import unet_inference, clip_inference, vae_inference, controlnet_inference 3 | from .ait import AIT 4 | 5 | __all__ = ["AIT", "AITLoader", "unet_inference", "clip_inference", "vae_inference", "controlnet_inference"] 6 | -------------------------------------------------------------------------------- /AITemplate/ait/ait.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, Union 2 | 3 | import torch 4 | from safetensors.torch import load_file 5 | 6 | from .load import AITLoader 7 | from .module import Model 8 | from .inference import clip_inference, unet_inference, vae_inference, controlnet_inference 9 | 10 | 11 | class AIT: 12 | def __init__(self, path: str = None) -> None: 13 | self.modules = {} 14 | self.unet = {} 15 | self.vae = {} 16 | self.controlnet = {} 17 | self.clip = {} 18 | self.control_net = None 19 | if path is not None: 20 | self.loader = AITLoader(path) 21 | else: 22 | self.loader = AITLoader() 23 | self.supported = ['clip', 'controlnet', 'unet', 'vae'] 24 | 25 | def load(self, 26 | aitemplate_path: str, 27 | hf_hub_or_path: str, 28 | module_type: str, 29 | ): 30 | if module_type == "clip": 31 | self.modules["clip"] = self.loader.load(aitemplate_path) 32 | clip = self.loader.diffusers_clip(hf_hub_or_path) 33 | self.modules["clip"] = self.loader.apply_clip(self.modules["clip"], clip) 34 | elif module_type == "controlnet": 35 | self.modules["controlnet"] = self.loader.load(aitemplate_path) 36 | controlnet = self.loader.diffusers_controlnet(hf_hub_or_path) 37 | self.modules["controlnet"] = self.loader.apply_controlnet(self.modules["controlnet"], controlnet) 38 | elif module_type == "unet": 39 | self.modules["unet"] = self.loader.load(aitemplate_path) 40 | unet = self.loader.diffusers_unet(hf_hub_or_path) 41 | self.modules["unet"] = self.loader.apply_unet(self.modules["unet"], unet) 42 | elif module_type == "vae_decode": 43 | self.modules["vae_decode"] = self.loader.load(aitemplate_path) 44 | vae = self.loader.diffusers_vae(hf_hub_or_path) 45 | self.modules["vae_decode"] = self.loader.apply_vae(self.modules["vae_decode"], vae) 46 | elif module_type == "vae_encode": 47 | self.modules["vae_encode"] = self.loader.load(aitemplate_path) 48 | vae = self.loader.diffusers_vae(hf_hub_or_path) 49 | self.modules["vae_encode"] = self.loader.apply_vae(self.modules["vae_encode"], vae, encoder=True) 50 | else: 51 | raise ValueError(f"module_type must be one of {self.supported}") 52 | 53 | def load_compvis(self, 54 | aitemplate_path: str, 55 | ckpt_path: str, 56 | module_type: str, 57 | ): 58 | if ckpt_path.endswith(".safetensors"): 59 | state_dict = load_file(ckpt_path) 60 | elif ckpt_path.endswith(".ckpt"): 61 | state_dict = torch.load(ckpt_path, map_location="cpu") 62 | else: 63 | raise ValueError("ckpt_path must be a .safetensors or .ckpt file") 64 | while "state_dict" in state_dict.keys(): 65 | """ 66 | yo dawg i heard you like state dicts so i put a state dict in your state dict 67 | 68 | apparently this happens in some models 69 | """ 70 | state_dict = state_dict["state_dict"] 71 | if module_type == "clip": 72 | self.modules["clip"] = self.loader.load(aitemplate_path) 73 | clip = self.loader.compvis_clip(state_dict) 74 | self.modules["clip"] = self.loader.apply_clip(self.modules["clip"], clip) 75 | elif module_type == "controlnet": 76 | self.modules["controlnet"] = self.loader.load(aitemplate_path) 77 | controlnet = self.loader.compvis_controlnet(state_dict) 78 | self.modules["controlnet"] = self.loader.apply_controlnet(self.modules["controlnet"], controlnet) 79 | elif module_type == "unet": 80 | self.modules["unet"] = self.loader.load(aitemplate_path) 81 | unet = self.loader.compvis_unet(state_dict) 82 | self.modules["unet"] = self.loader.apply_unet(self.modules["unet"], unet) 83 | elif module_type == "vae_decode": 84 | self.modules["vae_decode"] = self.loader.load(aitemplate_path) 85 | vae = self.loader.compvis_vae(state_dict) 86 | self.modules["vae_decode"] = self.loader.apply_vae(self.modules["vae_decode"], vae) 87 | elif module_type == "vae_encode": 88 | self.modules["vae_encode"] = self.loader.load(aitemplate_path) 89 | vae = self.loader.compvis_vae(state_dict) 90 | self.modules["vae_encode"] = self.loader.apply_vae(self.modules["vae_encode"], vae, encoder=True) 91 | else: 92 | raise ValueError(f"module_type must be one of {self.supported}") 93 | 94 | 95 | def test_unet( 96 | self, 97 | batch_size: int = 2, 98 | latent_channels: int = 4, 99 | height: int = 64, 100 | width: int = 64, 101 | hidden_dim: int = 768, 102 | sequence_length: int = 77, 103 | dtype="float16", 104 | device="cuda", 105 | benchmark: bool = False, 106 | add_embed_dim:int = 2816, 107 | xl = False, 108 | ): 109 | if "unet" not in self.modules: 110 | raise ValueError("unet module not loaded") 111 | latent_model_input_pt = torch.randn(batch_size, latent_channels, height, width).to(device) 112 | text_embeddings_pt = torch.randn(batch_size, sequence_length, hidden_dim).to(device) 113 | timesteps_pt = torch.Tensor([1] * batch_size).to(device) 114 | if xl: 115 | add_embeds = torch.randn(batch_size, add_embed_dim).to(device) 116 | if dtype == "float16": 117 | latent_model_input_pt = latent_model_input_pt.half() 118 | text_embeddings_pt = text_embeddings_pt.half() 119 | timesteps_pt = timesteps_pt.half() 120 | if xl: 121 | add_embeds = add_embeds.half() 122 | output = unet_inference( 123 | self.modules["unet"], 124 | latent_model_input=latent_model_input_pt, 125 | timesteps=timesteps_pt, 126 | encoder_hidden_states=text_embeddings_pt, 127 | benchmark=benchmark, 128 | add_embeds=add_embeds if xl else None, 129 | ) 130 | print(output.shape) 131 | return output 132 | 133 | def test_vae_encode( 134 | self, 135 | batch_size: int = 1, 136 | channels: int = 3, 137 | height: int = 512, 138 | width: int = 512, 139 | dtype="float16", 140 | device="cuda", 141 | ): 142 | if "vae_encode" not in self.modules: 143 | raise ValueError("vae module not loaded") 144 | vae_input = torch.randn(batch_size, channels, height, width).to(device) 145 | if dtype == "float16": 146 | vae_input = vae_input.half() 147 | output = vae_inference( 148 | self.modules["vae_encode"], 149 | vae_input=vae_input, 150 | encoder=True, 151 | ) 152 | print(output.shape) 153 | return output 154 | 155 | 156 | def test_vae( 157 | self, 158 | batch_size: int = 1, 159 | latent_channels: int = 4, 160 | height: int = 64, 161 | width: int = 64, 162 | dtype="float16", 163 | device="cuda", 164 | benchmark: bool = False, 165 | ): 166 | if "vae_decode" not in self.modules: 167 | raise ValueError("vae module not loaded") 168 | vae_input = torch.randn(batch_size, latent_channels, height, width).to(device) 169 | if dtype == "float16": 170 | vae_input = vae_input.half() 171 | output = vae_inference( 172 | self.modules["vae_decode"], 173 | vae_input=vae_input, 174 | benchmark=benchmark, 175 | ) 176 | print(output.shape) 177 | return output 178 | 179 | def test_clip( 180 | self, 181 | batch_size: int = 1, 182 | sequence_length: int = 77, 183 | tokenizer=None, 184 | ): 185 | if "clip" not in self.modules: 186 | raise ValueError("clip module not loaded") 187 | try: 188 | from transformers import CLIPTokenizer 189 | except ImportError: 190 | raise ImportError( 191 | "Please install transformers with `pip install transformers` to use this script." 192 | ) 193 | if tokenizer is None: 194 | tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") 195 | text_input = tokenizer( 196 | ["a photo of an astronaut riding a horse on mars"] * batch_size, 197 | padding="max_length", 198 | max_length=sequence_length, 199 | truncation=True, 200 | return_tensors="pt", 201 | ) 202 | input_ids = text_input["input_ids"].cuda() 203 | output = clip_inference( 204 | self.modules["clip"], 205 | input_ids=input_ids, 206 | seqlen=sequence_length, 207 | ) 208 | print(output.shape) 209 | return output 210 | 211 | def test_controlnet( 212 | self, 213 | batch_size: int = 2, 214 | latent_channels: int = 4, 215 | latent_height: int = 64, 216 | latent_width: int = 64, 217 | hidden_dim: int = 768, 218 | sequence_length: int = 77, 219 | control_height: int = 512, 220 | control_width: int = 512, 221 | control_channels: int = 3, 222 | add_embed_dim:int = 2816, 223 | xl: bool = False, 224 | benchmark: bool = False, 225 | device="cuda", 226 | dtype="float16", 227 | ): 228 | latent_model_input_pt = torch.randn(batch_size, latent_channels, latent_height, latent_width).to(device) 229 | text_embeddings_pt = torch.randn(batch_size, sequence_length, hidden_dim).to(device) 230 | timesteps_pt = torch.Tensor([1] * batch_size).to(device) 231 | controlnet_input_pt = torch.randn(batch_size, control_channels, control_height, control_width).to(device) 232 | if xl: 233 | add_embeds = torch.randn(batch_size, add_embed_dim).to(device) 234 | if dtype == "float16": 235 | latent_model_input_pt = latent_model_input_pt.half() 236 | text_embeddings_pt = text_embeddings_pt.half() 237 | timesteps_pt = timesteps_pt.half() 238 | controlnet_input_pt = controlnet_input_pt.half() 239 | if xl: 240 | add_embeds = add_embeds.half() 241 | outputs = controlnet_inference( 242 | self.modules["controlnet"], 243 | latent_model_input=latent_model_input_pt, 244 | timesteps=timesteps_pt, 245 | encoder_hidden_states=text_embeddings_pt, 246 | controlnet_cond=controlnet_input_pt, 247 | add_embeds=add_embeds if xl else None, 248 | benchmark=benchmark, 249 | ) 250 | for block, value in outputs.items(): 251 | print(block, value.shape) 252 | return outputs 253 | -------------------------------------------------------------------------------- /AITemplate/ait/compile/clip.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | import sys 16 | from aitemplate.compiler import compile_model 17 | from aitemplate.frontend import IntVar, Tensor 18 | from aitemplate.testing import detect_target 19 | 20 | from ..modeling.clip import CLIPTextTransformer as ait_CLIPTextTransformer 21 | from .util import mark_output 22 | from .release import process 23 | 24 | from ait.util.mapping import map_clip 25 | 26 | 27 | def compile_clip( 28 | pt_mod, 29 | batch_size=(1, 8), 30 | seqlen=64, 31 | dim=768, 32 | num_heads=12, 33 | depth=12, 34 | output_hidden_states=False, 35 | text_projection_dim=None, 36 | use_fp16_acc=True, 37 | convert_conv_to_gemm=True, 38 | act_layer="gelu", 39 | constants=True, 40 | model_name="CLIPTextModel", 41 | work_dir="./tmp", 42 | out_dir="./out", 43 | ): 44 | _batch_size = batch_size 45 | mask_seq = 0 46 | causal = True 47 | 48 | ait_mod = ait_CLIPTextTransformer( 49 | num_hidden_layers=depth, 50 | hidden_size=dim, 51 | num_attention_heads=num_heads, 52 | batch_size=batch_size, 53 | seq_len=seqlen, 54 | causal=causal, 55 | mask_seq=mask_seq, 56 | act_layer=act_layer, 57 | output_hidden_states=output_hidden_states, 58 | text_projection_dim=text_projection_dim, 59 | ) 60 | ait_mod.name_parameter_tensor() 61 | 62 | pt_mod = pt_mod.eval() 63 | params_ait = map_clip(pt_mod) 64 | 65 | static_shape = batch_size[0] == batch_size[1] 66 | if static_shape: 67 | batch_size = batch_size[0] 68 | else: 69 | batch_size = IntVar(values=list(batch_size), name="batch_size") 70 | 71 | input_ids_ait = Tensor( 72 | [batch_size, seqlen], name="input_ids", dtype="int64", is_input=True 73 | ) 74 | position_ids_ait = Tensor( 75 | [batch_size, seqlen], name="position_ids", dtype="int64", is_input=True 76 | ) 77 | 78 | Y = ait_mod(input_ids=input_ids_ait, position_ids=position_ids_ait) 79 | mark_output(Y) 80 | 81 | target = detect_target( 82 | use_fp16_acc=use_fp16_acc, convert_conv_to_gemm=convert_conv_to_gemm 83 | ) 84 | dll_name = model_name + ".dll" if sys.platform == "win32" else model_name + ".so" 85 | total_usage = compile_model( 86 | Y, target, work_dir, model_name, constants=params_ait if constants else None, dll_name=dll_name, 87 | ) 88 | sd = "L" 89 | if dim == 1024: 90 | sd = "H" 91 | if dim == 1280: 92 | sd = "G" 93 | vram = round(total_usage / 1024 / 1024) 94 | process(work_dir, model_name, dll_name, target._arch, None, None, _batch_size[-1], vram, out_dir, sd, "clip_text") -------------------------------------------------------------------------------- /AITemplate/ait/compile/controlnet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | import sys 16 | import torch 17 | from aitemplate.compiler import compile_model 18 | from aitemplate.frontend import IntVar, Tensor 19 | from aitemplate.testing import detect_target 20 | 21 | from ..modeling.controlnet import ( 22 | ControlNetModel as ait_ControlNetModel, 23 | ) 24 | from .util import mark_output 25 | from .release import process 26 | 27 | from ait.util.mapping import map_controlnet 28 | 29 | 30 | def compile_controlnet( 31 | pt_mod, 32 | batch_size=(1, 4), 33 | height=(64, 2048), 34 | width=(64, 2048), 35 | clip_chunks=1, 36 | dim=320, 37 | hidden_dim=768, 38 | use_fp16_acc=False, 39 | convert_conv_to_gemm=False, 40 | model_name="ControlNetModel", 41 | constants=False, 42 | work_dir="./tmp", 43 | out_dir="./out", 44 | down_factor=8, 45 | use_linear_projection=False, 46 | block_out_channels=(320, 640, 1280, 1280), 47 | down_block_types= ( 48 | "CrossAttnDownBlock2D", 49 | "CrossAttnDownBlock2D", 50 | "CrossAttnDownBlock2D", 51 | "DownBlock2D", 52 | ), 53 | in_channels=4, 54 | out_channels=4, 55 | sample_size=64, 56 | class_embed_type=None, 57 | num_class_embeds=None, 58 | time_embedding_dim = None, 59 | conv_in_kernel: int = 3, 60 | projection_class_embeddings_input_dim = None, 61 | addition_embed_type = None, 62 | addition_time_embed_dim = None, 63 | transformer_layers_per_block = 1, 64 | dtype="float16", 65 | ): 66 | _batch_size = batch_size 67 | _height = height 68 | _width = width 69 | xl = False 70 | if projection_class_embeddings_input_dim is not None: 71 | xl = True 72 | if isinstance(transformer_layers_per_block, int): 73 | transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) 74 | if isinstance(transformer_layers_per_block, int): 75 | transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) 76 | batch_size = batch_size # double batch size for unet 77 | ait_mod = ait_ControlNetModel( 78 | in_channels=in_channels, 79 | down_block_types=down_block_types, 80 | block_out_channels=block_out_channels, 81 | cross_attention_dim=hidden_dim, 82 | transformer_layers_per_block=transformer_layers_per_block, 83 | use_linear_projection=use_linear_projection, 84 | class_embed_type=class_embed_type, 85 | addition_embed_type=addition_embed_type, 86 | num_class_embeds=num_class_embeds, 87 | projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, 88 | dtype="float16", 89 | ) 90 | ait_mod.name_parameter_tensor() 91 | 92 | pt_mod = pt_mod.eval() 93 | params_ait = map_controlnet(pt_mod, dim=dim) 94 | 95 | static_shape = width[0] == width[1] and height[0] == height[1] and batch_size[0] == batch_size[1] 96 | 97 | if static_shape: 98 | batch_size = batch_size[0] * 2 # double batch size for unet 99 | height_d = height[0] // down_factor 100 | width_d = width[0] // down_factor 101 | height_c = height[0] 102 | width_c = width[0] 103 | clip_chunks = 77 104 | embedding_size = clip_chunks 105 | else: 106 | batch_size = batch_size[0], batch_size[1] * 2 # double batch size for unet 107 | batch_size = IntVar(values=list(batch_size), name="batch_size") 108 | height_d = height[0] // down_factor, height[1] // down_factor 109 | height_d = IntVar(values=list(height_d), name="height_d") 110 | width_d = width[0] // down_factor, width[1] // down_factor 111 | width_d = IntVar(values=list(width_d), name="width_d") 112 | height_c = height 113 | height_c = IntVar(values=list(height_c), name="height_c") 114 | width_c = width 115 | width_c = IntVar(values=list(width_c), name="width_c") 116 | clip_chunks = 77, 77 * clip_chunks 117 | embedding_size = IntVar(values=list(clip_chunks), name="embedding_size") 118 | 119 | latent_model_input_ait = Tensor( 120 | [batch_size, height_d, width_d, 4], name="latent_model_input", is_input=True 121 | ) 122 | timesteps_ait = Tensor([batch_size], name="timesteps", is_input=True) 123 | text_embeddings_pt_ait = Tensor( 124 | [batch_size, embedding_size, hidden_dim], name="encoder_hidden_states", is_input=True 125 | ) 126 | controlnet_condition_ait = Tensor( 127 | [batch_size, height_c, width_c, 3], name="control_hint", is_input=True 128 | ) 129 | 130 | add_embeds = None 131 | if xl: 132 | add_embeds = Tensor( 133 | [batch_size, projection_class_embeddings_input_dim], name="add_embeds", is_input=True, dtype=dtype 134 | ) 135 | 136 | 137 | Y = ait_mod( 138 | latent_model_input_ait, 139 | timesteps_ait, 140 | text_embeddings_pt_ait, 141 | controlnet_condition_ait, 142 | add_embeds=add_embeds, 143 | ) 144 | mark_output(Y) 145 | 146 | target = detect_target( 147 | use_fp16_acc=use_fp16_acc, convert_conv_to_gemm=convert_conv_to_gemm 148 | ) 149 | dll_name = model_name + ".dll" if sys.platform == "win32" else model_name + ".so" 150 | total_usage = compile_model( 151 | Y, target, work_dir, model_name, constants=params_ait if constants else None, dll_name=dll_name, 152 | ) 153 | sd = "v1" 154 | if hidden_dim == 1024: 155 | sd = "v2" 156 | elif hidden_dim == 2048: 157 | sd = "xl" 158 | vram = round(total_usage / 1024 / 1024) 159 | process(work_dir, model_name, dll_name, target._arch, _height[-1], _width[-1], _batch_size[-1], vram, out_dir, sd, "controlnet") -------------------------------------------------------------------------------- /AITemplate/ait/compile/release.py: -------------------------------------------------------------------------------- 1 | import os 2 | import lzma 3 | import hashlib 4 | import json 5 | import sys 6 | 7 | def sha256sum(filename): 8 | h = hashlib.sha256() 9 | b = bytearray(128 * 1024) 10 | mv = memoryview(b) 11 | with open(filename, "rb", buffering=0) as f: 12 | for n in iter(lambda: f.readinto(mv), 0): 13 | h.update(mv[:n]) 14 | return h.hexdigest() 15 | 16 | def filesize(filename): 17 | return os.stat(filename).st_size 18 | 19 | def compress_file(filename): 20 | with open(filename, "rb") as f: 21 | data = f.read() 22 | with lzma.open(filename + ".xz", "wb", preset=9) as f: 23 | f.write(data) 24 | sha256 = sha256sum(filename + ".xz") 25 | return sha256, filesize(filename + ".xz") 26 | 27 | def process_file(filename): 28 | file_size = filesize(filename) 29 | sha256 = sha256sum(filename) 30 | sha256_xz, file_size_xz = compress_file(filename) 31 | return sha256, file_size, sha256_xz, file_size_xz 32 | 33 | def process(work_dir, model_name, dll_name, arch, height, width, batch_size, vram, out_dir, sd, model_type): 34 | path = os.path.join(work_dir, model_name) 35 | dll_path = os.path.join(path, dll_name) 36 | sha256, file_size, sha256_xz, file_size_xz = process_file(dll_path) 37 | _os = "windows" if sys.platform == "win32" else "linux" 38 | cuda = f"sm{arch}" 39 | if height is None or width is None: 40 | _reso = None 41 | else: 42 | _reso = max(height, width) 43 | _bs = batch_size 44 | compressed_name = f"{dll_name}.xz" 45 | compressed_path = os.path.join(path, compressed_name) 46 | subpath = f"{_os}/{cuda}/" 47 | if _reso is not None: 48 | subpath = subpath + f"bs{_bs}/{_reso}/" 49 | key = (subpath + compressed_name).replace("\\", "/") 50 | subpath = os.path.join(out_dir, subpath) 51 | os.makedirs(subpath, exist_ok=True) 52 | out_path = os.path.join(subpath, compressed_name) 53 | os.rename(compressed_path, out_path) 54 | data = { 55 | "os": _os, 56 | "cuda": cuda, 57 | "model": model_type, 58 | "sd": sd, 59 | "batch_size": _bs, 60 | "resolution": _reso, 61 | "vram": vram, 62 | "url": key, 63 | "compressed_size": file_size_xz, 64 | "size": file_size, 65 | "compressed_sha256": sha256_xz, 66 | "sha256": sha256, 67 | } 68 | if not os.path.exists(os.path.join(out_dir, "modules.json")): 69 | with open(os.path.join(out_dir, "modules.json"), "w") as f: 70 | json.dump({}, f) 71 | with open(os.path.join(out_dir, "modules.json"), "r") as f: 72 | modules = json.load(f) 73 | modules[key.replace("/", "_")] = data 74 | with open(os.path.join(out_dir, "modules.json"), "w") as f: 75 | json.dump(modules, f) 76 | 77 | -------------------------------------------------------------------------------- /AITemplate/ait/compile/unet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | import sys 16 | import torch 17 | from aitemplate.compiler import compile_model 18 | from aitemplate.frontend import IntVar, Tensor, DynamicProfileStrategy 19 | from aitemplate.testing import detect_target 20 | 21 | from ..modeling.unet_2d_condition import ( 22 | UNet2DConditionModel as ait_UNet2DConditionModel, 23 | ) 24 | from .util import mark_output 25 | from .release import process 26 | from ait.util.mapping import map_unet 27 | 28 | def compile_unet( 29 | pt_mod, 30 | batch_size=(1, 8), 31 | height=(64, 2048), 32 | width=(64, 2048), 33 | clip_chunks=1, 34 | out_dir="./out", 35 | work_dir="./tmp", 36 | dim=320, 37 | hidden_dim=1024, 38 | use_fp16_acc=False, 39 | convert_conv_to_gemm=False, 40 | controlnet=False, 41 | attention_head_dim=[5, 10, 20, 20], # noqa: B006 42 | model_name="UNet2DConditionModel", 43 | use_linear_projection=False, 44 | constants=True, 45 | block_out_channels=(320, 640, 1280, 1280), 46 | down_block_types= ( 47 | "CrossAttnDownBlock2D", 48 | "CrossAttnDownBlock2D", 49 | "CrossAttnDownBlock2D", 50 | "DownBlock2D", 51 | ), 52 | up_block_types=( 53 | "UpBlock2D", 54 | "CrossAttnUpBlock2D", 55 | "CrossAttnUpBlock2D", 56 | "CrossAttnUpBlock2D", 57 | ), 58 | in_channels=4, 59 | out_channels=4, 60 | sample_size=64, 61 | class_embed_type=None, 62 | num_class_embeds=None, 63 | only_cross_attention=[ 64 | True, 65 | True, 66 | True, 67 | False 68 | ], 69 | down_factor=8, 70 | time_embedding_dim = None, 71 | conv_in_kernel: int = 3, 72 | projection_class_embeddings_input_dim = None, 73 | addition_embed_type = None, 74 | addition_time_embed_dim = None, 75 | transformer_layers_per_block = 1, 76 | dtype="float16", 77 | ): 78 | _batch_size = batch_size 79 | _height = height 80 | _width = width 81 | xl = False 82 | if projection_class_embeddings_input_dim is not None: 83 | xl = True 84 | if isinstance(only_cross_attention, bool): 85 | only_cross_attention = [only_cross_attention] * len(block_out_channels) 86 | if isinstance(transformer_layers_per_block, int): 87 | transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) 88 | if isinstance(attention_head_dim, int): 89 | attention_head_dim = (attention_head_dim,) * len(down_block_types) 90 | 91 | ait_mod = ait_UNet2DConditionModel( 92 | sample_size=sample_size, 93 | cross_attention_dim=hidden_dim, 94 | attention_head_dim=attention_head_dim, 95 | use_linear_projection=use_linear_projection, 96 | up_block_types=up_block_types, 97 | down_block_types=down_block_types, 98 | block_out_channels=block_out_channels, 99 | in_channels=in_channels, 100 | out_channels=out_channels, 101 | class_embed_type=class_embed_type, 102 | num_class_embeds=num_class_embeds, 103 | only_cross_attention=only_cross_attention, 104 | time_embedding_dim=time_embedding_dim, 105 | conv_in_kernel=conv_in_kernel, 106 | projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, 107 | addition_embed_type=addition_embed_type, 108 | addition_time_embed_dim=addition_time_embed_dim, 109 | transformer_layers_per_block=transformer_layers_per_block, 110 | dtype=dtype, 111 | ) 112 | ait_mod.name_parameter_tensor() 113 | 114 | # set AIT parameters 115 | pt_mod = pt_mod.eval() 116 | params_ait = map_unet(pt_mod, dim=dim, in_channels=in_channels, conv_in_key="conv_in_weight", dtype=dtype) 117 | 118 | static_shape = width[0] == width[1] and height[0] == height[1] 119 | 120 | if static_shape: 121 | height = height[0] // down_factor 122 | width = width[0] // down_factor 123 | height_d = height 124 | width_d = width 125 | height_1_d = height 126 | width_1_d = width 127 | height_2 = height // 2 128 | width_2 = width // 2 129 | height_4 = height // 4 130 | width_4 = width // 4 131 | height_8 = height // 8 132 | width_8 = width // 8 133 | height_2_d = height_2 134 | width_2_d = width_2 135 | height_4_d = height_4 136 | width_4_d = width_4 137 | height_8_d = height_8 138 | width_8_d = width_8 139 | else: 140 | height = [x // down_factor for x in height] 141 | width = [x // down_factor for x in width] 142 | height_d = IntVar(values=list(height), name="height_d") 143 | width_d = IntVar(values=list(width), name="width_d") 144 | height_1_d = IntVar(values=list(height), name="height_1_d") 145 | width_1_d = IntVar(values=list(width), name="width_1_d") 146 | height_2 = [x // 2 for x in height] 147 | width_2 = [x // 2 for x in width] 148 | height_4 = [x // 4 for x in height] 149 | width_4 = [x // 4 for x in width] 150 | height_8 = [x // 8 for x in height] 151 | width_8 = [x // 8 for x in width] 152 | height_2_d = IntVar(values=list(height_2), name="height_2_d") 153 | width_2_d = IntVar(values=list(width_2), name="width_2_d") 154 | height_4_d = IntVar(values=list(height_4), name="height_4_d") 155 | width_4_d = IntVar(values=list(width_4), name="width_4_d") 156 | height_8_d = IntVar(values=list(height_8), name="height_8_d") 157 | width_8_d = IntVar(values=list(width_8), name="width_8_d") 158 | 159 | batch_size = batch_size[0], batch_size[1] * 2 # double batch size for unet 160 | batch_size = IntVar(values=list(batch_size), name="batch_size") 161 | 162 | if static_shape: 163 | embedding_size = 77 164 | else: 165 | clip_chunks = 77, 77 * clip_chunks 166 | embedding_size = IntVar(values=list(clip_chunks), name="embedding_size") 167 | 168 | 169 | latent_model_input_ait = Tensor( 170 | [batch_size, height_d, width_d, in_channels], name="latent_model_input", is_input=True, dtype=dtype 171 | ) 172 | timesteps_ait = Tensor([batch_size], name="timesteps", is_input=True, dtype=dtype) 173 | text_embeddings_pt_ait = Tensor( 174 | [batch_size, embedding_size, hidden_dim], name="encoder_hidden_states", is_input=True, dtype=dtype 175 | ) 176 | 177 | class_labels = None 178 | #TODO: better way to handle this, enables class_labels for x4-upscaler 179 | if in_channels == 7: 180 | class_labels = Tensor( 181 | [batch_size], name="class_labels", dtype="int64", is_input=True 182 | ) 183 | 184 | add_embeds = None 185 | if xl: 186 | add_embeds = Tensor( 187 | [batch_size, projection_class_embeddings_input_dim], name="add_embeds", is_input=True, dtype=dtype 188 | ) 189 | 190 | down_block_residual_0 = None 191 | down_block_residual_1 = None 192 | down_block_residual_2 = None 193 | down_block_residual_3 = None 194 | down_block_residual_4 = None 195 | down_block_residual_5 = None 196 | down_block_residual_6 = None 197 | down_block_residual_7 = None 198 | down_block_residual_8 = None 199 | down_block_residual_9 = None 200 | down_block_residual_10 = None 201 | down_block_residual_11 = None 202 | mid_block_residual = None 203 | if controlnet: 204 | down_block_residual_0 = Tensor( 205 | [batch_size, height_1_d, width_1_d, block_out_channels[0]], 206 | name="down_block_residual_0", 207 | is_input=True, 208 | ) 209 | down_block_residual_1 = Tensor( 210 | [batch_size, height_1_d, width_1_d, block_out_channels[0]], 211 | name="down_block_residual_1", 212 | is_input=True, 213 | ) 214 | down_block_residual_2 = Tensor( 215 | [batch_size, height_1_d,width_1_d, block_out_channels[0]], 216 | name="down_block_residual_2", 217 | is_input=True, 218 | ) 219 | down_block_residual_3 = Tensor( 220 | [batch_size, height_2_d, width_2_d, block_out_channels[0]], 221 | name="down_block_residual_3", 222 | is_input=True, 223 | ) 224 | down_block_residual_4 = Tensor( 225 | [batch_size, height_2_d, width_2_d, block_out_channels[1]], 226 | name="down_block_residual_4", 227 | is_input=True, 228 | ) 229 | down_block_residual_5 = Tensor( 230 | [batch_size, height_2_d, width_2_d, block_out_channels[1]], 231 | name="down_block_residual_5", 232 | is_input=True, 233 | ) 234 | down_block_residual_6 = Tensor( 235 | [batch_size, height_4_d, width_4_d, block_out_channels[1]], 236 | name="down_block_residual_6", 237 | is_input=True, 238 | ) 239 | down_block_residual_7 = Tensor( 240 | [batch_size, height_4_d, width_4_d, block_out_channels[2]], 241 | name="down_block_residual_7", 242 | is_input=True, 243 | ) 244 | down_block_residual_8 = Tensor( 245 | [batch_size, height_4_d, width_4_d, block_out_channels[2]], 246 | name="down_block_residual_8", 247 | is_input=True, 248 | ) 249 | down_block_residual_9 = Tensor( 250 | [batch_size, height_8_d, width_8_d, block_out_channels[2]], 251 | name="down_block_residual_9", 252 | is_input=True, 253 | ) 254 | down_block_residual_10 = Tensor( 255 | [batch_size, height_8_d, width_8_d, block_out_channels[3]], 256 | name="down_block_residual_10", 257 | is_input=True, 258 | ) 259 | down_block_residual_11 = Tensor( 260 | [batch_size, height_8_d, width_8_d, block_out_channels[3]], 261 | name="down_block_residual_11", 262 | is_input=True, 263 | ) 264 | mid_block_residual = Tensor( 265 | [batch_size, height_8_d, width_8_d, block_out_channels[3]], 266 | name="mid_block_residual", 267 | is_input=True, 268 | ) 269 | 270 | 271 | Y = ait_mod( 272 | sample=latent_model_input_ait, 273 | timesteps=timesteps_ait, 274 | encoder_hidden_states=text_embeddings_pt_ait, 275 | down_block_residual_0=down_block_residual_0, 276 | down_block_residual_1=down_block_residual_1, 277 | down_block_residual_2=down_block_residual_2, 278 | down_block_residual_3=down_block_residual_3, 279 | down_block_residual_4=down_block_residual_4, 280 | down_block_residual_5=down_block_residual_5, 281 | down_block_residual_6=down_block_residual_6, 282 | down_block_residual_7=down_block_residual_7, 283 | down_block_residual_8=down_block_residual_8, 284 | down_block_residual_9=down_block_residual_9, 285 | down_block_residual_10=down_block_residual_10, 286 | down_block_residual_11=down_block_residual_11, 287 | mid_block_residual=mid_block_residual, 288 | class_labels=class_labels, 289 | add_embeds=add_embeds, 290 | ) 291 | mark_output(Y) 292 | 293 | target = detect_target( 294 | use_fp16_acc=use_fp16_acc, convert_conv_to_gemm=convert_conv_to_gemm 295 | ) 296 | dll_name = model_name + ".dll" if sys.platform == "win32" else model_name + ".so" 297 | total_usage = compile_model( 298 | Y, target, work_dir, model_name, constants=params_ait if constants else None, dll_name=dll_name, 299 | ) 300 | sd = "v1" 301 | if hidden_dim == 1024: 302 | sd = "v2" 303 | elif hidden_dim == 2048: 304 | sd = "xl" 305 | vram = round(total_usage / 1024 / 1024) 306 | model_type = "unet_control" if controlnet else "unet" 307 | process(work_dir, model_name, dll_name, target._arch, _height[-1], _width[-1], _batch_size[-1], vram, out_dir, sd, model_type) -------------------------------------------------------------------------------- /AITemplate/ait/compile/util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | def mark_output(ys): 16 | if type(ys) != tuple: 17 | ys = (ys, ) 18 | for i in range(len(ys)): 19 | y = ys[i] 20 | if type(y) == tuple: 21 | for yy in y: 22 | y_shape = [d._attrs["values"] for d in yy._attrs["shape"]] 23 | y_name = yy._attrs["name"] 24 | print("AIT {} shape: {}".format(y_name, y_shape)) 25 | else: 26 | y_shape = [d._attrs["values"] for d in y._attrs["shape"]] 27 | y_name = y._attrs["name"] 28 | print("AIT {} shape: {}".format(y_name, y_shape)) 29 | -------------------------------------------------------------------------------- /AITemplate/ait/compile/vae.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | 16 | import sys 17 | import torch 18 | from aitemplate.compiler import compile_model 19 | from aitemplate.frontend import IntVar, Tensor 20 | from aitemplate.testing import detect_target 21 | 22 | from ..modeling.vae import AutoencoderKL as ait_AutoencoderKL 23 | from .util import mark_output 24 | from .release import process 25 | 26 | from ait.util.mapping import map_vae 27 | 28 | def compile_vae( 29 | pt_mod, 30 | batch_size=(1, 8), 31 | height=(64, 2048), 32 | width=(64, 2048), 33 | use_fp16_acc=True, 34 | convert_conv_to_gemm=True, 35 | model_name="AutoencoderKL", 36 | constants=True, 37 | block_out_channels=[128, 256, 512, 512], 38 | layers_per_block=2, 39 | act_fn="silu", 40 | latent_channels=4, 41 | sample_size=512, 42 | in_channels=3, 43 | out_channels=3, 44 | down_block_types=[ 45 | "DownEncoderBlock2D", 46 | "DownEncoderBlock2D", 47 | "DownEncoderBlock2D", 48 | "DownEncoderBlock2D", 49 | ], 50 | up_block_types=[ 51 | "UpDecoderBlock2D", 52 | "UpDecoderBlock2D", 53 | "UpDecoderBlock2D", 54 | "UpDecoderBlock2D", 55 | ], 56 | input_size=(64, 64), 57 | down_factor=8, 58 | dtype="float16", 59 | work_dir="./tmp", 60 | out_dir="./tmp", 61 | vae_encode=False, 62 | ): 63 | _batch_size = batch_size 64 | _height = height 65 | _width = width 66 | ait_vae = ait_AutoencoderKL( 67 | batch_size[0], 68 | input_size[0], 69 | input_size[1], 70 | in_channels=in_channels, 71 | out_channels=out_channels, 72 | down_block_types=down_block_types, 73 | up_block_types=up_block_types, 74 | block_out_channels=block_out_channels, 75 | layers_per_block=layers_per_block, 76 | act_fn=act_fn, 77 | latent_channels=latent_channels, 78 | sample_size=sample_size, 79 | dtype=dtype 80 | ) 81 | 82 | static_batch = batch_size[0] == batch_size[1] 83 | static_shape = height[0] == height[1] and width[0] == width[1] 84 | if not vae_encode: 85 | height = height[0] // down_factor, height[1] // down_factor 86 | width = width[0] // down_factor, width[1] // down_factor 87 | 88 | if static_batch: 89 | batch_size = batch_size[0] 90 | else: 91 | batch_size = IntVar(values=list(batch_size), name="batch_size") 92 | if static_shape: 93 | height_d = height[0] 94 | width_d = width[0] 95 | else: 96 | height_d = IntVar(values=list(height), name="height") 97 | width_d = IntVar(values=list(width), name="width") 98 | 99 | ait_input = Tensor( 100 | shape=[batch_size, height_d, width_d, 3 if vae_encode else latent_channels], 101 | name="pixels" if vae_encode else "latent", 102 | is_input=True, 103 | dtype=dtype 104 | ) 105 | sample = None 106 | if vae_encode: 107 | sample = Tensor( 108 | shape=[batch_size, height_d, width_d, latent_channels], 109 | name="random_sample", 110 | is_input=True, 111 | dtype=dtype, 112 | ) 113 | ait_vae.name_parameter_tensor() 114 | 115 | pt_mod = pt_mod.eval() 116 | params_ait = map_vae(pt_mod, dtype=dtype, encoder=vae_encode) 117 | if vae_encode: 118 | Y = ait_vae.encode(ait_input, sample) 119 | else: 120 | Y = ait_vae.decode(ait_input) 121 | mark_output(Y) 122 | target = detect_target( 123 | use_fp16_acc=use_fp16_acc, convert_conv_to_gemm=convert_conv_to_gemm 124 | ) 125 | dll_name = model_name + ".dll" if sys.platform == "win32" else model_name + ".so" 126 | total_usage = compile_model( 127 | Y, target, work_dir, model_name, constants=params_ait if constants else None, dll_name=dll_name, 128 | ) 129 | sd = None 130 | vram = round(total_usage / 1024 / 1024) 131 | model_type = "vae_encode" if vae_encode else "vae_decode" 132 | process(work_dir, model_name, dll_name, target._arch, _height[-1], _width[-1], _batch_size[-1], vram, out_dir, sd, model_type) -------------------------------------------------------------------------------- /AITemplate/ait/inference.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | 5 | from .module import Model 6 | 7 | 8 | class AITemplateModelWrapper(torch.nn.Module): 9 | def __init__( 10 | self, 11 | unet_ait_exe: Model, 12 | alphas_cumprod: torch.Tensor, 13 | ): 14 | super().__init__() 15 | self.alphas_cumprod = alphas_cumprod 16 | self.unet_ait_exe = unet_ait_exe 17 | 18 | def apply_model( 19 | self, 20 | x: torch.Tensor, 21 | t: torch.Tensor, 22 | c_crossattn = None, 23 | c_concat = None, 24 | control = None, 25 | c_adm = None, 26 | transformer_options = None, 27 | ): 28 | timesteps_pt = t 29 | latent_model_input = x 30 | encoder_hidden_states = None 31 | down_block_residuals = None 32 | mid_block_residual = None 33 | add_embeds = None 34 | if c_crossattn is not None: 35 | encoder_hidden_states = c_crossattn 36 | if c_concat is not None: 37 | latent_model_input = torch.cat([x] + c_concat, dim=1) 38 | if control is not None: 39 | down_block_residuals = control["output"] 40 | mid_block_residual = control["middle"][0] 41 | if c_adm is not None: 42 | add_embeds = c_adm 43 | return unet_inference( 44 | self.unet_ait_exe, 45 | latent_model_input=latent_model_input, 46 | timesteps=timesteps_pt, 47 | encoder_hidden_states=encoder_hidden_states, 48 | down_block_residuals=down_block_residuals, 49 | mid_block_residual=mid_block_residual, 50 | add_embeds=add_embeds, 51 | ) 52 | 53 | 54 | def unet_inference( 55 | exe_module: Model, 56 | latent_model_input: torch.Tensor, 57 | timesteps: torch.Tensor, 58 | encoder_hidden_states: torch.Tensor, 59 | class_labels: torch.Tensor = None, 60 | down_block_residuals: List[torch.Tensor] = None, 61 | mid_block_residual: torch.Tensor = None, 62 | device: str = "cuda", 63 | dtype: str = "float16", 64 | benchmark: bool = False, 65 | add_embeds: torch.Tensor = None, 66 | ): 67 | batch = latent_model_input.shape[0] 68 | height, width = latent_model_input.shape[2], latent_model_input.shape[3] 69 | timesteps_pt = timesteps.expand(batch) 70 | inputs = { 71 | "latent_model_input": latent_model_input.permute((0, 2, 3, 1)) 72 | .contiguous() 73 | .to(device), 74 | "timesteps": timesteps_pt.to(device), 75 | "encoder_hidden_states": encoder_hidden_states.to(device), 76 | } 77 | if class_labels is not None: 78 | inputs["class_labels"] = class_labels.contiguous().to(device) 79 | if down_block_residuals is not None and mid_block_residual is not None: 80 | for i, y in enumerate(down_block_residuals): 81 | inputs[f"down_block_residual_{i}"] = y.permute((0, 2, 3, 1)).contiguous().to(device) 82 | inputs["mid_block_residual"] = mid_block_residual.permute((0, 2, 3, 1)).contiguous().to(device) 83 | if add_embeds is not None: 84 | inputs["add_embeds"] = add_embeds.to(device) 85 | if dtype == "float16": 86 | for k, v in inputs.items(): 87 | if k == "class_labels ": 88 | continue 89 | inputs[k] = v.half() 90 | ys = [] 91 | num_outputs = len(exe_module.get_output_name_to_index_map()) 92 | for i in range(num_outputs): 93 | shape = exe_module.get_output_maximum_shape(i) 94 | shape[0] = batch 95 | shape[1] = height 96 | shape[2] = width 97 | ys.append(torch.empty(shape).cuda().half()) 98 | exe_module.run_with_tensors(inputs, ys, graph_mode=False) 99 | noise_pred = ys[0].permute((0, 3, 1, 2)).float() 100 | if benchmark: 101 | t, _, _ = exe_module.benchmark_with_tensors( 102 | inputs=inputs, 103 | outputs=ys, 104 | count=50, 105 | repeat=4, 106 | ) 107 | print(f"unet latency: {t} ms, it/s: {1000 / t}") 108 | return noise_pred.cpu() 109 | 110 | 111 | def controlnet_inference( 112 | exe_module: Model, 113 | latent_model_input: torch.Tensor, 114 | timesteps: torch.Tensor, 115 | encoder_hidden_states: torch.Tensor, 116 | controlnet_cond: torch.Tensor, 117 | add_embeds: torch.Tensor = None, 118 | device: str = "cuda", 119 | dtype: str = "float16", 120 | benchmark: bool = False, 121 | ): 122 | if controlnet_cond.shape[0] != latent_model_input.shape[0]: 123 | controlnet_cond = controlnet_cond.expand(latent_model_input.shape[0], -1, -1, -1) 124 | if type(encoder_hidden_states) == dict: 125 | encoder_hidden_states = encoder_hidden_states['c_crossattn'] 126 | inputs = { 127 | "latent_model_input": latent_model_input.permute((0, 2, 3, 1)) 128 | .contiguous() 129 | .to(device), 130 | "timesteps": timesteps.to(device), 131 | "encoder_hidden_states": encoder_hidden_states.to(device), 132 | "control_hint": controlnet_cond.permute((0, 2, 3, 1)).contiguous().to(device), 133 | } 134 | if add_embeds is not None: 135 | inputs["add_embeds"] = add_embeds.to(device) 136 | if dtype == "float16": 137 | for k, v in inputs.items(): 138 | inputs[k] = v.half() 139 | ys = {} 140 | for name, idx in exe_module.get_output_name_to_index_map().items(): 141 | shape = exe_module.get_output_maximum_shape(idx) 142 | shape = torch.empty(shape).to(device) 143 | if dtype == "float16": 144 | shape = shape.half() 145 | ys[name] = shape 146 | exe_module.run_with_tensors(inputs, ys, graph_mode=False) 147 | ys = {k: y.permute((0, 3, 1, 2)).float() for k, y in ys.items()} 148 | if benchmark: 149 | ys = {} 150 | for name, idx in exe_module.get_output_name_to_index_map().items(): 151 | shape = exe_module.get_output_maximum_shape(idx) 152 | shape = torch.empty(shape).to(device) 153 | if dtype == "float16": 154 | shape = shape.half() 155 | ys[name] = shape 156 | t, _, _ = exe_module.benchmark_with_tensors( 157 | inputs=inputs, 158 | outputs=ys, 159 | count=50, 160 | repeat=4, 161 | ) 162 | print(f"controlnet latency: {t} ms, it/s: {1000 / t}") 163 | return ys 164 | 165 | 166 | 167 | def vae_inference( 168 | exe_module: Model, 169 | vae_input: torch.Tensor, 170 | factor: int = 8, 171 | device: str = "cuda", 172 | dtype: str = "float16", 173 | encoder: bool = False, 174 | latent_channels: int = 4, 175 | ): 176 | batch = vae_input.shape[0] 177 | height, width = vae_input.shape[2], vae_input.shape[3] 178 | if encoder: 179 | height = height // factor 180 | width = width // factor 181 | else: 182 | height = height * factor 183 | width = width * factor 184 | input_name = "pixels" if encoder else "latent" 185 | inputs = { 186 | input_name: torch.permute(vae_input, (0, 2, 3, 1)) 187 | .contiguous() 188 | .to(device), 189 | } 190 | if encoder: 191 | sample = torch.randn(batch, latent_channels, height, width) 192 | inputs["random_sample"] = torch.permute(sample, (0, 2, 3, 1)).contiguous().to(device) 193 | if dtype == "float16": 194 | for k, v in inputs.items(): 195 | inputs[k] = v.half() 196 | ys = [] 197 | num_outputs = len(exe_module.get_output_name_to_index_map()) 198 | for i in range(num_outputs): 199 | shape = exe_module.get_output_maximum_shape(i) 200 | shape[0] = batch 201 | shape[1] = height 202 | shape[2] = width 203 | ys.append(torch.empty(shape).to(device)) 204 | if dtype == "float16": 205 | ys[i] = ys[i].half() 206 | exe_module.run_with_tensors(inputs, ys, graph_mode=False) 207 | vae_out = ys[0].permute((0, 3, 1, 2)).cpu().float() 208 | return vae_out 209 | 210 | 211 | def clip_inference( 212 | exe_module: Model, 213 | input_ids: torch.Tensor, 214 | seqlen: int = 77, 215 | device: str = "cuda", 216 | dtype: str = "float16", 217 | ): 218 | batch = input_ids.shape[0] 219 | input_ids = input_ids.to(device) 220 | position_ids = torch.arange(seqlen).expand((batch, -1)).to(device) 221 | inputs = { 222 | "input_ids": input_ids, 223 | "position_ids": position_ids, 224 | } 225 | ys = [] 226 | num_outputs = len(exe_module.get_output_name_to_index_map()) 227 | for i in range(num_outputs): 228 | shape = exe_module.get_output_maximum_shape(i) 229 | shape[0] = batch 230 | ys.append(torch.empty(shape).to(device)) 231 | if dtype == "float16": 232 | ys[i] = ys[i].half() 233 | exe_module.run_with_tensors(inputs, ys, graph_mode=False) 234 | return ys[0].cpu().float() 235 | -------------------------------------------------------------------------------- /AITemplate/ait/load.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import json 4 | import os 5 | import lzma 6 | import requests 7 | import torch 8 | try: 9 | from diffusers import AutoencoderKL, ControlNetModel, UNet2DConditionModel 10 | except ImportError: 11 | pass 12 | try: 13 | from transformers import CLIPTextModel 14 | except ImportError: 15 | pass 16 | 17 | from .module import Model 18 | from .util import torch_dtype_from_str, convert_ldm_unet_checkpoint, convert_text_enc_state_dict, convert_ldm_vae_checkpoint 19 | from .util.mapping import map_clip, map_controlnet, map_unet, map_vae 20 | 21 | 22 | class AITLoader: 23 | def __init__(self, 24 | modules_path: str = "./modules/", 25 | num_runtimes: int = 1, 26 | device: Union[str, torch.device] = "cuda", 27 | dtype: str = "float16", 28 | ) -> None: 29 | """ 30 | device and dtype can be overriden at the function level 31 | device must be a cuda device 32 | """ 33 | self.device = device 34 | self.dtype = dtype 35 | self.num_runtimes = num_runtimes 36 | self.modules_path = modules_path 37 | self.extension = "dll" if os.name == "nt" else "so" 38 | try: 39 | self.modules = json.load(open(f"{modules_path}/modules.json", "r")) 40 | if type(self.modules) == dict: 41 | self.modules = list(self.modules.values()) 42 | except FileNotFoundError: 43 | raise FileNotFoundError(f"modules.json not found in {modules_path}") 44 | except json.decoder.JSONDecodeError: 45 | raise ValueError(f"modules.json in {modules_path} is not a valid json file") 46 | 47 | def download_module(self, sha256: str, url: str): 48 | module_path = f"{self.modules_path}/{sha256}.{self.extension}" 49 | temp_path = f"{self.modules_path}/{sha256}.{self.extension}.xz" 50 | if os.path.exists(module_path): 51 | return 52 | r = requests.get(url, stream=True) 53 | with open(temp_path, "wb") as f: 54 | for chunk in r.iter_content(chunk_size=8192): 55 | f.write(chunk) 56 | with lzma.open(temp_path, "rb") as f: 57 | with open(module_path, "wb") as g: 58 | g.write(f.read()) 59 | os.remove(temp_path) 60 | 61 | 62 | 63 | def load_module( 64 | self, sha256: str, url: str 65 | ): 66 | module_path = f"{self.modules_path}/{sha256}.{self.extension}" 67 | download = False 68 | if not os.path.exists(module_path): 69 | download = True 70 | if download: 71 | self.download_module(sha256, url) 72 | return self.load(module_path) 73 | 74 | 75 | def filter_modules(self, operating_system: str, sd: str, cuda: str, batch_size: int, resolution: int, model_type: str, largest: bool = False): 76 | modules = [x for x in self.modules if x["os"] == operating_system and x["sd"] == sd and x["cuda"] == cuda and x["batch_size"] == batch_size and x["resolution"] >= resolution and model_type == x["model"]] 77 | if len(modules) == 0: 78 | raise ValueError(f"No modules found for {operating_system} {sd} {cuda} {batch_size} {resolution} {model_type}") 79 | print(f"Found {len(modules)} modules for {operating_system} {sd} {cuda} {batch_size} {resolution} {model_type}") 80 | modules = sorted(modules, key=lambda k: k['resolution'], reverse=largest) 81 | print(f"Using {modules[0]['sha256']}") 82 | return modules 83 | 84 | 85 | def load( 86 | self, 87 | path: str, 88 | ) -> Model: 89 | return Model(lib_path=path, num_runtimes=self.num_runtimes) 90 | 91 | def compvis_unet( 92 | self, 93 | state_dict: dict, 94 | ) -> dict: 95 | """ 96 | removes: 97 | model.diffusion_model. 98 | diffusion_model. 99 | from keys if present before conversion 100 | """ 101 | return convert_ldm_unet_checkpoint(state_dict) 102 | 103 | def compvis_clip( 104 | self, 105 | state_dict: dict, 106 | ) -> dict: 107 | """ 108 | removes: 109 | cond_stage_model.transformer. 110 | cond_stage_model.model. 111 | from keys if present before conversion 112 | """ 113 | return convert_text_enc_state_dict(state_dict) 114 | 115 | def compvis_vae( 116 | self, 117 | state_dict: dict, 118 | ) -> dict: 119 | """ 120 | removes: 121 | first_stage_model. 122 | from keys if present before conversion 123 | """ 124 | return convert_ldm_vae_checkpoint(state_dict) 125 | 126 | def compvis_controlnet( 127 | self, 128 | state_dict: dict, 129 | ) -> dict: 130 | """ 131 | removes: 132 | control_model. 133 | from keys if present before conversion 134 | """ 135 | return convert_ldm_unet_checkpoint(state_dict, controlnet=True) 136 | 137 | def diffusers_unet( 138 | self, 139 | hf_hub_or_path: str, 140 | dtype: str = "float16", 141 | subfolder: str = "unet", 142 | revision: str = "fp16", 143 | ): 144 | return UNet2DConditionModel.from_pretrained( 145 | hf_hub_or_path, 146 | subfolder="unet" if not hf_hub_or_path.endswith("unet") else None, 147 | variant="fp16", 148 | use_safetensors=True, 149 | torch_dtype=torch_dtype_from_str(dtype) 150 | ) 151 | 152 | def diffusers_vae( 153 | self, 154 | hf_hub_or_path: str, 155 | dtype: str = "float16", 156 | subfolder: str = "vae", 157 | revision: str = "fp16", 158 | ): 159 | return AutoencoderKL.from_pretrained( 160 | hf_hub_or_path, 161 | subfolder=subfolder, 162 | revision=revision, 163 | torch_dtype=torch_dtype_from_str(dtype) 164 | ) 165 | 166 | def diffusers_controlnet( 167 | self, 168 | hf_hub_or_path: str, 169 | dtype: str = "float16", 170 | subfolder: str = None, 171 | revision: str = None, 172 | ): 173 | return ControlNetModel.from_pretrained( 174 | hf_hub_or_path, 175 | subfolder=subfolder, 176 | revision=revision, 177 | # variant="fp16", 178 | use_safetensors=True, 179 | torch_dtype=torch_dtype_from_str(dtype) 180 | ) 181 | 182 | def diffusers_clip( 183 | self, 184 | hf_hub_or_path: str, 185 | dtype: str = "float16", 186 | subfolder: str = "text_encoder", 187 | revision: str = "fp16", 188 | ): 189 | return CLIPTextModel.from_pretrained( 190 | hf_hub_or_path, 191 | subfolder=subfolder, 192 | revision=revision, 193 | torch_dtype=torch_dtype_from_str(dtype) 194 | ) 195 | 196 | def apply( 197 | self, 198 | aitemplate_module: Model, 199 | ait_params: dict, 200 | ) -> Model: 201 | aitemplate_module.set_many_constants_with_tensors(ait_params) 202 | aitemplate_module.fold_constants() 203 | return aitemplate_module 204 | 205 | def apply_unet( 206 | self, 207 | aitemplate_module: Model, 208 | unet,#: Union[UNet2DConditionModel, dict], 209 | in_channels: int = None, 210 | conv_in_key: str = None, 211 | dim: int = 320, 212 | device: Union[str, torch.device] = None, 213 | dtype: str = None, 214 | ) -> Model: 215 | """ 216 | you don't need to set in_channels or conv_in_key unless 217 | you are experimenting with other UNets 218 | """ 219 | device = self.device if device is None else device 220 | dtype = self.dtype if dtype is None else dtype 221 | ait_params = map_unet(unet, in_channels=in_channels, conv_in_key=conv_in_key, dim=dim, device=device, dtype=dtype) 222 | return self.apply(aitemplate_module, ait_params) 223 | 224 | def apply_clip( 225 | self, 226 | aitemplate_module: Model, 227 | clip,#: Union[CLIPTextModel, dict], 228 | device: Union[str, torch.device] = None, 229 | dtype: str = None, 230 | ) -> Model: 231 | device = self.device if device is None else device 232 | dtype = self.dtype if dtype is None else dtype 233 | ait_params = map_clip(clip, device=device, dtype=dtype) 234 | return self.apply(aitemplate_module, ait_params) 235 | 236 | def apply_controlnet( 237 | self, 238 | aitemplate_module: Model, 239 | controlnet,#: Union[ControlNetModel, dict], 240 | dim: int = 320, 241 | device: Union[str, torch.device] = None, 242 | dtype: str = None, 243 | ) -> Model: 244 | device = self.device if device is None else device 245 | dtype = self.dtype if dtype is None else dtype 246 | ait_params = map_controlnet(controlnet, dim=dim, device=device, dtype=dtype) 247 | return self.apply(aitemplate_module, ait_params) 248 | 249 | def apply_vae( 250 | self, 251 | aitemplate_module: Model, 252 | vae,#: Union[AutoencoderKL, dict], 253 | device: Union[str, torch.device] = None, 254 | dtype: str = None, 255 | encoder: bool = False, 256 | ) -> Model: 257 | device = self.device if device is None else device 258 | dtype = self.dtype if dtype is None else dtype 259 | ait_params = map_vae(vae, device=device, dtype=dtype, encoder=encoder) 260 | return self.apply(aitemplate_module, ait_params) 261 | -------------------------------------------------------------------------------- /AITemplate/ait/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | from .vae import AutoencoderKL as AIT_AutoencoderKL 2 | 3 | __all__ = ["AIT_AutoencoderKL"] 4 | -------------------------------------------------------------------------------- /AITemplate/ait/modeling/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | 16 | """ 17 | Implementations are translated from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py. 18 | """ 19 | 20 | from typing import Optional 21 | 22 | from aitemplate.compiler.ops import reshape 23 | from aitemplate.frontend import nn, Tensor 24 | 25 | 26 | class AttentionBlock(nn.Module): 27 | """ 28 | An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted 29 | to the N-d case. 30 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. 31 | Uses three q, k, v linear layers to compute attention. 32 | Parameters: 33 | batch_size (:obj:`int`): The number of examples per batch. 34 | height (:obj:`int`): Height of each image example. 35 | width (:obj:`int`): Width of each image example. 36 | channels (:obj:`int`): The number of channels in the input and output. 37 | num_head_channels (:obj:`int`, *optional*): 38 | The number of channels in each head. If None, then `num_heads` = 1. 39 | num_groups (:obj:`int`, *optional*, defaults to 32): The number of groups to use for group norm. 40 | eps (:obj:`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm. 41 | """ 42 | 43 | def __init__( 44 | self, 45 | batch_size: int, 46 | height: int, 47 | width: int, 48 | channels: int, 49 | num_head_channels: Optional[int] = None, 50 | num_groups: int = 32, 51 | rescale_output_factor: float = 1.0, 52 | eps: float = 1e-5, 53 | dtype="float16", 54 | ): 55 | super().__init__() 56 | self.batch_size = batch_size 57 | self.channels = channels 58 | self.num_heads = ( 59 | channels // num_head_channels if num_head_channels is not None else 1 60 | ) 61 | self.num_head_size = num_head_channels 62 | self.group_norm = nn.GroupNorm(num_groups, channels, eps, dtype=dtype) 63 | self.attention = nn.CrossAttention( 64 | channels, 65 | height * width, 66 | height * width, 67 | self.num_heads, 68 | qkv_bias=True, 69 | dtype=dtype 70 | ) 71 | self.rescale_output_factor = rescale_output_factor 72 | 73 | def forward(self, hidden_states) -> Tensor: 74 | """ 75 | input hidden_states shape: [batch, height, width, channel] 76 | output shape: [batch, height, width, channel] 77 | """ 78 | 79 | residual = hidden_states 80 | 81 | # norm 82 | hidden_states = self.group_norm(hidden_states) 83 | o_shape = hidden_states.shape() 84 | batch_dim = o_shape[0] 85 | 86 | hidden_states = reshape()( 87 | hidden_states, 88 | [batch_dim, -1, self.channels], 89 | ) 90 | 91 | res = self.attention(hidden_states, hidden_states, hidden_states, residual) * ( 92 | 1 / self.rescale_output_factor 93 | ) 94 | 95 | res = reshape()(res, o_shape) 96 | return res 97 | -------------------------------------------------------------------------------- /AITemplate/ait/modeling/embeddings.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | import math 16 | 17 | from aitemplate.compiler import ops 18 | from aitemplate.frontend import nn, Tensor 19 | 20 | 21 | def get_shape(x): 22 | shape = [it.value() for it in x._attrs["shape"]] 23 | return shape 24 | 25 | 26 | def get_timestep_embedding( 27 | timesteps: Tensor, 28 | embedding_dim: int, 29 | flip_sin_to_cos: bool = False, 30 | downscale_freq_shift: float = 1, 31 | scale: float = 1, 32 | max_period: int = 10000, 33 | dtype: str = "float16", 34 | arange_name = "arange", 35 | ): 36 | """ 37 | This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. 38 | 39 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 40 | These may be fractional. 41 | :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the 42 | embeddings. :return: an [N x dim] Tensor of positional embeddings. 43 | """ 44 | assert timesteps._rank() == 1, "Timesteps should be a 1d-array" 45 | 46 | half_dim = embedding_dim // 2 47 | 48 | exponent = (-math.log(max_period)) * Tensor( 49 | shape=[half_dim], dtype=dtype, name=arange_name 50 | ) 51 | 52 | exponent = exponent * (1.0 / (half_dim - downscale_freq_shift)) 53 | 54 | emb = ops.exp(exponent) 55 | emb = ops.reshape()(timesteps, [-1, 1]) * ops.reshape()(emb, [1, -1]) 56 | 57 | # scale embeddings 58 | emb = scale * emb 59 | 60 | # concat sine and cosine embeddings 61 | if flip_sin_to_cos: 62 | emb = ops.concatenate()( 63 | [ops.cos(emb), ops.sin(emb)], 64 | dim=-1, 65 | ) 66 | else: 67 | emb = ops.concatenate()( 68 | [ops.sin(emb), ops.cos(emb)], 69 | dim=-1, 70 | ) 71 | return emb 72 | 73 | 74 | class TimestepEmbedding(nn.Module): 75 | def __init__(self, channel: int, time_embed_dim: int, act_fn: str = "silu", dtype: str = "float16"): 76 | super().__init__() 77 | 78 | self.linear_1 = nn.Linear(channel, time_embed_dim, specialization="swish", dtype=dtype) 79 | self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim, dtype=dtype) 80 | 81 | def forward(self, sample): 82 | sample = self.linear_1(sample) 83 | sample = self.linear_2(sample) 84 | return sample 85 | 86 | 87 | class Timesteps(nn.Module): 88 | def __init__( 89 | self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, dtype: str = "float16", arange_name = "arange" 90 | ): 91 | super().__init__() 92 | self.num_channels = num_channels 93 | self.flip_sin_to_cos = flip_sin_to_cos 94 | self.downscale_freq_shift = downscale_freq_shift 95 | self.dtype = dtype 96 | self.arange_name = arange_name 97 | 98 | def forward(self, timesteps): 99 | t_emb = get_timestep_embedding( 100 | timesteps, 101 | self.num_channels, 102 | flip_sin_to_cos=self.flip_sin_to_cos, 103 | downscale_freq_shift=self.downscale_freq_shift, 104 | dtype=self.dtype, 105 | arange_name=self.arange_name, 106 | ) 107 | return t_emb 108 | -------------------------------------------------------------------------------- /AITemplate/ait/modeling/resnet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | from aitemplate.compiler import ops 16 | from aitemplate.frontend import nn, Tensor 17 | 18 | 19 | def get_shape(x): 20 | shape = [it.value() for it in x._attrs["shape"]] 21 | return shape 22 | 23 | 24 | class Upsample2D(nn.Module): 25 | """ 26 | An upsampling layer with an optional convolution. 27 | 28 | :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is 29 | applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 30 | upsampling occurs in the inner-two dimensions. 31 | """ 32 | 33 | def __init__( 34 | self, 35 | channels, 36 | use_conv=False, 37 | use_conv_transpose=False, 38 | out_channels=None, 39 | name="conv", 40 | dtype="float16" 41 | ): 42 | super().__init__() 43 | self.channels = channels 44 | self.out_channels = out_channels or channels 45 | self.use_conv = use_conv 46 | self.use_conv_transpose = use_conv_transpose 47 | self.name = name 48 | 49 | conv = None 50 | if use_conv_transpose: 51 | conv = nn.ConvTranspose2dBias(channels, self.out_channels, 4, 2, 1, dtype=dtype) 52 | elif use_conv: 53 | conv = nn.Conv2dBias(self.channels, self.out_channels, 3, 1, 1, dtype=dtype) 54 | 55 | # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed 56 | if name == "conv": 57 | self.conv = conv 58 | else: 59 | self.Conv2d_0 = conv 60 | 61 | def forward(self, x, upsample_size=None): 62 | if self.use_conv_transpose: 63 | return self.conv(x) 64 | out = None 65 | if upsample_size is not None: 66 | out = ops.size()(x) 67 | out[1] = upsample_size[1] 68 | out[2] = upsample_size[2] 69 | out = [x._attrs["int_var"] for x in out] 70 | out = Tensor(out) 71 | x = nn.Upsampling2d(scale_factor=2.0, mode="nearest")(x, out) 72 | 73 | # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed 74 | if self.use_conv: 75 | if self.name == "conv": 76 | x = self.conv(x) 77 | else: 78 | x = self.Conv2d_0(x) 79 | 80 | return x 81 | 82 | 83 | class Downsample2D(nn.Module): 84 | """ 85 | A downsampling layer with an optional convolution. 86 | 87 | :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is 88 | applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 89 | downsampling occurs in the inner-two dimensions. 90 | """ 91 | 92 | def __init__( 93 | self, channels, use_conv=False, out_channels=None, padding=1, name="conv", dtype="float16" 94 | ): 95 | super().__init__() 96 | self.channels = channels 97 | self.out_channels = out_channels or channels 98 | self.use_conv = use_conv 99 | self.padding = padding 100 | stride = 2 101 | self.name = name 102 | self.dtype = dtype 103 | 104 | if use_conv: 105 | conv = nn.Conv2dBias( 106 | self.channels, self.out_channels, 3, stride=stride, dtype=dtype, padding=padding 107 | ) 108 | else: 109 | assert self.channels == self.out_channels 110 | conv = nn.AvgPool2d(kernel_size=stride, stride=stride, padding=0) 111 | 112 | # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed 113 | if name == "conv": 114 | self.Conv2d_0 = conv 115 | self.conv = conv 116 | elif name == "Conv2d_0": 117 | self.conv = conv 118 | else: 119 | self.conv = conv 120 | 121 | def forward(self, hidden_states): 122 | if self.use_conv and self.padding == 0: 123 | padding = ops.full()([0, 1, 0, 0], 0.0, dtype=self.dtype) 124 | padding._attrs["shape"][0] = hidden_states._attrs["shape"][0] 125 | padding._attrs["shape"][2] = hidden_states._attrs["shape"][2] 126 | padding._attrs["shape"][3] = hidden_states._attrs["shape"][3] 127 | hidden_states = ops.concatenate()([hidden_states, padding], dim=1) 128 | padding = ops.full()([0, 0, 1, 0], 0.0, dtype=self.dtype) 129 | padding._attrs["shape"][0] = hidden_states._attrs["shape"][0] 130 | padding._attrs["shape"][1] = hidden_states._attrs["shape"][1] 131 | padding._attrs["shape"][3] = hidden_states._attrs["shape"][3] 132 | hidden_states = ops.concatenate()([hidden_states, padding], dim=2) 133 | 134 | hidden_states = self.conv(hidden_states) 135 | return hidden_states 136 | 137 | 138 | class ResnetBlock2D(nn.Module): 139 | def __init__( 140 | self, 141 | *, 142 | in_channels, 143 | out_channels=None, 144 | conv_shortcut=False, 145 | dropout=0.0, 146 | temb_channels=512, 147 | groups=32, 148 | groups_out=None, 149 | pre_norm=True, 150 | eps=1e-6, 151 | non_linearity="swish", 152 | time_embedding_norm="default", 153 | kernel=None, 154 | output_scale_factor=1.0, 155 | use_nin_shortcut=None, 156 | up=False, 157 | down=False, 158 | dtype="float16" 159 | ): 160 | super().__init__() 161 | self.pre_norm = pre_norm 162 | self.pre_norm = True 163 | self.in_channels = in_channels 164 | out_channels = in_channels if out_channels is None else out_channels 165 | self.out_channels = out_channels 166 | self.use_conv_shortcut = conv_shortcut 167 | self.time_embedding_norm = time_embedding_norm 168 | self.up = up 169 | self.down = down 170 | self.output_scale_factor = output_scale_factor 171 | 172 | if groups_out is None: 173 | groups_out = groups 174 | 175 | self.norm1 = nn.GroupNorm( 176 | num_groups=groups, 177 | num_channels=in_channels, 178 | eps=eps, 179 | affine=True, 180 | use_swish=True, 181 | dtype=dtype 182 | ) 183 | 184 | self.conv1 = nn.Conv2dBias( 185 | in_channels, out_channels, kernel_size=3, stride=1, padding=1, dtype=dtype 186 | ) 187 | 188 | if temb_channels is not None: 189 | self.time_emb_proj = nn.Linear(temb_channels, out_channels, dtype=dtype) 190 | else: 191 | self.time_emb_proj = None 192 | 193 | self.norm2 = nn.GroupNorm( 194 | num_groups=groups_out, 195 | num_channels=out_channels, 196 | eps=eps, 197 | affine=True, 198 | use_swish=True, 199 | dtype=dtype 200 | ) 201 | self.dropout = nn.Dropout(dropout, dtype=dtype) 202 | self.conv2 = nn.Conv2dBias( 203 | out_channels, out_channels, kernel_size=3, stride=1, padding=1, dtype=dtype 204 | ) 205 | 206 | self.upsample = self.downsample = None 207 | 208 | self.use_nin_shortcut = ( 209 | self.in_channels != self.out_channels 210 | if use_nin_shortcut is None 211 | else use_nin_shortcut 212 | ) 213 | 214 | if self.use_nin_shortcut: 215 | self.conv_shortcut = nn.Conv2dBias( 216 | in_channels, out_channels, 1, 1, 0,dtype=dtype 217 | ) # kernel_size=1, stride=1, padding=0) # conv_bias_add 218 | else: 219 | self.conv_shortcut = None 220 | 221 | def forward(self, x, temb=None): 222 | hidden_states = x 223 | 224 | # make sure hidden states is in float32 225 | # when running in half-precision 226 | hidden_states = self.norm1( 227 | hidden_states 228 | ) # .float()).type(hidden_states.dtype) # fused swish 229 | # hidden_states = self.nonlinearity(hidden_states) 230 | 231 | if self.upsample is not None: 232 | x = self.upsample(x) 233 | hidden_states = self.upsample(hidden_states) 234 | elif self.downsample is not None: 235 | x = self.downsample(x) 236 | hidden_states = self.downsample(hidden_states) 237 | 238 | hidden_states = self.conv1(hidden_states) 239 | bs, h, w, dim = hidden_states.shape() 240 | if temb is not None: 241 | temb = self.time_emb_proj(ops.silu(temb)) 242 | bs, dim = temb.shape() 243 | temb = ops.reshape()(temb, [bs, 1, 1, dim]) 244 | hidden_states = hidden_states + temb 245 | 246 | # make sure hidden states is in float32 247 | # when running in half-precision 248 | hidden_states = self.norm2(hidden_states) 249 | 250 | hidden_states = self.dropout(hidden_states) 251 | hidden_states = self.conv2(hidden_states) 252 | 253 | if self.conv_shortcut is not None: 254 | x = self.conv_shortcut(x) 255 | 256 | out = hidden_states + x 257 | 258 | return out 259 | -------------------------------------------------------------------------------- /AITemplate/ait/modeling/vae.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Translated from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/vae.py. 16 | """ 17 | 18 | from typing import Tuple 19 | 20 | from aitemplate.frontend import nn, Tensor 21 | from aitemplate.compiler import ops 22 | 23 | from .unet_blocks import get_down_block, get_up_block, UNetMidBlock2D 24 | 25 | 26 | class Decoder(nn.Module): 27 | def __init__( 28 | self, 29 | batch_size, 30 | height, 31 | width, 32 | in_channels=3, 33 | out_channels=3, 34 | up_block_types=("UpDecoderBlock2D",), 35 | block_out_channels=(64,), 36 | layers_per_block=2, 37 | act_fn="silu", 38 | dtype="float16" 39 | ): 40 | super().__init__() 41 | self.layers_per_block = layers_per_block 42 | 43 | self.conv_in = nn.Conv2dBias( 44 | in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1, dtype=dtype 45 | ) 46 | 47 | # mid 48 | self.mid_block = UNetMidBlock2D( 49 | batch_size, 50 | height, 51 | width, 52 | in_channels=block_out_channels[-1], 53 | resnet_eps=1e-6, 54 | resnet_act_fn=act_fn, 55 | output_scale_factor=1, 56 | resnet_time_scale_shift="default", 57 | attn_num_head_channels=None, 58 | resnet_groups=32, 59 | temb_channels=None, 60 | dtype=dtype 61 | ) 62 | 63 | # up 64 | self.up_blocks = nn.ModuleList([]) 65 | reversed_block_out_channels = list(reversed(block_out_channels)) 66 | output_channel = reversed_block_out_channels[0] 67 | for i, up_block_type in enumerate(up_block_types): 68 | prev_output_channel = output_channel 69 | output_channel = reversed_block_out_channels[i] 70 | 71 | is_final_block = i == len(block_out_channels) - 1 72 | 73 | up_block = get_up_block( 74 | up_block_type, 75 | num_layers=self.layers_per_block + 1, 76 | in_channels=prev_output_channel, 77 | out_channels=output_channel, 78 | prev_output_channel=None, 79 | temb_channels=None, 80 | add_upsample=not is_final_block, 81 | resnet_eps=1e-6, 82 | resnet_act_fn=act_fn, 83 | attn_num_head_channels=None, 84 | dtype=dtype 85 | ) 86 | self.up_blocks.append(up_block) 87 | prev_output_channel = output_channel 88 | 89 | # out 90 | num_groups_out = 32 91 | self.conv_norm_out = nn.GroupNorm( 92 | num_channels=block_out_channels[0], 93 | num_groups=num_groups_out, 94 | eps=1e-6, 95 | use_swish=True, 96 | dtype=dtype 97 | ) 98 | self.conv_out = nn.Conv2dBias( 99 | block_out_channels[0], out_channels, kernel_size=3, padding=1, stride=1,dtype=dtype 100 | ) 101 | 102 | def forward(self, z) -> Tensor: 103 | sample = z 104 | sample = self.conv_in(sample) 105 | 106 | # middle 107 | sample = self.mid_block(sample) 108 | 109 | # up 110 | for up_block in self.up_blocks: 111 | sample = up_block(sample) 112 | 113 | sample = self.conv_norm_out(sample) 114 | sample = self.conv_out(sample) 115 | 116 | return sample 117 | 118 | 119 | class Encoder(nn.Module): 120 | def __init__( 121 | self, 122 | batch_size, 123 | height, 124 | width, 125 | in_channels=3, 126 | out_channels=3, 127 | down_block_types=("DownEncoderBlock2D",), 128 | block_out_channels=(64,), 129 | layers_per_block=2, 130 | norm_num_groups=32, 131 | act_fn="silu", 132 | double_z=True, 133 | dtype="float16", 134 | ): 135 | super().__init__() 136 | self.layers_per_block = layers_per_block 137 | 138 | self.conv_in = nn.Conv2dBiasFewChannels( 139 | in_channels, 140 | block_out_channels[0], 141 | kernel_size=3, 142 | stride=1, 143 | padding=1, 144 | dtype=dtype, 145 | ) 146 | 147 | self.mid_block = None 148 | self.down_blocks = nn.ModuleList([]) 149 | 150 | # down 151 | output_channel = block_out_channels[0] 152 | for i, down_block_type in enumerate(down_block_types): 153 | input_channel = output_channel 154 | output_channel = block_out_channels[i] 155 | is_final_block = i == len(block_out_channels) - 1 156 | 157 | down_block = get_down_block( 158 | down_block_type, 159 | num_layers=self.layers_per_block, 160 | in_channels=input_channel, 161 | out_channels=output_channel, 162 | add_downsample=not is_final_block, 163 | resnet_eps=1e-6, 164 | downsample_padding=0, 165 | resnet_act_fn=act_fn, 166 | resnet_groups=norm_num_groups, 167 | attn_num_head_channels=None, 168 | temb_channels=None, 169 | dtype=dtype, 170 | ) 171 | self.down_blocks.append(down_block) 172 | 173 | # mid 174 | self.mid_block = UNetMidBlock2D( 175 | batch_size, 176 | height, 177 | width, 178 | in_channels=block_out_channels[-1], 179 | resnet_eps=1e-6, 180 | resnet_act_fn=act_fn, 181 | output_scale_factor=1, 182 | resnet_time_scale_shift="default", 183 | attn_num_head_channels=None, 184 | resnet_groups=norm_num_groups, 185 | temb_channels=None, 186 | dtype=dtype, 187 | ) 188 | 189 | # out 190 | self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6, dtype=dtype) 191 | self.conv_act = ops.silu 192 | 193 | conv_out_channels = 2 * out_channels if double_z else out_channels 194 | self.conv_out = nn.Conv2dBias(block_out_channels[-1], conv_out_channels, kernel_size=3, stride=1, padding=1, dtype=dtype) 195 | 196 | def forward(self, x): 197 | sample = x 198 | 199 | sample = self.conv_in(sample) 200 | 201 | for down_block in self.down_blocks: 202 | sample = down_block(sample) 203 | 204 | # middle 205 | sample = self.mid_block(sample) 206 | 207 | # post-process 208 | sample = self.conv_norm_out(sample) 209 | sample = self.conv_act(sample) 210 | sample = self.conv_out(sample) 211 | 212 | return sample 213 | 214 | 215 | class AutoencoderKL(nn.Module): 216 | def __init__( 217 | self, 218 | batch_size: int, 219 | height: int, 220 | width: int, 221 | in_channels: int = 3, 222 | out_channels: int = 3, 223 | down_block_types: Tuple[str] = ("DownEncoderBlock2D",), 224 | up_block_types: Tuple[str] = ("UpDecoderBlock2D",), 225 | block_out_channels: Tuple[int] = (64,), 226 | layers_per_block: int = 1, 227 | act_fn: str = "silu", 228 | latent_channels: int = 4, 229 | norm_num_groups: int = 32, 230 | sample_size: int = 32, 231 | dtype="float16" 232 | ): 233 | super().__init__() 234 | self.decoder = Decoder( 235 | batch_size, 236 | height, 237 | width, 238 | in_channels=latent_channels, 239 | out_channels=out_channels, 240 | up_block_types=up_block_types, 241 | block_out_channels=block_out_channels, 242 | layers_per_block=layers_per_block, 243 | act_fn=act_fn, 244 | dtype=dtype 245 | ) 246 | self.post_quant_conv = nn.Conv2dBias( 247 | latent_channels, latent_channels, kernel_size=1, stride=1, padding=0, dtype=dtype 248 | ) 249 | 250 | self.encoder = Encoder( 251 | batch_size, 252 | height, 253 | width, 254 | in_channels=in_channels, 255 | out_channels=latent_channels, 256 | down_block_types=down_block_types, 257 | block_out_channels=block_out_channels, 258 | layers_per_block=layers_per_block, 259 | act_fn=act_fn, 260 | norm_num_groups=norm_num_groups, 261 | double_z=True, 262 | dtype=dtype, 263 | ) 264 | self.quant_conv = nn.Conv2dBias( 265 | 2 * latent_channels, 2 * latent_channels, kernel_size=1, stride=1, padding=0, dtype=dtype 266 | ) 267 | 268 | def decode(self, z: Tensor, return_dict: bool = True): 269 | z = self.post_quant_conv(z) 270 | dec = self.decoder(z) 271 | dec._attrs["is_output"] = True 272 | dec._attrs["name"] = "pixels" 273 | return dec 274 | 275 | def encode(self, x: Tensor, sample: Tensor = None, return_dict: bool = True, deterministic: bool = False): 276 | h = self.encoder(x) 277 | moments = self.quant_conv(h) 278 | if sample is None: 279 | return moments 280 | mean, logvar = ops.chunk()(moments, 2, dim=3) 281 | logvar = ops.clamp()(logvar, -30.0, 20.0) 282 | std = ops.exp(0.5 * logvar) 283 | var = ops.exp(logvar) 284 | if deterministic: 285 | var = std = Tensor(mean.shape(), value=0.0, dtype=mean._attrs["dtype"]) 286 | sample._attrs["shape"] = mean._attrs["shape"] 287 | std._attrs["shape"] = mean._attrs["shape"] 288 | z = mean + std * sample 289 | z._attrs["is_output"] = True 290 | z._attrs["name"] = "latent" 291 | return z 292 | 293 | 294 | -------------------------------------------------------------------------------- /AITemplate/ait/module/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import Model 2 | 3 | __all__ = ["Model"] 4 | -------------------------------------------------------------------------------- /AITemplate/ait/module/dtype.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | dtype definitions and utility functions of AITemplate 17 | """ 18 | 19 | 20 | _DTYPE2BYTE = { 21 | "bool": 1, 22 | "float16": 2, 23 | "float32": 4, 24 | "float": 4, 25 | "int": 4, 26 | "int32": 4, 27 | "int64": 8, 28 | "bfloat16": 2, 29 | } 30 | 31 | 32 | # Maps dtype strings to AITemplateDtype enum in model_interface.h. 33 | # Must be kept in sync! 34 | # We can consider defining an AITemplateDtype enum to use on the Python 35 | # side at some point, but stick to strings for now to keep things consistent 36 | # with other Python APIs. 37 | _DTYPE_TO_ENUM = { 38 | "float16": 1, 39 | "float32": 2, 40 | "float": 2, 41 | "int": 3, 42 | "int32": 3, 43 | "int64": 4, 44 | "bool": 5, 45 | "bfloat16": 6, 46 | } 47 | 48 | 49 | def get_dtype_size(dtype: str) -> int: 50 | """Returns size (in bytes) of the given dtype str. 51 | 52 | Parameters 53 | ---------- 54 | dtype: str 55 | A data type string. 56 | 57 | Returns 58 | ---------- 59 | int 60 | Size (in bytes) of this dtype. 61 | """ 62 | 63 | if dtype not in _DTYPE2BYTE: 64 | raise KeyError(f"Unknown dtype: {dtype}. Expected one of {_DTYPE2BYTE.keys()}") 65 | return _DTYPE2BYTE[dtype] 66 | 67 | 68 | def normalize_dtype(dtype: str) -> str: 69 | """Returns a normalized dtype str. 70 | 71 | Parameters 72 | ---------- 73 | dtype: str 74 | A data type string. 75 | 76 | Returns 77 | ---------- 78 | str 79 | normalized dtype str. 80 | """ 81 | if dtype == "int": 82 | return "int32" 83 | if dtype == "float": 84 | return "float32" 85 | return dtype 86 | 87 | 88 | def dtype_str_to_enum(dtype: str) -> int: 89 | """Returns the AITemplateDtype enum value (defined in model_interface.h) of 90 | the given dtype str. 91 | 92 | Parameters 93 | ---------- 94 | dtype: str 95 | A data type string. 96 | 97 | Returns 98 | ---------- 99 | int 100 | the AITemplateDtype enum value. 101 | """ 102 | if dtype not in _DTYPE_TO_ENUM: 103 | raise ValueError( 104 | f"Got unsupported input dtype {dtype}! Supported dtypes are: {list(_DTYPE_TO_ENUM.keys())}" 105 | ) 106 | return _DTYPE_TO_ENUM[dtype] 107 | 108 | 109 | def dtype_to_enumerator(dtype: str) -> str: 110 | """Returns the string representation of the AITemplateDtype enum 111 | (defined in model_interface.h) for the given dtype str. 112 | 113 | Parameters 114 | ---------- 115 | dtype: str 116 | A data type string. 117 | 118 | Returns 119 | ---------- 120 | str 121 | the AITemplateDtype enum string representation. 122 | """ 123 | 124 | def _impl(dtype): 125 | if dtype == "float16": 126 | return "kHalf" 127 | elif dtype == "float32" or dtype == "float": 128 | return "kFloat" 129 | elif dtype == "int32" or dtype == "int": 130 | return "kInt" 131 | elif dtype == "int64": 132 | return "kLong" 133 | elif dtype == "bool": 134 | return "kBool" 135 | elif dtype == "bfloat16": 136 | return "kBFloat16" 137 | else: 138 | raise AssertionError(f"unknown dtype {dtype}") 139 | 140 | return f"AITemplateDtype::{_impl(dtype)}" 141 | 142 | 143 | def is_same_dtype(dtype1: str, dtype2: str) -> bool: 144 | """Returns True if dtype1 and dtype2 are the same dtype and False otherwise. 145 | 146 | Parameters 147 | ---------- 148 | dtype1: str 149 | A data type string. 150 | dtype2: str 151 | A data type string. 152 | 153 | Returns 154 | ---------- 155 | bool 156 | whether dtype1 and dtype2 are the same dtype 157 | """ 158 | return normalize_dtype(dtype1) == normalize_dtype(dtype2) 159 | -------------------------------------------------------------------------------- /AITemplate/ait/module/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | miscellaneous utilities 17 | """ 18 | import hashlib 19 | import logging 20 | import os 21 | import platform 22 | 23 | 24 | def is_debug(): 25 | logger = logging.getLogger("aitemplate") 26 | return logger.level == logging.DEBUG 27 | 28 | 29 | def is_linux() -> bool: 30 | return platform.system() == "Linux" 31 | 32 | 33 | def is_windows() -> bool: 34 | return os.name == "nt" 35 | 36 | 37 | def setup_logger(name): 38 | root_logger = logging.getLogger(name) 39 | info_handle = logging.StreamHandler() 40 | formatter = logging.Formatter("%(asctime)s %(levelname)s <%(name)s> %(message)s") 41 | info_handle.setFormatter(formatter) 42 | root_logger.addHandler(info_handle) 43 | root_logger.propagate = False 44 | 45 | DEFAULT_LOGLEVEL = logging.getLogger().level 46 | log_level_str = os.environ.get("LOGLEVEL", None) 47 | LOG_LEVEL = ( 48 | getattr(logging, log_level_str.upper()) 49 | if log_level_str is not None 50 | else DEFAULT_LOGLEVEL 51 | ) 52 | root_logger.setLevel(LOG_LEVEL) 53 | return root_logger 54 | 55 | 56 | def short_str(s, length=8) -> str: 57 | """ 58 | Returns a hashed string, somewhat similar to URL shortener. 59 | """ 60 | hash_str = hashlib.sha256(s.encode()).hexdigest() 61 | return hash_str[0:length] 62 | 63 | 64 | def callstack_stats(enable=False): 65 | if enable: 66 | 67 | def decorator(f): 68 | import cProfile 69 | import io 70 | import pstats 71 | 72 | logger = logging.getLogger(__name__) 73 | 74 | def inner_function(*args, **kwargs): 75 | pr = cProfile.Profile() 76 | pr.enable() 77 | result = f(*args, **kwargs) 78 | pr.disable() 79 | s = io.StringIO() 80 | pstats.Stats(pr, stream=s).sort_stats( 81 | pstats.SortKey.CUMULATIVE 82 | ).print_stats(30) 83 | logger.debug(s.getvalue()) 84 | return result 85 | 86 | return inner_function 87 | 88 | return decorator 89 | else: 90 | 91 | def decorator(f): 92 | def inner_function(*args, **kwargs): 93 | return f(*args, **kwargs) 94 | 95 | return inner_function 96 | 97 | return decorator 98 | -------------------------------------------------------------------------------- /AITemplate/ait/module/torch_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | Functions for working with torch Tensors. 17 | AITemplate doesn't depend on PyTorch, but it exposes 18 | many APIs that work with torch Tensors anyways. 19 | 20 | The functions in this file may assume that 21 | `import torch` will work. 22 | """ 23 | 24 | 25 | def types_mapping(): 26 | from torch import bfloat16, bool, float16, float32, int32, int64 27 | 28 | yield (float16, "float16") 29 | yield (bfloat16, "bfloat16") 30 | yield (float32, "float32") 31 | yield (int32, "int32") 32 | yield (int64, "int64") 33 | yield (bool, "bool") 34 | 35 | 36 | def torch_dtype_to_string(dtype): 37 | for (torch_dtype, ait_dtype) in types_mapping(): 38 | if dtype == torch_dtype: 39 | return ait_dtype 40 | raise ValueError( 41 | f"Got unsupported input dtype {dtype}! " 42 | f"Supported dtypes are: {list(types_mapping())}" 43 | ) 44 | 45 | 46 | def string_to_torch_dtype(string_dtype): 47 | if string_dtype is None: 48 | # Many torch functions take optional dtypes, so 49 | # handling None is useful here. 50 | return None 51 | 52 | for (torch_dtype, ait_dtype) in types_mapping(): 53 | if string_dtype == ait_dtype: 54 | return torch_dtype 55 | raise ValueError( 56 | f"Got unsupported ait dtype {string_dtype}! " 57 | f"Supported dtypes are: {list(types_mapping())}" 58 | ) 59 | -------------------------------------------------------------------------------- /AITemplate/ait/util/__init__.py: -------------------------------------------------------------------------------- 1 | from .ckpt_convert import convert_ldm_unet_checkpoint, convert_ldm_vae_checkpoint, convert_text_enc_state_dict 2 | from .torch_dtype_from_str import torch_dtype_from_str 3 | 4 | __all__ = ["convert_ldm_unet_checkpoint", "torch_dtype_from_str"] -------------------------------------------------------------------------------- /AITemplate/ait/util/mapping/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import map_clip 2 | from .controlnet import map_controlnet 3 | from .vae import map_vae 4 | from .unet import map_unet 5 | 6 | __all__ = ["map_clip", "map_controlnet", "map_vae", "map_unet"] 7 | -------------------------------------------------------------------------------- /AITemplate/ait/util/mapping/clip.py: -------------------------------------------------------------------------------- 1 | try: 2 | from transformers import CLIPTextConfig, CLIPTextModel 3 | except ImportError: 4 | raise ImportError( 5 | "Please install transformers with `pip install transformers` to use this script." 6 | ) 7 | 8 | import torch 9 | from ...util import torch_dtype_from_str 10 | 11 | 12 | def map_clip(pt_mod, device="cuda", dtype="float16"): 13 | pt_params = dict(pt_mod.named_parameters()) 14 | params_ait = {} 15 | for key, arr in pt_params.items(): 16 | arr = arr.to(device, dtype=torch_dtype_from_str(dtype)) 17 | name = key.replace("text_model.", "") 18 | ait_name = name.replace(".", "_") 19 | if name.endswith("out_proj.weight"): 20 | ait_name = ait_name.replace("out_proj", "proj") 21 | elif name.endswith("out_proj.bias"): 22 | ait_name = ait_name.replace("out_proj", "proj") 23 | elif "q_proj" in name: 24 | ait_name = ait_name.replace("q_proj", "proj_q") 25 | elif "k_proj" in name: 26 | ait_name = ait_name.replace("k_proj", "proj_k") 27 | elif "v_proj" in name: 28 | ait_name = ait_name.replace("v_proj", "proj_v") 29 | params_ait[ait_name] = arr 30 | return params_ait 31 | -------------------------------------------------------------------------------- /AITemplate/ait/util/mapping/controlnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ...util import torch_dtype_from_str 3 | 4 | 5 | def map_controlnet(pt_mod, dim=320, device="cuda", dtype="float16"): 6 | if not isinstance(pt_mod, dict): 7 | pt_params = dict(pt_mod.named_parameters()) 8 | else: 9 | pt_params = pt_mod 10 | params_ait = {} 11 | for key, arr in pt_params.items(): 12 | arr = arr.to(device, dtype=torch_dtype_from_str(dtype)) 13 | if len(arr.shape) == 4: 14 | arr = arr.permute((0, 2, 3, 1)).contiguous() 15 | elif key.endswith("ff.net.0.proj.weight"): 16 | w1, w2 = arr.chunk(2, dim=0) 17 | params_ait[key.replace(".", "_")] = w1 18 | params_ait[key.replace(".", "_").replace("proj", "gate")] = w2 19 | continue 20 | elif key.endswith("ff.net.0.proj.bias"): 21 | w1, w2 = arr.chunk(2, dim=0) 22 | params_ait[key.replace(".", "_")] = w1 23 | params_ait[key.replace(".", "_").replace("proj", "gate")] = w2 24 | continue 25 | params_ait[key.replace(".", "_")] = arr 26 | params_ait["controlnet_cond_embedding_conv_in_weight"] = torch.nn.functional.pad( 27 | params_ait["controlnet_cond_embedding_conv_in_weight"], (0, 1, 0, 0, 0, 0, 0, 0) 28 | ) 29 | params_ait["arange"] = ( 30 | torch.arange(start=0, end=dim // 2, dtype=torch.float32).cuda().half() 31 | ) 32 | return params_ait 33 | -------------------------------------------------------------------------------- /AITemplate/ait/util/mapping/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ...util import torch_dtype_from_str 3 | 4 | def map_unet(pt_mod, in_channels=None, conv_in_key=None, dim=320, device="cuda", dtype="float16"): 5 | if in_channels is not None and conv_in_key is None: 6 | raise ValueError("conv_in_key must be specified if in_channels is not None for padding") 7 | if not isinstance(pt_mod, dict): 8 | pt_params = dict(pt_mod.named_parameters()) 9 | else: 10 | pt_params = pt_mod 11 | params_ait = {} 12 | for key, arr in pt_params.items(): 13 | if key.startswith("model.diffusion_model."): 14 | key = key.replace("model.diffusion_model.", "") 15 | arr = arr.to(device, dtype=torch_dtype_from_str(dtype)) 16 | if len(arr.shape) == 4: 17 | arr = arr.permute((0, 2, 3, 1)).contiguous() 18 | elif key.endswith("ff.net.0.proj.weight"): 19 | w1, w2 = arr.chunk(2, dim=0) 20 | params_ait[key.replace(".", "_")] = w1 21 | params_ait[key.replace(".", "_").replace("proj", "gate")] = w2 22 | continue 23 | elif key.endswith("ff.net.0.proj.bias"): 24 | w1, w2 = arr.chunk(2, dim=0) 25 | params_ait[key.replace(".", "_")] = w1 26 | params_ait[key.replace(".", "_").replace("proj", "gate")] = w2 27 | continue 28 | params_ait[key.replace(".", "_")] = arr 29 | 30 | if conv_in_key is not None: 31 | if in_channels % 4 != 0: 32 | pad_by = 4 - (in_channels % 4) 33 | params_ait[conv_in_key] = torch.functional.F.pad(params_ait[conv_in_key], (0, pad_by)) 34 | 35 | params_ait["arange"] = ( 36 | torch.arange(start=0, end=dim // 2, dtype=torch.float32).to(device, dtype=torch_dtype_from_str(dtype)) 37 | ) 38 | 39 | return params_ait 40 | -------------------------------------------------------------------------------- /AITemplate/ait/util/mapping/vae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ...util import torch_dtype_from_str 3 | 4 | def map_vae(pt_module, device="cuda", dtype="float16", encoder=False): 5 | if not isinstance(pt_module, dict): 6 | pt_params = dict(pt_module.named_parameters()) 7 | else: 8 | pt_params = pt_module 9 | params_ait = {} 10 | quant_key = "post_quant" if encoder else "quant" 11 | vae_key = "decoder" if encoder else "encoder" 12 | for key, arr in pt_params.items(): 13 | if key.startswith(vae_key): 14 | continue 15 | if key.startswith(quant_key): 16 | continue 17 | arr = arr.to(device, dtype=torch_dtype_from_str(dtype)) 18 | key = key.replace(".", "_") 19 | if ( 20 | "conv" in key 21 | and "norm" not in key 22 | and key.endswith("_weight") 23 | and len(arr.shape) == 4 24 | ): 25 | params_ait[key] = torch.permute(arr, [0, 2, 3, 1]).contiguous() 26 | elif key.endswith("proj_attn_weight"): 27 | prefix = key[: -len("proj_attn_weight")] 28 | key = prefix + "attention_proj_weight" 29 | params_ait[key] = arr 30 | elif key.endswith("to_out_0_weight"): 31 | prefix = key[: -len("to_out_0_weight")] 32 | key = prefix + "attention_proj_weight" 33 | params_ait[key] = arr 34 | elif key.endswith("proj_attn_bias"): 35 | prefix = key[: -len("proj_attn_bias")] 36 | key = prefix + "attention_proj_bias" 37 | params_ait[key] = arr 38 | elif key.endswith("to_out_0_bias"): 39 | prefix = key[: -len("to_out_0_bias")] 40 | key = prefix + "attention_proj_bias" 41 | params_ait[key] = arr 42 | elif key.endswith("query_weight"): 43 | prefix = key[: -len("query_weight")] 44 | key = prefix + "attention_proj_q_weight" 45 | params_ait[key] = arr 46 | elif key.endswith("to_q_weight"): 47 | prefix = key[: -len("to_q_weight")] 48 | key = prefix + "attention_proj_q_weight" 49 | params_ait[key] = arr 50 | elif key.endswith("query_bias"): 51 | prefix = key[: -len("query_bias")] 52 | key = prefix + "attention_proj_q_bias" 53 | params_ait[key] = arr 54 | elif key.endswith("to_q_bias"): 55 | prefix = key[: -len("to_q_bias")] 56 | key = prefix + "attention_proj_q_bias" 57 | params_ait[key] = arr 58 | elif key.endswith("key_weight"): 59 | prefix = key[: -len("key_weight")] 60 | key = prefix + "attention_proj_k_weight" 61 | params_ait[key] = arr 62 | elif key.endswith("key_bias"): 63 | prefix = key[: -len("key_bias")] 64 | key = prefix + "attention_proj_k_bias" 65 | params_ait[key] = arr 66 | elif key.endswith("value_weight"): 67 | prefix = key[: -len("value_weight")] 68 | key = prefix + "attention_proj_v_weight" 69 | params_ait[key] = arr 70 | elif key.endswith("value_bias"): 71 | prefix = key[: -len("value_bias")] 72 | key = prefix + "attention_proj_v_bias" 73 | params_ait[key] = arr 74 | elif key.endswith("to_k_weight"): 75 | prefix = key[: -len("to_k_weight")] 76 | key = prefix + "attention_proj_k_weight" 77 | params_ait[key] = arr 78 | elif key.endswith("to_v_weight"): 79 | prefix = key[: -len("to_v_weight")] 80 | key = prefix + "attention_proj_v_weight" 81 | params_ait[key] = arr 82 | elif key.endswith("to_k_bias"): 83 | prefix = key[: -len("to_k_bias")] 84 | key = prefix + "attention_proj_k_bias" 85 | params_ait[key] = arr 86 | elif key.endswith("to_v_bias"): 87 | prefix = key[: -len("to_v_bias")] 88 | key = prefix + "attention_proj_v_bias" 89 | params_ait[key] = arr 90 | else: 91 | params_ait[key] = arr 92 | if encoder: 93 | params_ait["encoder_conv_in_weight"] = torch.functional.F.pad( 94 | params_ait["encoder_conv_in_weight"], (0, 1, 0, 0, 0, 0, 0, 0) 95 | ) 96 | 97 | return params_ait 98 | -------------------------------------------------------------------------------- /AITemplate/ait/util/torch_dtype_from_str.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def torch_dtype_from_str(dtype: str): 4 | return torch.__dict__.get(dtype, None) -------------------------------------------------------------------------------- /AITemplate/clip.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | import logging 16 | 17 | import click 18 | import torch 19 | from aitemplate.testing import detect_target 20 | from transformers import CLIPTextModel, CLIPTextModelWithProjection 21 | from ait.compile.clip import compile_clip 22 | 23 | @click.command() 24 | @click.option( 25 | "--hf-hub-or-path", 26 | default=r"runwayml/stable-diffusion-v1-5", 27 | help="the local diffusers pipeline directory or hf hub path e.g. runwayml/stable-diffusion-v1-5", 28 | ) 29 | @click.option( 30 | "--batch-size", 31 | default=(1, 2), 32 | type=(int, int), 33 | nargs=2, 34 | help="Minimum and maximum batch size", 35 | ) 36 | @click.option( 37 | "--output-hidden-states", 38 | default=False, 39 | type=bool, 40 | help="Output hidden states", 41 | ) 42 | @click.option( 43 | "--text-projection", 44 | default=False, 45 | type=bool, 46 | help="use text projection", 47 | ) 48 | @click.option( 49 | "--include-constants", 50 | default=False, 51 | type=bool, 52 | help="include constants (model weights) with compiled model", 53 | ) 54 | @click.option( 55 | "--subfolder", 56 | default="text_encoder", 57 | help="subfolder of hf repo or path. default `text_encoder`, this is `text_encoder_2` for SDXL.", 58 | ) 59 | @click.option("--use-fp16-acc", default=True, help="use fp16 accumulation") 60 | @click.option("--convert-conv-to-gemm", default=True, help="convert 1x1 conv to gemm") 61 | @click.option("--model-name", default="CLIPTextModel", help="module name") 62 | @click.option("--work-dir", default="./tmp", help="work directory") 63 | @click.option("--out-dir", default="./out", help="out directory") 64 | def compile_diffusers( 65 | hf_hub_or_path, 66 | batch_size, 67 | output_hidden_states, 68 | text_projection, 69 | include_constants, 70 | subfolder="text_encoder", 71 | use_fp16_acc=True, 72 | convert_conv_to_gemm=True, 73 | model_name="CLIPTextModel", 74 | work_dir="./tmp", 75 | out_dir="./out", 76 | ): 77 | logging.getLogger().setLevel(logging.INFO) 78 | torch.manual_seed(4896) 79 | 80 | if detect_target().name() == "rocm": 81 | convert_conv_to_gemm = False 82 | 83 | if text_projection: 84 | pipe = CLIPTextModelWithProjection.from_pretrained( 85 | hf_hub_or_path, 86 | subfolder=subfolder, 87 | variant="fp16", 88 | torch_dtype=torch.float16, 89 | use_safetensors=True, 90 | ).to("cuda") 91 | else: 92 | pipe = CLIPTextModel.from_pretrained( 93 | hf_hub_or_path, 94 | subfolder=subfolder, 95 | variant="fp16", 96 | torch_dtype=torch.float16, 97 | use_safetensors=True, 98 | ).to("cuda") 99 | 100 | compile_clip( 101 | pipe, 102 | batch_size=batch_size, 103 | seqlen=pipe.config.max_position_embeddings, 104 | use_fp16_acc=use_fp16_acc, 105 | convert_conv_to_gemm=convert_conv_to_gemm, 106 | output_hidden_states=output_hidden_states, 107 | text_projection_dim=pipe.config.projection_dim if text_projection else None, 108 | depth=pipe.config.num_hidden_layers, 109 | num_heads=pipe.config.num_attention_heads, 110 | dim=pipe.config.hidden_size, 111 | act_layer=pipe.config.hidden_act, 112 | constants=include_constants, 113 | model_name=model_name, 114 | work_dir=work_dir, 115 | out_dir=out_dir, 116 | ) 117 | 118 | if __name__ == "__main__": 119 | compile_diffusers() 120 | -------------------------------------------------------------------------------- /AITemplate/controlnet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | import logging 16 | 17 | import click 18 | import torch 19 | from aitemplate.testing import detect_target 20 | 21 | try: 22 | from diffusers import ControlNetModel 23 | except ImportError: 24 | raise ImportError( 25 | "Please install diffusers with `pip install diffusers` to use this script." 26 | ) 27 | from ait.compile.controlnet import compile_controlnet 28 | 29 | 30 | @click.command() 31 | @click.option( 32 | "--hf-hub-or-path", 33 | default="lllyasviel/sd-controlnet-canny", 34 | help="the local diffusers pipeline directory or hf hub path e.g. lllyasviel/sd-controlnet-canny", 35 | ) 36 | @click.option( 37 | "--width", 38 | default=(64, 2048), 39 | type=(int, int), 40 | nargs=2, 41 | help="Minimum and maximum width", 42 | ) 43 | @click.option( 44 | "--height", 45 | default=(64, 2048), 46 | type=(int, int), 47 | nargs=2, 48 | help="Minimum and maximum height", 49 | ) 50 | @click.option( 51 | "--batch-size", 52 | default=(1, 1), 53 | type=(int, int), 54 | nargs=2, 55 | help="Minimum and maximum batch size", 56 | ) 57 | @click.option("--clip-chunks", default=30, help="Maximum number of clip chunks") 58 | @click.option( 59 | "--include-constants", 60 | default=None, 61 | help="include constants (model weights) with compiled model", 62 | ) 63 | @click.option("--use-fp16-acc", default=True, help="use fp16 accumulation") 64 | @click.option("--convert-conv-to-gemm", default=True, help="convert 1x1 conv to gemm") 65 | @click.option("--model-name", default="ControlNetModel", help="module name") 66 | @click.option("--work-dir", default="./tmp", help="work directory") 67 | @click.option("--out-dir", default="./out", help="out directory") 68 | def compile_diffusers( 69 | hf_hub_or_path, 70 | width, 71 | height, 72 | batch_size, 73 | clip_chunks, 74 | include_constants, 75 | use_fp16_acc=True, 76 | convert_conv_to_gemm=True, 77 | model_name="ControlNetModel", 78 | work_dir="./tmp", 79 | out_dir="./out", 80 | ): 81 | logging.getLogger().setLevel(logging.INFO) 82 | torch.manual_seed(4896) 83 | 84 | if detect_target().name() == "rocm": 85 | convert_conv_to_gemm = False 86 | 87 | pipe = ControlNetModel.from_pretrained( 88 | hf_hub_or_path, 89 | use_safetensors=True, 90 | # variant="fp16", 91 | torch_dtype=torch.float16, 92 | ).to("cuda") 93 | 94 | compile_controlnet( 95 | pipe, 96 | batch_size=batch_size, 97 | width=width, 98 | height=height, 99 | clip_chunks=clip_chunks, 100 | convert_conv_to_gemm=convert_conv_to_gemm, 101 | use_fp16_acc=use_fp16_acc, 102 | constants=include_constants, 103 | model_name=model_name, 104 | work_dir=work_dir, 105 | hidden_dim=pipe.config.cross_attention_dim, 106 | use_linear_projection=pipe.config.get("use_linear_projection", False), 107 | block_out_channels=pipe.config.block_out_channels, 108 | down_block_types=pipe.config.down_block_types, 109 | in_channels=pipe.config.in_channels, 110 | class_embed_type=pipe.config.class_embed_type, 111 | num_class_embeds=pipe.config.num_class_embeds, 112 | dim=pipe.config.block_out_channels[0], 113 | time_embedding_dim=None, 114 | projection_class_embeddings_input_dim=pipe.config.projection_class_embeddings_input_dim 115 | if hasattr(pipe.config, "projection_class_embeddings_input_dim") 116 | else None, 117 | addition_embed_type=pipe.config.addition_embed_type 118 | if hasattr(pipe.config, "addition_embed_type") 119 | else None, 120 | addition_time_embed_dim=pipe.config.addition_time_embed_dim 121 | if hasattr(pipe.config, "addition_time_embed_dim") 122 | else None, 123 | transformer_layers_per_block=pipe.config.transformer_layers_per_block 124 | if hasattr(pipe.config, "transformer_layers_per_block") 125 | else 1, 126 | out_dir=out_dir, 127 | ) 128 | 129 | 130 | if __name__ == "__main__": 131 | compile_diffusers() 132 | -------------------------------------------------------------------------------- /AITemplate/download_pipeline.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | import click 16 | import torch 17 | try: 18 | from diffusers import StableDiffusionPipeline 19 | except ImportError: 20 | raise ImportError( 21 | "Please install diffusers with `pip install diffusers` to use this script." 22 | ) 23 | 24 | 25 | @click.command() 26 | @click.option("--token", default="", help="access token") 27 | @click.option( 28 | "--hf-hub", 29 | default="runwayml/stable-diffusion-v1-5", 30 | help="hf hub", 31 | ) 32 | @click.option( 33 | "--save_directory", 34 | default="./tmp/diffusers-pipeline/runwayml/stable-diffusion-v1-5", 35 | help="pipeline files local directory", 36 | ) 37 | @click.option("--revision", default="fp16", help="access token") 38 | def download_pipeline_files(token, hf_hub, save_directory, revision="fp16") -> None: 39 | StableDiffusionPipeline.from_pretrained( 40 | hf_hub, 41 | revision=revision if revision != "" else None, 42 | torch_dtype=torch.float16, 43 | # use provided token or the one generated with `huggingface-cli login`` 44 | use_auth_token=token if token != "" else True, 45 | ).save_pretrained(save_directory) 46 | 47 | 48 | if __name__ == "__main__": 49 | download_pipeline_files() 50 | -------------------------------------------------------------------------------- /AITemplate/modules/place_modules_here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FizzleDorf/AIT/b46f17dd095283c31d12d5ffd6fb90457c85e5e0/AITemplate/modules/place_modules_here -------------------------------------------------------------------------------- /AITemplate/test.py: -------------------------------------------------------------------------------- 1 | from ait.ait import AIT 2 | 3 | if __name__ == "__main__": 4 | ait = AIT() 5 | ait.load("/home/user/ait_modules/unet_64_1024_1_1.so", "runwayml/stable-diffusion-v1-5", "unet") 6 | ait.test_unet() 7 | ait = AIT() 8 | ait.load("/home/user/ait_tmp/tmp/v1_vae_64_1024/test.so", "runwayml/stable-diffusion-v1-5", "vae") 9 | ait.test_vae() 10 | ait = AIT() 11 | ait.load("/home/user/ait_tmp/v1_clip_1/test.so", "runwayml/stable-diffusion-v1-5", "clip") 12 | ait.test_clip() 13 | ait = AIT() 14 | ait.load("/home/user/ait_tmp/v1_controlnet_512_512_1/test.so", "lllyasviel/sd-controlnet-canny", "controlnet") 15 | ait.test_controlnet() 16 | # ait = AIT() 17 | # ait.load_compvis("/home/user/ait_modules/unet_64_1024_1_1.so", "/home/user/checkpoints/v1-5-pruned-emaonly.ckpt", "unet") 18 | # ait.test_unet() -------------------------------------------------------------------------------- /AITemplate/unet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | import logging 16 | 17 | import click 18 | import torch 19 | from aitemplate.testing import detect_target 20 | try: 21 | from diffusers import UNet2DConditionModel 22 | except ImportError: 23 | raise ImportError( 24 | "Please install diffusers with `pip install diffusers` to use this script." 25 | ) 26 | from ait.compile.unet import compile_unet 27 | 28 | @click.command() 29 | @click.option( 30 | "--hf-hub-or-path", 31 | default="runwayml/stable-diffusion-v1-5", 32 | help="the local diffusers pipeline directory or hf hub path e.g. runwayml/stable-diffusion-v1-5", 33 | ) 34 | @click.option( 35 | "--width", 36 | default=(64, 1024), 37 | type=(int, int), 38 | nargs=2, 39 | help="Minimum and maximum width", 40 | ) 41 | @click.option( 42 | "--height", 43 | default=(64, 1024), 44 | type=(int, int), 45 | nargs=2, 46 | help="Minimum and maximum height", 47 | ) 48 | @click.option( 49 | "--batch-size", 50 | default=(1, 1), 51 | type=(int, int), 52 | nargs=2, 53 | help="Minimum and maximum batch size", 54 | ) 55 | @click.option("--clip-chunks", default=30, help="Maximum number of clip chunks") 56 | @click.option( 57 | "--include-constants", 58 | default=None, 59 | help="include constants (model weights) with compiled model", 60 | ) 61 | @click.option( 62 | "--down-factor", 63 | default=8, 64 | type=int, 65 | help="Down factor, this is 4 for x4-upscaler", 66 | ) 67 | @click.option("--fp32", default=False, help="use fp32") 68 | @click.option("--use-fp16-acc", default=True, help="use fp16 accumulation") 69 | @click.option("--convert-conv-to-gemm", default=True, help="convert 1x1 conv to gemm") 70 | @click.option("--controlnet", default=False, help="UNet for controlnet") 71 | @click.option("--model-name", default="UNet2DConditionModel", help="module name") 72 | @click.option("--work-dir", default="./tmp", help="work directory") 73 | @click.option("--out-dir", default="./out", help="out directory") 74 | def compile_diffusers( 75 | hf_hub_or_path, 76 | width, 77 | height, 78 | batch_size, 79 | clip_chunks, 80 | include_constants, 81 | down_factor=8, 82 | fp32=False, 83 | use_fp16_acc=True, 84 | convert_conv_to_gemm=True, 85 | controlnet=False, 86 | model_name="UNet2DConditionModel", 87 | work_dir="./tmp", 88 | out_dir="./out", 89 | ): 90 | logging.getLogger().setLevel(logging.INFO) 91 | torch.manual_seed(4896) 92 | 93 | if detect_target().name() == "rocm": 94 | convert_conv_to_gemm = False 95 | 96 | pipe = UNet2DConditionModel.from_pretrained( 97 | hf_hub_or_path, 98 | subfolder="unet" if not hf_hub_or_path.endswith("unet") else None, 99 | variant="fp16", 100 | use_safetensors=True, 101 | torch_dtype=torch.float16, 102 | ).to("cuda") 103 | 104 | compile_unet( 105 | pipe, 106 | batch_size=batch_size, 107 | width=width, 108 | height=height, 109 | clip_chunks=clip_chunks, 110 | use_fp16_acc=use_fp16_acc, 111 | convert_conv_to_gemm=convert_conv_to_gemm, 112 | hidden_dim=pipe.config.cross_attention_dim, 113 | attention_head_dim=pipe.config.attention_head_dim, 114 | use_linear_projection=pipe.config.get("use_linear_projection", False), 115 | block_out_channels=pipe.config.block_out_channels, 116 | down_block_types=pipe.config.down_block_types, 117 | up_block_types=pipe.config.up_block_types, 118 | in_channels=pipe.config.in_channels, 119 | out_channels=pipe.config.out_channels, 120 | class_embed_type=pipe.config.class_embed_type, 121 | num_class_embeds=pipe.config.num_class_embeds, 122 | only_cross_attention=pipe.config.only_cross_attention, 123 | sample_size=pipe.config.sample_size, 124 | dim=pipe.config.block_out_channels[0], 125 | time_embedding_dim = None, 126 | conv_in_kernel=pipe.config.conv_in_kernel, 127 | projection_class_embeddings_input_dim=pipe.config.projection_class_embeddings_input_dim, 128 | addition_embed_type = pipe.config.addition_embed_type, 129 | addition_time_embed_dim = pipe.config.addition_time_embed_dim, 130 | transformer_layers_per_block = pipe.config.transformer_layers_per_block, 131 | constants=True if include_constants else False, 132 | controlnet=True if controlnet else False, 133 | model_name=model_name, 134 | work_dir=work_dir, 135 | down_factor=down_factor, 136 | dtype="float32" if fp32 else "float16", 137 | out_dir=out_dir, 138 | ) 139 | 140 | if __name__ == "__main__": 141 | compile_diffusers() 142 | -------------------------------------------------------------------------------- /AITemplate/vae.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | import logging 16 | 17 | import click 18 | import torch 19 | from aitemplate.testing import detect_target 20 | try: 21 | from diffusers import AutoencoderKL 22 | except ImportError: 23 | raise ImportError( 24 | "Please install diffusers with `pip install diffusers` to use this script." 25 | ) 26 | from ait.compile.vae import compile_vae 27 | 28 | @click.command() 29 | @click.option( 30 | "--hf-hub-or-path", 31 | default="runwayml/stable-diffusion-v1-5", 32 | help="the local diffusers pipeline directory or hf hub path e.g. runwayml/stable-diffusion-v1-5", 33 | ) 34 | @click.option( 35 | "--width", 36 | default=(64, 1024), 37 | type=(int, int), 38 | nargs=2, 39 | help="Minimum and maximum width", 40 | ) 41 | @click.option( 42 | "--height", 43 | default=(64, 1024), 44 | type=(int, int), 45 | nargs=2, 46 | help="Minimum and maximum height", 47 | ) 48 | @click.option( 49 | "--batch-size", 50 | default=(1, 1), 51 | type=(int, int), 52 | nargs=2, 53 | help="Minimum and maximum batch size", 54 | ) 55 | @click.option("--fp32", default=False, help="use fp32") 56 | @click.option( 57 | "--include-constants", 58 | default=None, 59 | help="include constants (model weights) with compiled model", 60 | ) 61 | @click.option( 62 | "--input-size", 63 | default=64, 64 | type=int, 65 | help="Input sample size, same as sample size of the unet model. this is 128 for x4-upscaler", 66 | ) 67 | @click.option( 68 | "--down-factor", 69 | default=8, 70 | type=int, 71 | help="Down factor, this is 4 for x4-upscaler", 72 | ) 73 | @click.option( 74 | "--encoder", 75 | default=False, 76 | type=bool, 77 | help="If True, compile encoder, otherwise, compile decoder", 78 | ) 79 | @click.option("--use-fp16-acc", default=True, help="use fp16 accumulation") 80 | @click.option("--convert-conv-to-gemm", default=True, help="convert 1x1 conv to gemm") 81 | @click.option("--model-name", default="AutoencoderKL", help="module name") 82 | @click.option("--work-dir", default="./tmp", help="work directory") 83 | @click.option("--out-dir", default="./out", help="out directory") 84 | def compile_diffusers( 85 | hf_hub_or_path, 86 | width, 87 | height, 88 | batch_size, 89 | fp32, 90 | include_constants, 91 | input_size=64, 92 | down_factor=8, 93 | encoder=False, 94 | use_fp16_acc=True, 95 | convert_conv_to_gemm=True, 96 | model_name="AutoencoderKL", 97 | work_dir="./tmp", 98 | out_dir="./out", 99 | ): 100 | logging.getLogger().setLevel(logging.INFO) 101 | torch.manual_seed(4896) 102 | 103 | if detect_target().name() == "rocm": 104 | convert_conv_to_gemm = False 105 | 106 | pipe = AutoencoderKL.from_pretrained( 107 | hf_hub_or_path, 108 | subfolder="vae" if not hf_hub_or_path.endswith("vae") else None, 109 | torch_dtype=torch.float32 if fp32 else torch.float16 110 | ).to("cuda") 111 | 112 | compile_vae( 113 | pipe, 114 | batch_size=batch_size, 115 | width=width, 116 | height=height, 117 | use_fp16_acc=use_fp16_acc, 118 | convert_conv_to_gemm=convert_conv_to_gemm, 119 | constants=True if include_constants else False, 120 | block_out_channels=pipe.config.block_out_channels, 121 | layers_per_block=pipe.config.layers_per_block, 122 | act_fn=pipe.config.act_fn, 123 | latent_channels=pipe.config.latent_channels, 124 | in_channels=pipe.config.in_channels, 125 | out_channels=pipe.config.out_channels, 126 | down_block_types=pipe.config.down_block_types, 127 | up_block_types=pipe.config.up_block_types, 128 | sample_size=pipe.config.sample_size, 129 | input_size=(input_size, input_size), 130 | down_factor=down_factor, 131 | model_name=model_name, 132 | dtype="float32" if fp32 else "float16", 133 | work_dir=work_dir, 134 | vae_encode=encoder, 135 | out_dir=out_dir, 136 | ) 137 | 138 | if __name__ == "__main__": 139 | compile_diffusers() 140 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// 2 | 3 | This repo is now depricated. New AIT developments will take place in a 4 | [new repo](https://github.com/FizzleDorf/ComfyUI-AIT). This is due to a 5 | volitile foundation that too frequently needs fixing after ComfyUI updates. 6 | The new repo was made alongside Comfyanonamous so this doesn't happen in the future. 7 | Thank you for your patience and understanding. 8 | 9 | ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// 10 | 11 | # AIT 12 | 13 | Experimental usage of [AITemplate](https://github.com/facebookincubator/AITemplate). 14 | 15 | [New XL Modules](https://huggingface.co/Fizzledorf/AITemplateXL) 16 | 17 | [Alternative Pre-compiled modules](https://huggingface.co/city96/AITemplate) (by City96) 18 | 19 | [Old Pre-compiled modules](https://huggingface.co/datasets/Fizzledorf/AITemplate_V1_V2) for earlier versions of this node. 20 | 21 | ## ComfyUI custom node 22 | 23 | you can use this [workflow for sdxl]( 24 | https://civitai.com/models/133818) 25 | thanks a bunch tdg8uu! 26 | 27 | ### Installation 28 | 29 | This repo can be cloned directly to ComfyUI's custom nodes folder. 30 | 31 | Adjust the path as required, the example assumes you are working from the ComfyUI repo. 32 | ``` 33 | git clone https://github.com/FizzleDorf/AIT custom_nodes/AIT 34 | ``` 35 | 36 | ### Modules 37 | 38 | Modules will be automatically selected, downloaded, and decompressed by the plugin. 39 | 40 | #### Nodes 41 | 42 | ##### Load AITemplate 43 | 44 | ![image](https://github.com/hlky/AIT/assets/106811348/75d25eac-4c50-4a83-bb47-58a10d38e094) 45 | 46 | `Loaders > Load AITemplate` 47 | 48 | #### Load AITemplate (ControlNet) 49 | 50 | ![image](https://github.com/hlky/AIT/assets/106811348/d410a55b-2d45-4e5c-8f36-50b1d3b84b4b) 51 | 52 | `Loaders > Load AITemplate (ControlNet)` 53 | 54 | #### VAE Decode (AITemplate) 55 | 56 | ![image](https://github.com/hlky/AIT/assets/106811348/75cfe24d-912a-4e7b-880f-18e97809d810) 57 | 58 | `Latent > VAE Decode (AITemplate)` 59 | 60 | #### VAE Encode (AITemplate) 61 | 62 | ![image](https://github.com/hlky/AIT/assets/106811348/7562c744-e3b1-4a63-9c49-b1a9875dbc47) 63 | 64 | `Latent > VAE Encode (AITemplate)` 65 | 66 | #### VAE Encode (AITemplate, Inpaint) 67 | 68 | ![image](https://github.com/hlky/AIT/assets/106811348/dce433cb-8160-4cba-9d87-829b0e75288e) 69 | 70 | `Latent > Inpaint > VAE Encode (AITemplate)` 71 | 72 | 73 | ### Workflow 74 | 75 | Example workflows in [`workflows/`](https://github.com/hlky/AIT/tree/main/workflows) 76 | 77 | ### Errors 78 | 79 | * Part of the error will be printed by the AITemplate module so this will be above the trackback. 80 | 81 | ## Supported model types 82 | * ControlNet 83 | * CLIPTextModel 84 | * UNet 85 | * **Inpainting UNet** 86 | * VAE 87 | * **VAE encode** 88 | 89 | ## Developers 90 | 91 | ### [Compile](https://github.com/hlky/AIT/blob/main/docs/compile.md) 92 | 93 | ### [CLIP](https://github.com/hlky/AIT/blob/main/docs/clip.md) 94 | 95 | ### [ControlNet](https://github.com/hlky/AIT/blob/main/docs/controlnet.md) 96 | 97 | ### [CompVis](https://github.com/hlky/AIT/blob/main/docs/compvis.md) 98 | 99 | ### [UNet](https://github.com/hlky/AIT/blob/main/docs/unet.md) 100 | 101 | ### [VAE](https://github.com/hlky/AIT/blob/main/docs/vae.md) 102 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .AITemplate.AITemplate import AITemplateLoader, AITemplateControlNetLoader, AITemplateVAEDecode, AITemplateVAEEncode, VAEEncodeForInpaint, AITemplateEmptyLatentImage, AITemplateLatentUpscale 2 | 3 | NODE_CLASS_MAPPINGS = { 4 | "AITemplateLoader": AITemplateLoader, 5 | "AITemplateControlNetLoader": AITemplateControlNetLoader, 6 | "AITemplateVAEDecode": AITemplateVAEDecode, 7 | "AITemplateVAEEncode": AITemplateVAEEncode, 8 | "AITemplateVAEEncodeForInpaint": VAEEncodeForInpaint, 9 | "AITemplateEmptyLatentImage": AITemplateEmptyLatentImage, 10 | "AITemplateLatentUpscale": AITemplateLatentUpscale, 11 | } 12 | 13 | # A dictionary that contains the friendly/humanly readable titles for the nodes 14 | NODE_DISPLAY_NAME_MAPPINGS = { 15 | "AITemplateLoader": "Load AITemplate", 16 | "AITemplateControlNetLoader": "Load AITemplate (ControlNet)", 17 | "AITemplateVAELoader": "Load AITemplate (VAE)", 18 | "AITemplateVAEDecode": "VAE Decode (AITemplate)", 19 | "AITemplateVAEEncode": "VAE Encode (AITemplate)", 20 | "AITemplateVAEEncodeForInpaint": "VAE Encode (AITemplate, Inpaint)", 21 | "AITemplateEmptyLatentImage": "Empty Latent Image (AITemplate)", 22 | "AITemplateLatentUpscale": "Upscale Latent Image (AITemplate)", 23 | } 24 | -------------------------------------------------------------------------------- /docs/clip.md: -------------------------------------------------------------------------------- 1 | # CLIP 2 | 3 | ## Limitations 4 | * None known 5 | 6 | ## Inference 7 | 8 | ### Inputs 9 | 10 | * sequence_length is typically 77 11 | * longer prompts are accepted by UNet modules when compiled with appropiate `--clip-chunks` option which provides maximum prompt length `clip_chunks * 77` 12 | 13 | * `"input0"` - `input_ids` 14 | ``` 15 | from transformers import CLIPTokenizer 16 | tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") 17 | text_input = tokenizer( 18 | ["a photo of an astronaut riding a horse on mars"] * batch_size, 19 | padding="max_length", 20 | max_length=sequence_length, 21 | truncation=True, 22 | return_tensors="pt", 23 | ) 24 | input_ids = text_input["input_ids"].cuda() 25 | ``` 26 | 27 | * `"input1"` - position_ids 28 | * dealt with internally by `clip_inference` 29 | e.g. `torch.arange(sequence_length).expand((batch, -1)).to(device)` 30 | 31 | ### Outputs 32 | 33 | `torch.randn(batch_size, sequence_length, hidden_dim)` `torch.randn(1, 77, 768)` 34 | 35 | ## Function 36 | 37 | ``` 38 | def clip_inference( 39 | exe_module: Model, 40 | input_ids: torch.Tensor, 41 | seqlen: int = 77, 42 | device: str = "cuda", 43 | dtype: str = "float16", 44 | ): 45 | ``` 46 | * `seqlen` is unlikely to need to be changed. 47 | * `device` could be specified e.g. `cuda:1` if required 48 | * `dtype` is experimental, the module would need to be compiled as `float32` 49 | -------------------------------------------------------------------------------- /docs/compile.md: -------------------------------------------------------------------------------- 1 | 2 | ## Compile 3 | 4 | [AITemplate](https://github.com/facebookincubator/AITemplate) must be installed 5 | 6 | ``` 7 | git clone --recursive https://github.com/facebookincubator/AITemplate 8 | cd AITemplate/python 9 | python setup.py bdist_wheel 10 | pip install dist/*.whl 11 | ``` 12 | 13 | ### VRAM Usage 14 | * For dynamic shape modules the vram usage of the module will be that of the maximum shape. 15 | * This includes batch size. 16 | 17 | ### All 18 | 19 | * use `--include-constants True` to include the model weights in the compiled module 20 | * by default the modules do not include model weights 21 | * use `--work-dir` to set the directory where profilers and modules will be built 22 | * use `--model-name` to set the name of the compiled module 23 | 24 | ### UNet/Control UNet 25 | ``` 26 | python unet.py --hf-hub-or-path "runwayml/stable-diffusion-v1-5" --width 64 1024 --height 64 1024 --batch-size 1 2 --clip-chunks 8 --model-name "v1_unet_64_1024_1_2" --work-dir "/home/user/ait_tmp/" 27 | ``` 28 | ``` 29 | python unet.py --hf-hub-or-path "runwayml/stable-diffusion-v1-5" --width 512 1024 --height 512 1024 --batch-size 1 1 --clip-chunks 8 --model-name "v1_control_unet_512_512" --work-dir "/home/user/ait_tmp/" --controlnet True 30 | ``` 31 | ``` 32 | Usage: unet.py [OPTIONS] 33 | 34 | Options: 35 | --hf-hub-or-path TEXT the local diffusers pipeline directory or hf 36 | hub path e.g. runwayml/stable-diffusion-v1-5 37 | --width ... Minimum and maximum width 38 | --height ... Minimum and maximum height 39 | --batch-size ... 40 | Minimum and maximum batch size 41 | --clip-chunks INTEGER Maximum number of clip chunks 42 | --include-constants TEXT include constants (model weights) with 43 | compiled model 44 | --use-fp16-acc TEXT use fp16 accumulation 45 | --convert-conv-to-gemm TEXT convert 1x1 conv to gemm 46 | --controlnet TEXT UNet for controlnet 47 | --model-name TEXT module name 48 | --work-dir TEXT work directory 49 | --help Show this message and exit. 50 | ``` 51 | 52 | 53 | ### ControlNet 54 | ``` 55 | python controlnet.py --width 64 1024 --height 64 1024 --batch-size 1 1 --model-name "v1_controlnet_64_512_1" --work-dir "/home/user/ait_tmp/" 56 | ``` 57 | ``` 58 | Usage: controlnet.py [OPTIONS] 59 | 60 | Options: 61 | --hf-hub-or-path TEXT the local diffusers pipeline directory or hf 62 | hub path e.g. lllyasviel/sd-controlnet-canny 63 | --width ... Minimum and maximum width 64 | --height ... Minimum and maximum height 65 | --batch-size ... 66 | Minimum and maximum batch size 67 | --clip-chunks INTEGER Maximum number of clip chunks 68 | --include-constants TEXT include constants (model weights) with 69 | compiled model 70 | --use-fp16-acc TEXT use fp16 accumulation 71 | --convert-conv-to-gemm TEXT convert 1x1 conv to gemm 72 | --model-name TEXT module name 73 | --work-dir TEXT work directory 74 | --help Show this message and exit. 75 | ``` 76 | 77 | ### CLIPTextModel 78 | ``` 79 | python clip.py --hf-hub-or-path "runwayml/stable-diffusion-v1-5" --batch-size 1 8 --model-name "v1_clip_1" --work-dir "/home/user/ait_tmp/" 80 | ``` 81 | ``` 82 | Usage: clip.py [OPTIONS] 83 | 84 | Options: 85 | --hf-hub-or-path TEXT the local diffusers pipeline directory or hf 86 | hub path e.g. runwayml/stable-diffusion-v1-5 87 | --batch-size ... 88 | Minimum and maximum batch size 89 | --include-constants TEXT include constants (model weights) with 90 | compiled model 91 | --use-fp16-acc TEXT use fp16 accumulation 92 | --convert-conv-to-gemm TEXT convert 1x1 conv to gemm 93 | --model-name TEXT module name 94 | --work-dir TEXT work directory 95 | --help Show this message and exit. 96 | ``` 97 | 98 | ### VAE 99 | ``` 100 | python vae.py --hf-hub-or-path "runwayml/stable-diffusion-v1-5" --width 64 1024 --height 64 1024 --batch-size 1 1 --model-name "v1_vae_64_1024" --work-dir "/home/user/ait_tmp/" 101 | ``` 102 | ``` 103 | Usage: vae.py [OPTIONS] 104 | 105 | Options: 106 | --hf-hub-or-path TEXT the local diffusers pipeline directory or hf 107 | hub path e.g. runwayml/stable-diffusion-v1-5 108 | --width ... Minimum and maximum width 109 | --height ... Minimum and maximum height 110 | --batch-size ... 111 | Minimum and maximum batch size 112 | --fp32 TEXT use fp32 113 | --include-constants TEXT include constants (model weights) with 114 | compiled model 115 | --use-fp16-acc TEXT use fp16 accumulation 116 | --convert-conv-to-gemm TEXT convert 1x1 conv to gemm 117 | --model-name TEXT module name 118 | --work-dir TEXT work directory 119 | --help Show this message and exit. 120 | ``` 121 | -------------------------------------------------------------------------------- /docs/compvis.md: -------------------------------------------------------------------------------- 1 | # CompVis 2 | 3 | ## AITemplate Model Wrapper 4 | 5 | * Used in place of `CFGNoisePredictor` which is then wrapped by `CompVisDenoiser`/`CompVisVDenoiser` 6 | * Provides `apply_model` 7 | ``` 8 | def apply_model( 9 | self, 10 | x: torch.Tensor, 11 | t: torch.Tensor, 12 | c_crossattn = None, 13 | c_concat = None, 14 | control = None, 15 | transformer_options = None, 16 | ): 17 | ``` 18 | * `timesteps_pt` = `t` 19 | * `latent_model_input` = `x` 20 | * `c_crossattn` = encoder_hidden_states 21 | * `c_concat` = will be concatenated to `latent_model_input` if not `None` 22 | * for ControlNet, additional residuals are expected under `control` 23 | * `down_block_residuals` = `control`.`output` 24 | * `mid_block_residual` = `control`.`middle[0]` 25 | * `transformer_options` is unused but present for ComfyUI compatibility 26 | 27 | Input is passed to [`unet_inference`](https://github.com/hlky/AIT/blob/main/docs/unet.md) 28 | -------------------------------------------------------------------------------- /docs/controlnet.md: -------------------------------------------------------------------------------- 1 | # ControlNet 2 | 3 | ## Limitations 4 | * None 5 | 6 | ## Inference 7 | 8 | ### Inputs 9 | 10 | * `"input0"` - `latent_model_input` 11 | e.g. `torch.randn(batch_size, latent_channels, latent_height, latent_width)` `torch.randn(2, 4, 64, 64)` 12 | 13 | * `"input1"` - `timesteps` 14 | e.g. `torch.Tensor([1] * batch_size)` 15 | 16 | * `"input2"` - `encoder_hidden_states` 17 | e.g. `torch.randn(batch_size, sequence_length, hidden_dim)` `torch.randn(2, 77, 768)` 18 | 19 | * `"input3"` - `controlnet_cond` 20 | * This is typically the output from a ControlNet annotator. 21 | e.g. `torch.randn(batch_size, control_channels, control_height, control_width)` `torch.randn(2, 3, 512, 512)` 22 | 23 | ### Outputs 24 | 25 | * `"down_block_residual_{i}"` i = 0..11 26 | e.g. 27 | ``` 28 | torch.Size([2, 64, 64, 320]) 29 | torch.Size([2, 64, 64, 320]) 30 | torch.Size([2, 64, 64, 320]) 31 | torch.Size([2, 32, 32, 320]) 32 | torch.Size([2, 32, 32, 640]) 33 | torch.Size([2, 32, 32, 640]) 34 | torch.Size([2, 16, 16, 640]) 35 | torch.Size([2, 16, 16, 1280]) 36 | torch.Size([2, 16, 16, 1280]) 37 | torch.Size([2, 8, 8, 1280]) 38 | torch.Size([2, 8, 8, 1280]) 39 | torch.Size([2, 8, 8, 1280]) 40 | ``` 41 | 42 | * `"mid_block_residual"` 43 | e.g. 44 | `torch.Size([2, 8, 8, 1280])` 45 | 46 | ## Function 47 | 48 | ``` 49 | def controlnet_inference( 50 | exe_module: Model, 51 | latent_model_input: torch.Tensor, 52 | timesteps: torch.Tensor, 53 | encoder_hidden_states: torch.Tensor, 54 | controlnet_cond: torch.Tensor, 55 | device: str = "cuda", 56 | dtype: str = "float16", 57 | ): 58 | ``` 59 | * `device` could be specified e.g. `cuda:1` if required. 60 | * `dtype` is experimental, the module would need to be compiled as `float32`. 61 | -------------------------------------------------------------------------------- /docs/unet.md: -------------------------------------------------------------------------------- 1 | # UNet 2 | 3 | ## Limitations 4 | * Models with odd input channels (x4-upscaler etc.) are experimental. 5 | 6 | ## Inference 7 | 8 | * AITemplate uses `bhwc` 9 | * for input to `_inference` functions provide input as `bchw`, output will be in `bchw` 10 | 11 | ### Inputs 12 | 13 | * `"input0"` - `latent_model_input` 14 | e.g. `torch.randn(batch_size, latent_channels, height, width)` `torch.randn(2, 4, 64, 64)` 15 | 16 | * `"input1"` - `timesteps` 17 | e.g. `torch.Tensor([1] * batch_size)` 18 | 19 | * `"input2"` - `encoder_hidden_states` 20 | e.g. `torch.randn(batch_size, sequence_length, hidden_dim)` `torch.randn(2, 77, 768)` 21 | 22 | #### ControlNet 23 | 24 | These are the output from ControlNet modules. The sizes are determined by batch size, latent height and width, and block_out_channels. 25 | 26 | * `"down_block_residual_{i}"` i = 0..11 27 | e.g. 28 | ``` 29 | torch.Size([2, 64, 64, 320]) 30 | torch.Size([2, 64, 64, 320]) 31 | torch.Size([2, 64, 64, 320]) 32 | torch.Size([2, 32, 32, 320]) 33 | torch.Size([2, 32, 32, 640]) 34 | torch.Size([2, 32, 32, 640]) 35 | torch.Size([2, 16, 16, 640]) 36 | torch.Size([2, 16, 16, 1280]) 37 | torch.Size([2, 16, 16, 1280]) 38 | torch.Size([2, 8, 8, 1280]) 39 | torch.Size([2, 8, 8, 1280]) 40 | torch.Size([2, 8, 8, 1280]) 41 | ``` 42 | 43 | * `"mid_block_residual"` 44 | e.g. 45 | `torch.Size([2, 8, 8, 1280])` 46 | 47 | #### Experimental 48 | 49 | * `"input3"` - `class_labels` 50 | * This is the noise level for `x4-upscaler` 51 | e.g. `torch.tensor([20] * batch_size, dtype=torch.long)` 52 | 53 | ### Outputs 54 | 55 | Same size as `latent_model_input` e.g. `torch.randn(batch_size, latent_channels, height, width)` `torch.randn(2, 4, 64, 64)` 56 | 57 | ## Function 58 | 59 | ``` 60 | def unet_inference( 61 | exe_module: Model, 62 | latent_model_input: torch.Tensor, 63 | timesteps: torch.Tensor, 64 | encoder_hidden_states: torch.Tensor, 65 | class_labels: torch.Tensor = None, 66 | down_block_residuals: List[torch.Tensor] = None, 67 | mid_block_residual: torch.Tensor = None, 68 | device: str = "cuda", 69 | dtype: str = "float16", 70 | ): 71 | ``` 72 | * `class_labels` is experimental and used by `x4-upscaler`. 73 | * `down_block_residuals` and `mid_block_residual` require a `Control UNet` module. 74 | * `device` could be specified e.g. `cuda:1` if required. 75 | * `dtype` is experimental, the module would need to be compiled as `float32`. 76 | -------------------------------------------------------------------------------- /docs/vae.md: -------------------------------------------------------------------------------- 1 | # VAE 2 | 3 | ## Limitations 4 | * None known 5 | 6 | ## Inference 7 | 8 | ### Inputs 9 | 10 | * `"vae_input"` - vae_input 11 | * VAE input is the output from UNet 12 | e.g. `torch.randn(batch_size, latent_channels, height, width)` `torch.randn(2, 4, 64, 64)` 13 | or for VAE encode 14 | `torch.randn(batch_size, latent_channels, height, width)` `torch.randn(2, 4, 512, 512)` 15 | 16 | ### Outputs 17 | 18 | `h, w` * `factor` 19 | e.g. `torch.randn(2, 4, 512, 512)` 20 | 21 | or for VAE encode 22 | 23 | `h, w` // `factor` 24 | e.g. `torch.randn(2, 4, 64, 64)` 25 | 26 | ## Function 27 | 28 | ``` 29 | def vae_inference( 30 | exe_module: Model, 31 | vae_input: torch.Tensor, 32 | factor: int = 8, 33 | device: str = "cuda", 34 | dtype: str = "float16", 35 | encoder: bool = False, 36 | ): 37 | ``` 38 | `factor` must be set correctly when experimenting with non-standard VAE, default is 8 e.g. `64->512`. 39 | `x4-upscaler` uses `4` as the UNet sample size is bigger e.g. `128->512`. 40 | `factor` is used to set the output shape to accomodate dynamic shape support. 41 | `device` could be specified e.g. `cuda:1` if required 42 | `dtype` is experimental, the module would need to be compiled as `float32`, this is required for `x4-upscaler`. 43 | `encoder` is set for VAE encode inference. 44 | -------------------------------------------------------------------------------- /workflows/aitemplate_controlnet.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 15, 3 | "last_link_id": 19, 4 | "nodes": [ 5 | { 6 | "id": 5, 7 | "type": "EmptyLatentImage", 8 | "pos": [ 9 | 473, 10 | 609 11 | ], 12 | "size": { 13 | "0": 315, 14 | "1": 106 15 | }, 16 | "flags": {}, 17 | "order": 0, 18 | "mode": 0, 19 | "outputs": [ 20 | { 21 | "name": "LATENT", 22 | "type": "LATENT", 23 | "links": [ 24 | 2 25 | ], 26 | "slot_index": 0 27 | } 28 | ], 29 | "properties": { 30 | "Node name for S&R": "EmptyLatentImage" 31 | }, 32 | "widgets_values": [ 33 | 512, 34 | 512, 35 | 1 36 | ] 37 | }, 38 | { 39 | "id": 9, 40 | "type": "SaveImage", 41 | "pos": [ 42 | 1451, 43 | 189 44 | ], 45 | "size": [ 46 | 210, 47 | 270 48 | ], 49 | "flags": {}, 50 | "order": 11, 51 | "mode": 0, 52 | "inputs": [ 53 | { 54 | "name": "images", 55 | "type": "IMAGE", 56 | "link": 13 57 | } 58 | ], 59 | "properties": {}, 60 | "widgets_values": [ 61 | "ComfyUI" 62 | ] 63 | }, 64 | { 65 | "id": 10, 66 | "type": "AITemplateLoader", 67 | "pos": [ 68 | 26, 69 | 375 70 | ], 71 | "size": { 72 | "0": 315, 73 | "1": 58 74 | }, 75 | "flags": {}, 76 | "order": 5, 77 | "mode": 0, 78 | "inputs": [ 79 | { 80 | "name": "model", 81 | "type": "MODEL", 82 | "link": 10 83 | } 84 | ], 85 | "outputs": [ 86 | { 87 | "name": "MODEL", 88 | "type": "MODEL", 89 | "links": [ 90 | 11 91 | ], 92 | "shape": 3, 93 | "slot_index": 0 94 | } 95 | ], 96 | "properties": { 97 | "Node name for S&R": "AITemplateLoader" 98 | }, 99 | "widgets_values": [ 100 | "enable" 101 | ] 102 | }, 103 | { 104 | "id": 3, 105 | "type": "KSampler", 106 | "pos": [ 107 | 863, 108 | 186 109 | ], 110 | "size": { 111 | "0": 315, 112 | "1": 262 113 | }, 114 | "flags": {}, 115 | "order": 9, 116 | "mode": 0, 117 | "inputs": [ 118 | { 119 | "name": "model", 120 | "type": "MODEL", 121 | "link": 11 122 | }, 123 | { 124 | "name": "positive", 125 | "type": "CONDITIONING", 126 | "link": 17 127 | }, 128 | { 129 | "name": "negative", 130 | "type": "CONDITIONING", 131 | "link": 6 132 | }, 133 | { 134 | "name": "latent_image", 135 | "type": "LATENT", 136 | "link": 2 137 | } 138 | ], 139 | "outputs": [ 140 | { 141 | "name": "LATENT", 142 | "type": "LATENT", 143 | "links": [ 144 | 12 145 | ], 146 | "slot_index": 0 147 | } 148 | ], 149 | "properties": { 150 | "Node name for S&R": "KSampler" 151 | }, 152 | "widgets_values": [ 153 | 103478091149550, 154 | "randomize", 155 | 20, 156 | 8, 157 | "euler", 158 | "normal", 159 | 1 160 | ] 161 | }, 162 | { 163 | "id": 11, 164 | "type": "AITemplateVAEDecode", 165 | "pos": [ 166 | 1150, 167 | 71 168 | ], 169 | "size": { 170 | "0": 315, 171 | "1": 78 172 | }, 173 | "flags": {}, 174 | "order": 10, 175 | "mode": 0, 176 | "inputs": [ 177 | { 178 | "name": "vae", 179 | "type": "VAE", 180 | "link": 19 181 | }, 182 | { 183 | "name": "samples", 184 | "type": "LATENT", 185 | "link": 12 186 | } 187 | ], 188 | "outputs": [ 189 | { 190 | "name": "IMAGE", 191 | "type": "IMAGE", 192 | "links": [ 193 | 13 194 | ], 195 | "shape": 3, 196 | "slot_index": 0 197 | } 198 | ], 199 | "properties": { 200 | "Node name for S&R": "AITemplateVAEDecode" 201 | }, 202 | "widgets_values": [ 203 | "enable" 204 | ] 205 | }, 206 | { 207 | "id": 13, 208 | "type": "ControlNetLoader", 209 | "pos": [ 210 | 30, 211 | 620 212 | ], 213 | "size": { 214 | "0": 315, 215 | "1": 58 216 | }, 217 | "flags": {}, 218 | "order": 1, 219 | "mode": 0, 220 | "outputs": [ 221 | { 222 | "name": "CONTROL_NET", 223 | "type": "CONTROL_NET", 224 | "links": [ 225 | 14 226 | ], 227 | "shape": 3, 228 | "slot_index": 0 229 | } 230 | ], 231 | "properties": { 232 | "Node name for S&R": "ControlNetLoader" 233 | }, 234 | "widgets_values": [ 235 | "control_v11p_sd15_canny.pth" 236 | ] 237 | }, 238 | { 239 | "id": 12, 240 | "type": "AITemplateControlNetLoader", 241 | "pos": [ 242 | 25, 243 | 730 244 | ], 245 | "size": { 246 | "0": 315, 247 | "1": 58 248 | }, 249 | "flags": {}, 250 | "order": 4, 251 | "mode": 0, 252 | "inputs": [ 253 | { 254 | "name": "control_net", 255 | "type": "CONTROL_NET", 256 | "link": 14 257 | } 258 | ], 259 | "outputs": [ 260 | { 261 | "name": "CONTROL_NET", 262 | "type": "CONTROL_NET", 263 | "links": [ 264 | 15 265 | ], 266 | "shape": 3, 267 | "slot_index": 0 268 | } 269 | ], 270 | "properties": { 271 | "Node name for S&R": "AITemplateControlNetLoader" 272 | }, 273 | "widgets_values": [ 274 | "enable" 275 | ] 276 | }, 277 | { 278 | "id": 15, 279 | "type": "LoadImage", 280 | "pos": [ 281 | 26, 282 | 830 283 | ], 284 | "size": [ 285 | 315, 286 | 314.00001525878906 287 | ], 288 | "flags": {}, 289 | "order": 2, 290 | "mode": 0, 291 | "outputs": [ 292 | { 293 | "name": "IMAGE", 294 | "type": "IMAGE", 295 | "links": [ 296 | 16 297 | ], 298 | "shape": 3, 299 | "slot_index": 0 300 | }, 301 | { 302 | "name": "MASK", 303 | "type": "MASK", 304 | "links": null, 305 | "shape": 3 306 | } 307 | ], 308 | "properties": { 309 | "Node name for S&R": "LoadImage" 310 | }, 311 | "widgets_values": [ 312 | "canny.png", 313 | "image" 314 | ] 315 | }, 316 | { 317 | "id": 14, 318 | "type": "ControlNetApply", 319 | "pos": [ 320 | 357, 321 | 754 322 | ], 323 | "size": { 324 | "0": 317.4000244140625, 325 | "1": 98 326 | }, 327 | "flags": {}, 328 | "order": 8, 329 | "mode": 0, 330 | "inputs": [ 331 | { 332 | "name": "conditioning", 333 | "type": "CONDITIONING", 334 | "link": 18 335 | }, 336 | { 337 | "name": "control_net", 338 | "type": "CONTROL_NET", 339 | "link": 15 340 | }, 341 | { 342 | "name": "image", 343 | "type": "IMAGE", 344 | "link": 16 345 | } 346 | ], 347 | "outputs": [ 348 | { 349 | "name": "CONDITIONING", 350 | "type": "CONDITIONING", 351 | "links": [ 352 | 17 353 | ], 354 | "shape": 3, 355 | "slot_index": 0 356 | } 357 | ], 358 | "properties": { 359 | "Node name for S&R": "ControlNetApply" 360 | }, 361 | "widgets_values": [ 362 | 1 363 | ] 364 | }, 365 | { 366 | "id": 7, 367 | "type": "CLIPTextEncode", 368 | "pos": [ 369 | 410, 370 | 390 371 | ], 372 | "size": { 373 | "0": 425.27801513671875, 374 | "1": 180.6060791015625 375 | }, 376 | "flags": {}, 377 | "order": 7, 378 | "mode": 0, 379 | "inputs": [ 380 | { 381 | "name": "clip", 382 | "type": "CLIP", 383 | "link": 5 384 | } 385 | ], 386 | "outputs": [ 387 | { 388 | "name": "CONDITIONING", 389 | "type": "CONDITIONING", 390 | "links": [ 391 | 6 392 | ], 393 | "slot_index": 0 394 | } 395 | ], 396 | "properties": { 397 | "Node name for S&R": "CLIPTextEncode" 398 | }, 399 | "widgets_values": [ 400 | "text, watermark" 401 | ] 402 | }, 403 | { 404 | "id": 6, 405 | "type": "CLIPTextEncode", 406 | "pos": [ 407 | 415, 408 | 186 409 | ], 410 | "size": { 411 | "0": 422.84503173828125, 412 | "1": 164.31304931640625 413 | }, 414 | "flags": {}, 415 | "order": 6, 416 | "mode": 0, 417 | "inputs": [ 418 | { 419 | "name": "clip", 420 | "type": "CLIP", 421 | "link": 3 422 | } 423 | ], 424 | "outputs": [ 425 | { 426 | "name": "CONDITIONING", 427 | "type": "CONDITIONING", 428 | "links": [ 429 | 18 430 | ], 431 | "slot_index": 0 432 | } 433 | ], 434 | "properties": { 435 | "Node name for S&R": "CLIPTextEncode" 436 | }, 437 | "widgets_values": [ 438 | "beautiful scenery nature glass bottle landscape, , purple galaxy bottle," 439 | ] 440 | }, 441 | { 442 | "id": 4, 443 | "type": "CheckpointLoaderSimple", 444 | "pos": [ 445 | 26, 446 | 474 447 | ], 448 | "size": { 449 | "0": 315, 450 | "1": 98 451 | }, 452 | "flags": {}, 453 | "order": 3, 454 | "mode": 0, 455 | "outputs": [ 456 | { 457 | "name": "MODEL", 458 | "type": "MODEL", 459 | "links": [ 460 | 10 461 | ], 462 | "slot_index": 0 463 | }, 464 | { 465 | "name": "CLIP", 466 | "type": "CLIP", 467 | "links": [ 468 | 3, 469 | 5 470 | ], 471 | "slot_index": 1 472 | }, 473 | { 474 | "name": "VAE", 475 | "type": "VAE", 476 | "links": [ 477 | 19 478 | ], 479 | "slot_index": 2 480 | } 481 | ], 482 | "properties": { 483 | "Node name for S&R": "CheckpointLoaderSimple" 484 | }, 485 | "widgets_values": [ 486 | "v1-5-pruned-emaonly.safetensors" 487 | ] 488 | } 489 | ], 490 | "links": [ 491 | [ 492 | 2, 493 | 5, 494 | 0, 495 | 3, 496 | 3, 497 | "LATENT" 498 | ], 499 | [ 500 | 3, 501 | 4, 502 | 1, 503 | 6, 504 | 0, 505 | "CLIP" 506 | ], 507 | [ 508 | 5, 509 | 4, 510 | 1, 511 | 7, 512 | 0, 513 | "CLIP" 514 | ], 515 | [ 516 | 6, 517 | 7, 518 | 0, 519 | 3, 520 | 2, 521 | "CONDITIONING" 522 | ], 523 | [ 524 | 10, 525 | 4, 526 | 0, 527 | 10, 528 | 0, 529 | "MODEL" 530 | ], 531 | [ 532 | 11, 533 | 10, 534 | 0, 535 | 3, 536 | 0, 537 | "MODEL" 538 | ], 539 | [ 540 | 12, 541 | 3, 542 | 0, 543 | 11, 544 | 1, 545 | "LATENT" 546 | ], 547 | [ 548 | 13, 549 | 11, 550 | 0, 551 | 9, 552 | 0, 553 | "IMAGE" 554 | ], 555 | [ 556 | 14, 557 | 13, 558 | 0, 559 | 12, 560 | 0, 561 | "CONTROL_NET" 562 | ], 563 | [ 564 | 15, 565 | 12, 566 | 0, 567 | 14, 568 | 1, 569 | "CONTROL_NET" 570 | ], 571 | [ 572 | 16, 573 | 15, 574 | 0, 575 | 14, 576 | 2, 577 | "IMAGE" 578 | ], 579 | [ 580 | 17, 581 | 14, 582 | 0, 583 | 3, 584 | 1, 585 | "CONDITIONING" 586 | ], 587 | [ 588 | 18, 589 | 6, 590 | 0, 591 | 14, 592 | 0, 593 | "CONDITIONING" 594 | ], 595 | [ 596 | 19, 597 | 4, 598 | 2, 599 | 11, 600 | 0, 601 | "VAE" 602 | ] 603 | ], 604 | "groups": [], 605 | "config": {}, 606 | "extra": {}, 607 | "version": 0.4 608 | } -------------------------------------------------------------------------------- /workflows/aitemplate_img2img_unet_vae.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 13, 3 | "last_link_id": 17, 4 | "nodes": [ 5 | { 6 | "id": 7, 7 | "type": "CLIPTextEncode", 8 | "pos": [ 9 | 413, 10 | 389 11 | ], 12 | "size": { 13 | "0": 425.27801513671875, 14 | "1": 180.6060791015625 15 | }, 16 | "flags": {}, 17 | "order": 4, 18 | "mode": 0, 19 | "inputs": [ 20 | { 21 | "name": "clip", 22 | "type": "CLIP", 23 | "link": 5 24 | } 25 | ], 26 | "outputs": [ 27 | { 28 | "name": "CONDITIONING", 29 | "type": "CONDITIONING", 30 | "links": [ 31 | 6 32 | ], 33 | "slot_index": 0 34 | } 35 | ], 36 | "properties": { 37 | "Node name for S&R": "CLIPTextEncode" 38 | }, 39 | "widgets_values": [ 40 | "text, watermark" 41 | ] 42 | }, 43 | { 44 | "id": 6, 45 | "type": "CLIPTextEncode", 46 | "pos": [ 47 | 415, 48 | 186 49 | ], 50 | "size": { 51 | "0": 422.84503173828125, 52 | "1": 164.31304931640625 53 | }, 54 | "flags": {}, 55 | "order": 3, 56 | "mode": 0, 57 | "inputs": [ 58 | { 59 | "name": "clip", 60 | "type": "CLIP", 61 | "link": 3 62 | } 63 | ], 64 | "outputs": [ 65 | { 66 | "name": "CONDITIONING", 67 | "type": "CONDITIONING", 68 | "links": [ 69 | 4 70 | ], 71 | "slot_index": 0 72 | } 73 | ], 74 | "properties": { 75 | "Node name for S&R": "CLIPTextEncode" 76 | }, 77 | "widgets_values": [ 78 | "beautiful scenery nature glass bottle landscape, , purple galaxy bottle," 79 | ] 80 | }, 81 | { 82 | "id": 9, 83 | "type": "SaveImage", 84 | "pos": [ 85 | 1451, 86 | 189 87 | ], 88 | "size": [ 89 | 210, 90 | 270 91 | ], 92 | "flags": {}, 93 | "order": 8, 94 | "mode": 0, 95 | "inputs": [ 96 | { 97 | "name": "images", 98 | "type": "IMAGE", 99 | "link": 13 100 | } 101 | ], 102 | "properties": {}, 103 | "widgets_values": [ 104 | "ComfyUI" 105 | ] 106 | }, 107 | { 108 | "id": 10, 109 | "type": "AITemplateLoader", 110 | "pos": [ 111 | 26, 112 | 375 113 | ], 114 | "size": { 115 | "0": 315, 116 | "1": 58 117 | }, 118 | "flags": {}, 119 | "order": 2, 120 | "mode": 0, 121 | "inputs": [ 122 | { 123 | "name": "model", 124 | "type": "MODEL", 125 | "link": 10 126 | } 127 | ], 128 | "outputs": [ 129 | { 130 | "name": "MODEL", 131 | "type": "MODEL", 132 | "links": [ 133 | 11 134 | ], 135 | "shape": 3, 136 | "slot_index": 0 137 | } 138 | ], 139 | "properties": { 140 | "Node name for S&R": "AITemplateLoader" 141 | }, 142 | "widgets_values": [ 143 | "enable" 144 | ] 145 | }, 146 | { 147 | "id": 11, 148 | "type": "AITemplateVAEDecode", 149 | "pos": [ 150 | 1150, 151 | 71 152 | ], 153 | "size": { 154 | "0": 315, 155 | "1": 78 156 | }, 157 | "flags": {}, 158 | "order": 7, 159 | "mode": 0, 160 | "inputs": [ 161 | { 162 | "name": "vae", 163 | "type": "VAE", 164 | "link": 17 165 | }, 166 | { 167 | "name": "samples", 168 | "type": "LATENT", 169 | "link": 12 170 | } 171 | ], 172 | "outputs": [ 173 | { 174 | "name": "IMAGE", 175 | "type": "IMAGE", 176 | "links": [ 177 | 13 178 | ], 179 | "shape": 3, 180 | "slot_index": 0 181 | } 182 | ], 183 | "properties": { 184 | "Node name for S&R": "AITemplateVAEDecode" 185 | }, 186 | "widgets_values": [ 187 | "enable" 188 | ] 189 | }, 190 | { 191 | "id": 13, 192 | "type": "LoadImage", 193 | "pos": [ 194 | 17, 195 | 735 196 | ], 197 | "size": { 198 | "0": 315, 199 | "1": 314 200 | }, 201 | "flags": {}, 202 | "order": 0, 203 | "mode": 0, 204 | "outputs": [ 205 | { 206 | "name": "IMAGE", 207 | "type": "IMAGE", 208 | "links": [ 209 | 15 210 | ], 211 | "shape": 3, 212 | "slot_index": 0 213 | }, 214 | { 215 | "name": "MASK", 216 | "type": "MASK", 217 | "links": null, 218 | "shape": 3 219 | } 220 | ], 221 | "properties": { 222 | "Node name for S&R": "LoadImage" 223 | }, 224 | "widgets_values": [ 225 | "example.png", 226 | "image" 227 | ] 228 | }, 229 | { 230 | "id": 12, 231 | "type": "AITemplateVAEEncode", 232 | "pos": [ 233 | 18, 234 | 616 235 | ], 236 | "size": { 237 | "0": 315, 238 | "1": 78 239 | }, 240 | "flags": {}, 241 | "order": 5, 242 | "mode": 0, 243 | "inputs": [ 244 | { 245 | "name": "pixels", 246 | "type": "IMAGE", 247 | "link": 15 248 | }, 249 | { 250 | "name": "vae", 251 | "type": "VAE", 252 | "link": 14 253 | } 254 | ], 255 | "outputs": [ 256 | { 257 | "name": "LATENT", 258 | "type": "LATENT", 259 | "links": [ 260 | 16 261 | ], 262 | "shape": 3, 263 | "slot_index": 0 264 | } 265 | ], 266 | "properties": { 267 | "Node name for S&R": "AITemplateVAEEncode" 268 | }, 269 | "widgets_values": [ 270 | "enable" 271 | ] 272 | }, 273 | { 274 | "id": 4, 275 | "type": "CheckpointLoaderSimple", 276 | "pos": [ 277 | 26, 278 | 474 279 | ], 280 | "size": { 281 | "0": 315, 282 | "1": 98 283 | }, 284 | "flags": {}, 285 | "order": 1, 286 | "mode": 0, 287 | "outputs": [ 288 | { 289 | "name": "MODEL", 290 | "type": "MODEL", 291 | "links": [ 292 | 10 293 | ], 294 | "slot_index": 0 295 | }, 296 | { 297 | "name": "CLIP", 298 | "type": "CLIP", 299 | "links": [ 300 | 3, 301 | 5 302 | ], 303 | "slot_index": 1 304 | }, 305 | { 306 | "name": "VAE", 307 | "type": "VAE", 308 | "links": [ 309 | 14, 310 | 17 311 | ], 312 | "slot_index": 2 313 | } 314 | ], 315 | "properties": { 316 | "Node name for S&R": "CheckpointLoaderSimple" 317 | }, 318 | "widgets_values": [ 319 | "v1-5-pruned-emaonly.safetensors" 320 | ] 321 | }, 322 | { 323 | "id": 3, 324 | "type": "KSampler", 325 | "pos": [ 326 | 863, 327 | 186 328 | ], 329 | "size": { 330 | "0": 315, 331 | "1": 262 332 | }, 333 | "flags": {}, 334 | "order": 6, 335 | "mode": 0, 336 | "inputs": [ 337 | { 338 | "name": "model", 339 | "type": "MODEL", 340 | "link": 11 341 | }, 342 | { 343 | "name": "positive", 344 | "type": "CONDITIONING", 345 | "link": 4 346 | }, 347 | { 348 | "name": "negative", 349 | "type": "CONDITIONING", 350 | "link": 6 351 | }, 352 | { 353 | "name": "latent_image", 354 | "type": "LATENT", 355 | "link": 16 356 | } 357 | ], 358 | "outputs": [ 359 | { 360 | "name": "LATENT", 361 | "type": "LATENT", 362 | "links": [ 363 | 12 364 | ], 365 | "slot_index": 0 366 | } 367 | ], 368 | "properties": { 369 | "Node name for S&R": "KSampler" 370 | }, 371 | "widgets_values": [ 372 | 748405773077563, 373 | "randomize", 374 | 20, 375 | 8, 376 | "euler", 377 | "normal", 378 | 0.8 379 | ] 380 | } 381 | ], 382 | "links": [ 383 | [ 384 | 3, 385 | 4, 386 | 1, 387 | 6, 388 | 0, 389 | "CLIP" 390 | ], 391 | [ 392 | 4, 393 | 6, 394 | 0, 395 | 3, 396 | 1, 397 | "CONDITIONING" 398 | ], 399 | [ 400 | 5, 401 | 4, 402 | 1, 403 | 7, 404 | 0, 405 | "CLIP" 406 | ], 407 | [ 408 | 6, 409 | 7, 410 | 0, 411 | 3, 412 | 2, 413 | "CONDITIONING" 414 | ], 415 | [ 416 | 10, 417 | 4, 418 | 0, 419 | 10, 420 | 0, 421 | "MODEL" 422 | ], 423 | [ 424 | 11, 425 | 10, 426 | 0, 427 | 3, 428 | 0, 429 | "MODEL" 430 | ], 431 | [ 432 | 12, 433 | 3, 434 | 0, 435 | 11, 436 | 1, 437 | "LATENT" 438 | ], 439 | [ 440 | 13, 441 | 11, 442 | 0, 443 | 9, 444 | 0, 445 | "IMAGE" 446 | ], 447 | [ 448 | 14, 449 | 4, 450 | 2, 451 | 12, 452 | 1, 453 | "VAE" 454 | ], 455 | [ 456 | 15, 457 | 13, 458 | 0, 459 | 12, 460 | 0, 461 | "IMAGE" 462 | ], 463 | [ 464 | 16, 465 | 12, 466 | 0, 467 | 3, 468 | 3, 469 | "LATENT" 470 | ], 471 | [ 472 | 17, 473 | 4, 474 | 2, 475 | 11, 476 | 0, 477 | "VAE" 478 | ] 479 | ], 480 | "groups": [], 481 | "config": {}, 482 | "extra": {}, 483 | "version": 0.4 484 | } -------------------------------------------------------------------------------- /workflows/aitemplate_two_pass.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 14, 3 | "last_link_id": 21, 4 | "nodes": [ 5 | { 6 | "id": 5, 7 | "type": "EmptyLatentImage", 8 | "pos": [ 9 | 460, 10 | 470 11 | ], 12 | "size": { 13 | "0": 315, 14 | "1": 106 15 | }, 16 | "flags": {}, 17 | "order": 0, 18 | "mode": 0, 19 | "outputs": [ 20 | { 21 | "name": "LATENT", 22 | "type": "LATENT", 23 | "links": [ 24 | 2 25 | ], 26 | "slot_index": 0 27 | } 28 | ], 29 | "properties": { 30 | "Node name for S&R": "EmptyLatentImage" 31 | }, 32 | "widgets_values": [ 33 | 512, 34 | 512, 35 | 1 36 | ] 37 | }, 38 | { 39 | "id": 3, 40 | "type": "KSampler", 41 | "pos": [ 42 | 850, 43 | 50 44 | ], 45 | "size": { 46 | "0": 315, 47 | "1": 262 48 | }, 49 | "flags": {}, 50 | "order": 5, 51 | "mode": 0, 52 | "inputs": [ 53 | { 54 | "name": "model", 55 | "type": "MODEL", 56 | "link": 11 57 | }, 58 | { 59 | "name": "positive", 60 | "type": "CONDITIONING", 61 | "link": 4 62 | }, 63 | { 64 | "name": "negative", 65 | "type": "CONDITIONING", 66 | "link": 6 67 | }, 68 | { 69 | "name": "latent_image", 70 | "type": "LATENT", 71 | "link": 2 72 | } 73 | ], 74 | "outputs": [ 75 | { 76 | "name": "LATENT", 77 | "type": "LATENT", 78 | "links": [ 79 | 14 80 | ], 81 | "slot_index": 0 82 | } 83 | ], 84 | "properties": { 85 | "Node name for S&R": "KSampler" 86 | }, 87 | "widgets_values": [ 88 | 986778340352574, 89 | "randomize", 90 | 20, 91 | 8, 92 | "euler", 93 | "normal", 94 | 1 95 | ] 96 | }, 97 | { 98 | "id": 6, 99 | "type": "CLIPTextEncode", 100 | "pos": [ 101 | 400, 102 | 50 103 | ], 104 | "size": { 105 | "0": 422.84503173828125, 106 | "1": 164.31304931640625 107 | }, 108 | "flags": {}, 109 | "order": 3, 110 | "mode": 0, 111 | "inputs": [ 112 | { 113 | "name": "clip", 114 | "type": "CLIP", 115 | "link": 3 116 | } 117 | ], 118 | "outputs": [ 119 | { 120 | "name": "CONDITIONING", 121 | "type": "CONDITIONING", 122 | "links": [ 123 | 4, 124 | 15 125 | ], 126 | "slot_index": 0 127 | } 128 | ], 129 | "properties": { 130 | "Node name for S&R": "CLIPTextEncode" 131 | }, 132 | "widgets_values": [ 133 | "beautiful scenery nature glass bottle landscape, , purple galaxy bottle," 134 | ] 135 | }, 136 | { 137 | "id": 7, 138 | "type": "CLIPTextEncode", 139 | "pos": [ 140 | 400, 141 | 250 142 | ], 143 | "size": { 144 | "0": 425.27801513671875, 145 | "1": 180.6060791015625 146 | }, 147 | "flags": {}, 148 | "order": 4, 149 | "mode": 0, 150 | "inputs": [ 151 | { 152 | "name": "clip", 153 | "type": "CLIP", 154 | "link": 5 155 | } 156 | ], 157 | "outputs": [ 158 | { 159 | "name": "CONDITIONING", 160 | "type": "CONDITIONING", 161 | "links": [ 162 | 6, 163 | 16 164 | ], 165 | "slot_index": 0 166 | } 167 | ], 168 | "properties": { 169 | "Node name for S&R": "CLIPTextEncode" 170 | }, 171 | "widgets_values": [ 172 | "text, watermark" 173 | ] 174 | }, 175 | { 176 | "id": 10, 177 | "type": "AITemplateLoader", 178 | "pos": [ 179 | 10, 180 | 240 181 | ], 182 | "size": { 183 | "0": 315, 184 | "1": 58 185 | }, 186 | "flags": {}, 187 | "order": 2, 188 | "mode": 0, 189 | "inputs": [ 190 | { 191 | "name": "model", 192 | "type": "MODEL", 193 | "link": 10 194 | } 195 | ], 196 | "outputs": [ 197 | { 198 | "name": "MODEL", 199 | "type": "MODEL", 200 | "links": [ 201 | 11, 202 | 17 203 | ], 204 | "shape": 3, 205 | "slot_index": 0 206 | } 207 | ], 208 | "properties": { 209 | "Node name for S&R": "AITemplateLoader" 210 | }, 211 | "widgets_values": [ 212 | "enable" 213 | ] 214 | }, 215 | { 216 | "id": 13, 217 | "type": "LatentUpscale", 218 | "pos": [ 219 | 851, 220 | 356 221 | ], 222 | "size": { 223 | "0": 315, 224 | "1": 130 225 | }, 226 | "flags": {}, 227 | "order": 6, 228 | "mode": 0, 229 | "inputs": [ 230 | { 231 | "name": "samples", 232 | "type": "LATENT", 233 | "link": 14 234 | } 235 | ], 236 | "outputs": [ 237 | { 238 | "name": "LATENT", 239 | "type": "LATENT", 240 | "links": [ 241 | 18 242 | ], 243 | "shape": 3, 244 | "slot_index": 0 245 | } 246 | ], 247 | "properties": { 248 | "Node name for S&R": "LatentUpscale" 249 | }, 250 | "widgets_values": [ 251 | "nearest-exact", 252 | 768, 253 | 768, 254 | "disabled" 255 | ] 256 | }, 257 | { 258 | "id": 14, 259 | "type": "AITemplateVAEDecode", 260 | "pos": [ 261 | 1193, 262 | 500 263 | ], 264 | "size": { 265 | "0": 315, 266 | "1": 78 267 | }, 268 | "flags": {}, 269 | "order": 8, 270 | "mode": 0, 271 | "inputs": [ 272 | { 273 | "name": "vae", 274 | "type": "VAE", 275 | "link": 21 276 | }, 277 | { 278 | "name": "samples", 279 | "type": "LATENT", 280 | "link": 19 281 | } 282 | ], 283 | "outputs": [ 284 | { 285 | "name": "IMAGE", 286 | "type": "IMAGE", 287 | "links": [ 288 | 20 289 | ], 290 | "shape": 3, 291 | "slot_index": 0 292 | } 293 | ], 294 | "properties": { 295 | "Node name for S&R": "AITemplateVAEDecode" 296 | }, 297 | "widgets_values": [ 298 | "enable" 299 | ] 300 | }, 301 | { 302 | "id": 4, 303 | "type": "CheckpointLoaderSimple", 304 | "pos": [ 305 | 10, 306 | 340 307 | ], 308 | "size": { 309 | "0": 315, 310 | "1": 98 311 | }, 312 | "flags": {}, 313 | "order": 1, 314 | "mode": 0, 315 | "outputs": [ 316 | { 317 | "name": "MODEL", 318 | "type": "MODEL", 319 | "links": [ 320 | 10 321 | ], 322 | "slot_index": 0 323 | }, 324 | { 325 | "name": "CLIP", 326 | "type": "CLIP", 327 | "links": [ 328 | 3, 329 | 5 330 | ], 331 | "slot_index": 1 332 | }, 333 | { 334 | "name": "VAE", 335 | "type": "VAE", 336 | "links": [ 337 | 21 338 | ], 339 | "slot_index": 2 340 | } 341 | ], 342 | "properties": { 343 | "Node name for S&R": "CheckpointLoaderSimple" 344 | }, 345 | "widgets_values": [ 346 | "v1-5-pruned-emaonly.safetensors" 347 | ] 348 | }, 349 | { 350 | "id": 12, 351 | "type": "KSampler", 352 | "pos": [ 353 | 853, 354 | 530 355 | ], 356 | "size": { 357 | "0": 315, 358 | "1": 262 359 | }, 360 | "flags": {}, 361 | "order": 7, 362 | "mode": 0, 363 | "inputs": [ 364 | { 365 | "name": "model", 366 | "type": "MODEL", 367 | "link": 17 368 | }, 369 | { 370 | "name": "positive", 371 | "type": "CONDITIONING", 372 | "link": 15 373 | }, 374 | { 375 | "name": "negative", 376 | "type": "CONDITIONING", 377 | "link": 16 378 | }, 379 | { 380 | "name": "latent_image", 381 | "type": "LATENT", 382 | "link": 18 383 | } 384 | ], 385 | "outputs": [ 386 | { 387 | "name": "LATENT", 388 | "type": "LATENT", 389 | "links": [ 390 | 19 391 | ], 392 | "shape": 3, 393 | "slot_index": 0 394 | } 395 | ], 396 | "properties": { 397 | "Node name for S&R": "KSampler" 398 | }, 399 | "widgets_values": [ 400 | 290845048926429, 401 | "randomize", 402 | 20, 403 | 8, 404 | "euler", 405 | "normal", 406 | 0.5 407 | ] 408 | }, 409 | { 410 | "id": 9, 411 | "type": "SaveImage", 412 | "pos": [ 413 | 1176, 414 | 49 415 | ], 416 | "size": [ 417 | 345.14699847412135, 418 | 397.81999157714847 419 | ], 420 | "flags": {}, 421 | "order": 9, 422 | "mode": 0, 423 | "inputs": [ 424 | { 425 | "name": "images", 426 | "type": "IMAGE", 427 | "link": 20 428 | } 429 | ], 430 | "properties": {}, 431 | "widgets_values": [ 432 | "ComfyUI" 433 | ] 434 | } 435 | ], 436 | "links": [ 437 | [ 438 | 2, 439 | 5, 440 | 0, 441 | 3, 442 | 3, 443 | "LATENT" 444 | ], 445 | [ 446 | 3, 447 | 4, 448 | 1, 449 | 6, 450 | 0, 451 | "CLIP" 452 | ], 453 | [ 454 | 4, 455 | 6, 456 | 0, 457 | 3, 458 | 1, 459 | "CONDITIONING" 460 | ], 461 | [ 462 | 5, 463 | 4, 464 | 1, 465 | 7, 466 | 0, 467 | "CLIP" 468 | ], 469 | [ 470 | 6, 471 | 7, 472 | 0, 473 | 3, 474 | 2, 475 | "CONDITIONING" 476 | ], 477 | [ 478 | 10, 479 | 4, 480 | 0, 481 | 10, 482 | 0, 483 | "MODEL" 484 | ], 485 | [ 486 | 11, 487 | 10, 488 | 0, 489 | 3, 490 | 0, 491 | "MODEL" 492 | ], 493 | [ 494 | 14, 495 | 3, 496 | 0, 497 | 13, 498 | 0, 499 | "LATENT" 500 | ], 501 | [ 502 | 15, 503 | 6, 504 | 0, 505 | 12, 506 | 1, 507 | "CONDITIONING" 508 | ], 509 | [ 510 | 16, 511 | 7, 512 | 0, 513 | 12, 514 | 2, 515 | "CONDITIONING" 516 | ], 517 | [ 518 | 17, 519 | 10, 520 | 0, 521 | 12, 522 | 0, 523 | "MODEL" 524 | ], 525 | [ 526 | 18, 527 | 13, 528 | 0, 529 | 12, 530 | 3, 531 | "LATENT" 532 | ], 533 | [ 534 | 19, 535 | 12, 536 | 0, 537 | 14, 538 | 1, 539 | "LATENT" 540 | ], 541 | [ 542 | 20, 543 | 14, 544 | 0, 545 | 9, 546 | 0, 547 | "IMAGE" 548 | ], 549 | [ 550 | 21, 551 | 4, 552 | 2, 553 | 14, 554 | 0, 555 | "VAE" 556 | ] 557 | ], 558 | "groups": [], 559 | "config": {}, 560 | "extra": {}, 561 | "version": 0.4 562 | } -------------------------------------------------------------------------------- /workflows/aitemplate_unet_only.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 10, 3 | "last_link_id": 11, 4 | "nodes": [ 5 | { 6 | "id": 7, 7 | "type": "CLIPTextEncode", 8 | "pos": [ 9 | 413, 10 | 389 11 | ], 12 | "size": { 13 | "0": 425.27801513671875, 14 | "1": 180.6060791015625 15 | }, 16 | "flags": {}, 17 | "order": 4, 18 | "mode": 0, 19 | "inputs": [ 20 | { 21 | "name": "clip", 22 | "type": "CLIP", 23 | "link": 5 24 | } 25 | ], 26 | "outputs": [ 27 | { 28 | "name": "CONDITIONING", 29 | "type": "CONDITIONING", 30 | "links": [ 31 | 6 32 | ], 33 | "slot_index": 0 34 | } 35 | ], 36 | "properties": { 37 | "Node name for S&R": "CLIPTextEncode" 38 | }, 39 | "widgets_values": [ 40 | "text, watermark" 41 | ] 42 | }, 43 | { 44 | "id": 6, 45 | "type": "CLIPTextEncode", 46 | "pos": [ 47 | 415, 48 | 186 49 | ], 50 | "size": { 51 | "0": 422.84503173828125, 52 | "1": 164.31304931640625 53 | }, 54 | "flags": {}, 55 | "order": 3, 56 | "mode": 0, 57 | "inputs": [ 58 | { 59 | "name": "clip", 60 | "type": "CLIP", 61 | "link": 3 62 | } 63 | ], 64 | "outputs": [ 65 | { 66 | "name": "CONDITIONING", 67 | "type": "CONDITIONING", 68 | "links": [ 69 | 4 70 | ], 71 | "slot_index": 0 72 | } 73 | ], 74 | "properties": { 75 | "Node name for S&R": "CLIPTextEncode" 76 | }, 77 | "widgets_values": [ 78 | "beautiful scenery nature glass bottle landscape, , purple galaxy bottle," 79 | ] 80 | }, 81 | { 82 | "id": 5, 83 | "type": "EmptyLatentImage", 84 | "pos": [ 85 | 473, 86 | 609 87 | ], 88 | "size": { 89 | "0": 315, 90 | "1": 106 91 | }, 92 | "flags": {}, 93 | "order": 0, 94 | "mode": 0, 95 | "outputs": [ 96 | { 97 | "name": "LATENT", 98 | "type": "LATENT", 99 | "links": [ 100 | 2 101 | ], 102 | "slot_index": 0 103 | } 104 | ], 105 | "properties": { 106 | "Node name for S&R": "EmptyLatentImage" 107 | }, 108 | "widgets_values": [ 109 | 512, 110 | 512, 111 | 1 112 | ] 113 | }, 114 | { 115 | "id": 3, 116 | "type": "KSampler", 117 | "pos": [ 118 | 863, 119 | 186 120 | ], 121 | "size": { 122 | "0": 315, 123 | "1": 262 124 | }, 125 | "flags": {}, 126 | "order": 5, 127 | "mode": 0, 128 | "inputs": [ 129 | { 130 | "name": "model", 131 | "type": "MODEL", 132 | "link": 11 133 | }, 134 | { 135 | "name": "positive", 136 | "type": "CONDITIONING", 137 | "link": 4 138 | }, 139 | { 140 | "name": "negative", 141 | "type": "CONDITIONING", 142 | "link": 6 143 | }, 144 | { 145 | "name": "latent_image", 146 | "type": "LATENT", 147 | "link": 2 148 | } 149 | ], 150 | "outputs": [ 151 | { 152 | "name": "LATENT", 153 | "type": "LATENT", 154 | "links": [ 155 | 7 156 | ], 157 | "slot_index": 0 158 | } 159 | ], 160 | "properties": { 161 | "Node name for S&R": "KSampler" 162 | }, 163 | "widgets_values": [ 164 | 156680208700286, 165 | "randomize", 166 | 20, 167 | 8, 168 | "euler", 169 | "normal", 170 | 1 171 | ] 172 | }, 173 | { 174 | "id": 8, 175 | "type": "VAEDecode", 176 | "pos": [ 177 | 1209, 178 | 188 179 | ], 180 | "size": { 181 | "0": 210, 182 | "1": 46 183 | }, 184 | "flags": {}, 185 | "order": 6, 186 | "mode": 0, 187 | "inputs": [ 188 | { 189 | "name": "samples", 190 | "type": "LATENT", 191 | "link": 7 192 | }, 193 | { 194 | "name": "vae", 195 | "type": "VAE", 196 | "link": 8 197 | } 198 | ], 199 | "outputs": [ 200 | { 201 | "name": "IMAGE", 202 | "type": "IMAGE", 203 | "links": [ 204 | 9 205 | ], 206 | "slot_index": 0 207 | } 208 | ], 209 | "properties": { 210 | "Node name for S&R": "VAEDecode" 211 | } 212 | }, 213 | { 214 | "id": 9, 215 | "type": "SaveImage", 216 | "pos": [ 217 | 1451, 218 | 189 219 | ], 220 | "size": { 221 | "0": 210, 222 | "1": 58 223 | }, 224 | "flags": {}, 225 | "order": 7, 226 | "mode": 0, 227 | "inputs": [ 228 | { 229 | "name": "images", 230 | "type": "IMAGE", 231 | "link": 9 232 | } 233 | ], 234 | "properties": {}, 235 | "widgets_values": [ 236 | "ComfyUI" 237 | ] 238 | }, 239 | { 240 | "id": 4, 241 | "type": "CheckpointLoaderSimple", 242 | "pos": [ 243 | 26, 244 | 474 245 | ], 246 | "size": { 247 | "0": 315, 248 | "1": 98 249 | }, 250 | "flags": {}, 251 | "order": 1, 252 | "mode": 0, 253 | "outputs": [ 254 | { 255 | "name": "MODEL", 256 | "type": "MODEL", 257 | "links": [ 258 | 10 259 | ], 260 | "slot_index": 0 261 | }, 262 | { 263 | "name": "CLIP", 264 | "type": "CLIP", 265 | "links": [ 266 | 3, 267 | 5 268 | ], 269 | "slot_index": 1 270 | }, 271 | { 272 | "name": "VAE", 273 | "type": "VAE", 274 | "links": [ 275 | 8 276 | ], 277 | "slot_index": 2 278 | } 279 | ], 280 | "properties": { 281 | "Node name for S&R": "CheckpointLoaderSimple" 282 | }, 283 | "widgets_values": [ 284 | "v1-5-pruned-emaonly.safetensors" 285 | ] 286 | }, 287 | { 288 | "id": 10, 289 | "type": "AITemplateLoader", 290 | "pos": [ 291 | 26, 292 | 375 293 | ], 294 | "size": { 295 | "0": 315, 296 | "1": 58 297 | }, 298 | "flags": {}, 299 | "order": 2, 300 | "mode": 0, 301 | "inputs": [ 302 | { 303 | "name": "model", 304 | "type": "MODEL", 305 | "link": 10 306 | } 307 | ], 308 | "outputs": [ 309 | { 310 | "name": "MODEL", 311 | "type": "MODEL", 312 | "links": [ 313 | 11 314 | ], 315 | "shape": 3, 316 | "slot_index": 0 317 | } 318 | ], 319 | "properties": { 320 | "Node name for S&R": "AITemplateLoader" 321 | }, 322 | "widgets_values": [ 323 | "enable" 324 | ] 325 | } 326 | ], 327 | "links": [ 328 | [ 329 | 2, 330 | 5, 331 | 0, 332 | 3, 333 | 3, 334 | "LATENT" 335 | ], 336 | [ 337 | 3, 338 | 4, 339 | 1, 340 | 6, 341 | 0, 342 | "CLIP" 343 | ], 344 | [ 345 | 4, 346 | 6, 347 | 0, 348 | 3, 349 | 1, 350 | "CONDITIONING" 351 | ], 352 | [ 353 | 5, 354 | 4, 355 | 1, 356 | 7, 357 | 0, 358 | "CLIP" 359 | ], 360 | [ 361 | 6, 362 | 7, 363 | 0, 364 | 3, 365 | 2, 366 | "CONDITIONING" 367 | ], 368 | [ 369 | 7, 370 | 3, 371 | 0, 372 | 8, 373 | 0, 374 | "LATENT" 375 | ], 376 | [ 377 | 8, 378 | 4, 379 | 2, 380 | 8, 381 | 1, 382 | "VAE" 383 | ], 384 | [ 385 | 9, 386 | 8, 387 | 0, 388 | 9, 389 | 0, 390 | "IMAGE" 391 | ], 392 | [ 393 | 10, 394 | 4, 395 | 0, 396 | 10, 397 | 0, 398 | "MODEL" 399 | ], 400 | [ 401 | 11, 402 | 10, 403 | 0, 404 | 3, 405 | 0, 406 | "MODEL" 407 | ] 408 | ], 409 | "groups": [], 410 | "config": {}, 411 | "extra": {}, 412 | "version": 0.4 413 | } -------------------------------------------------------------------------------- /workflows/aitemplate_unet_vae.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 11, 3 | "last_link_id": 14, 4 | "nodes": [ 5 | { 6 | "id": 7, 7 | "type": "CLIPTextEncode", 8 | "pos": [ 9 | 413, 10 | 389 11 | ], 12 | "size": { 13 | "0": 425.27801513671875, 14 | "1": 180.6060791015625 15 | }, 16 | "flags": {}, 17 | "order": 4, 18 | "mode": 0, 19 | "inputs": [ 20 | { 21 | "name": "clip", 22 | "type": "CLIP", 23 | "link": 5 24 | } 25 | ], 26 | "outputs": [ 27 | { 28 | "name": "CONDITIONING", 29 | "type": "CONDITIONING", 30 | "links": [ 31 | 6 32 | ], 33 | "slot_index": 0 34 | } 35 | ], 36 | "properties": { 37 | "Node name for S&R": "CLIPTextEncode" 38 | }, 39 | "widgets_values": [ 40 | "text, watermark" 41 | ] 42 | }, 43 | { 44 | "id": 6, 45 | "type": "CLIPTextEncode", 46 | "pos": [ 47 | 415, 48 | 186 49 | ], 50 | "size": { 51 | "0": 422.84503173828125, 52 | "1": 164.31304931640625 53 | }, 54 | "flags": {}, 55 | "order": 3, 56 | "mode": 0, 57 | "inputs": [ 58 | { 59 | "name": "clip", 60 | "type": "CLIP", 61 | "link": 3 62 | } 63 | ], 64 | "outputs": [ 65 | { 66 | "name": "CONDITIONING", 67 | "type": "CONDITIONING", 68 | "links": [ 69 | 4 70 | ], 71 | "slot_index": 0 72 | } 73 | ], 74 | "properties": { 75 | "Node name for S&R": "CLIPTextEncode" 76 | }, 77 | "widgets_values": [ 78 | "beautiful scenery nature glass bottle landscape, , purple galaxy bottle," 79 | ] 80 | }, 81 | { 82 | "id": 5, 83 | "type": "EmptyLatentImage", 84 | "pos": [ 85 | 473, 86 | 609 87 | ], 88 | "size": { 89 | "0": 315, 90 | "1": 106 91 | }, 92 | "flags": {}, 93 | "order": 0, 94 | "mode": 0, 95 | "outputs": [ 96 | { 97 | "name": "LATENT", 98 | "type": "LATENT", 99 | "links": [ 100 | 2 101 | ], 102 | "slot_index": 0 103 | } 104 | ], 105 | "properties": { 106 | "Node name for S&R": "EmptyLatentImage" 107 | }, 108 | "widgets_values": [ 109 | 512, 110 | 512, 111 | 1 112 | ] 113 | }, 114 | { 115 | "id": 9, 116 | "type": "SaveImage", 117 | "pos": [ 118 | 1451, 119 | 189 120 | ], 121 | "size": { 122 | "0": 210, 123 | "1": 58 124 | }, 125 | "flags": {}, 126 | "order": 7, 127 | "mode": 0, 128 | "inputs": [ 129 | { 130 | "name": "images", 131 | "type": "IMAGE", 132 | "link": 13 133 | } 134 | ], 135 | "properties": {}, 136 | "widgets_values": [ 137 | "ComfyUI" 138 | ] 139 | }, 140 | { 141 | "id": 10, 142 | "type": "AITemplateLoader", 143 | "pos": [ 144 | 26, 145 | 375 146 | ], 147 | "size": { 148 | "0": 315, 149 | "1": 58 150 | }, 151 | "flags": {}, 152 | "order": 2, 153 | "mode": 0, 154 | "inputs": [ 155 | { 156 | "name": "model", 157 | "type": "MODEL", 158 | "link": 10 159 | } 160 | ], 161 | "outputs": [ 162 | { 163 | "name": "MODEL", 164 | "type": "MODEL", 165 | "links": [ 166 | 11 167 | ], 168 | "shape": 3, 169 | "slot_index": 0 170 | } 171 | ], 172 | "properties": { 173 | "Node name for S&R": "AITemplateLoader" 174 | }, 175 | "widgets_values": [ 176 | "enable" 177 | ] 178 | }, 179 | { 180 | "id": 3, 181 | "type": "KSampler", 182 | "pos": [ 183 | 863, 184 | 186 185 | ], 186 | "size": { 187 | "0": 315, 188 | "1": 262 189 | }, 190 | "flags": {}, 191 | "order": 5, 192 | "mode": 0, 193 | "inputs": [ 194 | { 195 | "name": "model", 196 | "type": "MODEL", 197 | "link": 11 198 | }, 199 | { 200 | "name": "positive", 201 | "type": "CONDITIONING", 202 | "link": 4 203 | }, 204 | { 205 | "name": "negative", 206 | "type": "CONDITIONING", 207 | "link": 6 208 | }, 209 | { 210 | "name": "latent_image", 211 | "type": "LATENT", 212 | "link": 2 213 | } 214 | ], 215 | "outputs": [ 216 | { 217 | "name": "LATENT", 218 | "type": "LATENT", 219 | "links": [ 220 | 12 221 | ], 222 | "slot_index": 0 223 | } 224 | ], 225 | "properties": { 226 | "Node name for S&R": "KSampler" 227 | }, 228 | "widgets_values": [ 229 | 156680208700286, 230 | "randomize", 231 | 20, 232 | 8, 233 | "euler", 234 | "normal", 235 | 1 236 | ] 237 | }, 238 | { 239 | "id": 11, 240 | "type": "AITemplateVAEDecode", 241 | "pos": [ 242 | 1150, 243 | 71 244 | ], 245 | "size": { 246 | "0": 315, 247 | "1": 78 248 | }, 249 | "flags": {}, 250 | "order": 6, 251 | "mode": 0, 252 | "inputs": [ 253 | { 254 | "name": "vae", 255 | "type": "VAE", 256 | "link": 14 257 | }, 258 | { 259 | "name": "samples", 260 | "type": "LATENT", 261 | "link": 12 262 | } 263 | ], 264 | "outputs": [ 265 | { 266 | "name": "IMAGE", 267 | "type": "IMAGE", 268 | "links": [ 269 | 13 270 | ], 271 | "shape": 3, 272 | "slot_index": 0 273 | } 274 | ], 275 | "properties": { 276 | "Node name for S&R": "AITemplateVAEDecode" 277 | }, 278 | "widgets_values": [ 279 | "enable" 280 | ] 281 | }, 282 | { 283 | "id": 4, 284 | "type": "CheckpointLoaderSimple", 285 | "pos": [ 286 | 26, 287 | 474 288 | ], 289 | "size": { 290 | "0": 315, 291 | "1": 98 292 | }, 293 | "flags": {}, 294 | "order": 1, 295 | "mode": 0, 296 | "outputs": [ 297 | { 298 | "name": "MODEL", 299 | "type": "MODEL", 300 | "links": [ 301 | 10 302 | ], 303 | "slot_index": 0 304 | }, 305 | { 306 | "name": "CLIP", 307 | "type": "CLIP", 308 | "links": [ 309 | 3, 310 | 5 311 | ], 312 | "slot_index": 1 313 | }, 314 | { 315 | "name": "VAE", 316 | "type": "VAE", 317 | "links": [ 318 | 14 319 | ], 320 | "slot_index": 2 321 | } 322 | ], 323 | "properties": { 324 | "Node name for S&R": "CheckpointLoaderSimple" 325 | }, 326 | "widgets_values": [ 327 | "v1-5-pruned-emaonly.safetensors" 328 | ] 329 | } 330 | ], 331 | "links": [ 332 | [ 333 | 2, 334 | 5, 335 | 0, 336 | 3, 337 | 3, 338 | "LATENT" 339 | ], 340 | [ 341 | 3, 342 | 4, 343 | 1, 344 | 6, 345 | 0, 346 | "CLIP" 347 | ], 348 | [ 349 | 4, 350 | 6, 351 | 0, 352 | 3, 353 | 1, 354 | "CONDITIONING" 355 | ], 356 | [ 357 | 5, 358 | 4, 359 | 1, 360 | 7, 361 | 0, 362 | "CLIP" 363 | ], 364 | [ 365 | 6, 366 | 7, 367 | 0, 368 | 3, 369 | 2, 370 | "CONDITIONING" 371 | ], 372 | [ 373 | 10, 374 | 4, 375 | 0, 376 | 10, 377 | 0, 378 | "MODEL" 379 | ], 380 | [ 381 | 11, 382 | 10, 383 | 0, 384 | 3, 385 | 0, 386 | "MODEL" 387 | ], 388 | [ 389 | 12, 390 | 3, 391 | 0, 392 | 11, 393 | 1, 394 | "LATENT" 395 | ], 396 | [ 397 | 13, 398 | 11, 399 | 0, 400 | 9, 401 | 0, 402 | "IMAGE" 403 | ], 404 | [ 405 | 14, 406 | 4, 407 | 2, 408 | 11, 409 | 0, 410 | "VAE" 411 | ] 412 | ], 413 | "groups": [], 414 | "config": {}, 415 | "extra": {}, 416 | "version": 0.4 417 | } --------------------------------------------------------------------------------