├── .gitignore ├── compile_ait.py ├── demo_part.py ├── demo_sd_fp16.py ├── demo_sd_openvino.py ├── export.py ├── export_unet.py ├── prompts.txt ├── prompts ├── dream.txt ├── people.txt └── wallpaper.txt ├── readme.md ├── stablefusion ├── .gitignore ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── clip_textmodel.cpython-38.pyc │ └── stablefusion_ov_engine.cpython-38.pyc ├── ait_modeling │ ├── attention.py │ ├── clip.py │ ├── embeddings.py │ ├── resnet.py │ ├── unet_2d_condition.py │ ├── unet_blocks.py │ └── vae.py ├── clip_textmodel.py ├── stablefusion_ov_engine.py ├── stablefusion_pipeline.py ├── test.py ├── trt_model.py └── unet_2d_condition.py └── test_diffusers.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | .vs/ 3 | *.png 4 | weights/ 5 | __pycache__/ 6 | results/ 7 | *.onnx 8 | *.trt 9 | 10 | vendor/ 11 | -------------------------------------------------------------------------------- /compile_ait.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | import logging 16 | from collections import OrderedDict 17 | 18 | import click 19 | import numpy as np 20 | 21 | import torch 22 | 23 | from aitemplate.compiler import compile_model 24 | from aitemplate.frontend import Tensor 25 | from aitemplate.testing import detect_target 26 | from diffusers import StableDiffusionPipeline 27 | 28 | from modeling.clip import CLIPTextTransformer as ait_CLIPTextTransformer 29 | 30 | from modeling.unet_2d_condition import UNet2DConditionModel as ait_UNet2DConditionModel 31 | 32 | from modeling.vae import AutoencoderKL as ait_AutoencoderKL 33 | 34 | 35 | USE_CUDA = detect_target().name() == "cuda" 36 | 37 | access_token = True 38 | pipe = None 39 | 40 | 41 | def mark_output(y): 42 | if type(y) is not tuple: 43 | y = (y,) 44 | for i in range(len(y)): 45 | y[i]._attrs["is_output"] = True 46 | y[i]._attrs["name"] = "output_%d" % (i) 47 | y_shape = [d._attrs["values"][0] for d in y[i]._attrs["shape"]] 48 | print("AIT output_{} shape: {}".format(i, y_shape)) 49 | 50 | 51 | def map_unet_params(pt_mod, dim): 52 | pt_params = dict(pt_mod.named_parameters()) 53 | params_ait = {} 54 | for key, arr in pt_params.items(): 55 | if len(arr.shape) == 4: 56 | arr = arr.permute((0, 2, 3, 1)).contiguous() 57 | elif key.endswith("ff.net.0.proj.weight"): 58 | w1, w2 = arr.chunk(2, dim=0) 59 | params_ait[key.replace(".", "_")] = w1 60 | params_ait[key.replace(".", "_").replace("proj", "gate")] = w2 61 | continue 62 | elif key.endswith("ff.net.0.proj.bias"): 63 | w1, w2 = arr.chunk(2, dim=0) 64 | params_ait[key.replace(".", "_")] = w1 65 | params_ait[key.replace(".", "_").replace("proj", "gate")] = w2 66 | continue 67 | params_ait[key.replace(".", "_")] = arr 68 | 69 | params_ait["arange"] = ( 70 | torch.arange(start=0, end=dim // 2, dtype=torch.float32).cuda().half() 71 | ) 72 | return params_ait 73 | 74 | 75 | def map_vae_params(ait_module, pt_module, batch_size, seq_len): 76 | pt_params = dict(pt_module.named_parameters()) 77 | mapped_pt_params = OrderedDict() 78 | for name, _ in ait_module.named_parameters(): 79 | ait_name = name.replace(".", "_") 80 | if name in pt_params: 81 | if ( 82 | "conv" in name 83 | and "norm" not in name 84 | and name.endswith(".weight") 85 | and len(pt_params[name].shape) == 4 86 | ): 87 | mapped_pt_params[ait_name] = torch.permute( 88 | pt_params[name], [0, 2, 3, 1] 89 | ).contiguous() 90 | else: 91 | mapped_pt_params[ait_name] = pt_params[name] 92 | elif name.endswith("attention.qkv.weight"): 93 | prefix = name[: -len("attention.qkv.weight")] 94 | q_weight = pt_params[prefix + "query.weight"] 95 | k_weight = pt_params[prefix + "key.weight"] 96 | v_weight = pt_params[prefix + "value.weight"] 97 | qkv_weight = torch.cat([q_weight, k_weight, v_weight], dim=0) 98 | mapped_pt_params[ait_name] = qkv_weight 99 | elif name.endswith("attention.qkv.bias"): 100 | prefix = name[: -len("attention.qkv.bias")] 101 | q_bias = pt_params[prefix + "query.bias"] 102 | k_bias = pt_params[prefix + "key.bias"] 103 | v_bias = pt_params[prefix + "value.bias"] 104 | qkv_bias = torch.cat([q_bias, k_bias, v_bias], dim=0) 105 | mapped_pt_params[ait_name] = qkv_bias 106 | elif name.endswith("attention.proj.weight"): 107 | prefix = name[: -len("attention.proj.weight")] 108 | pt_name = prefix + "proj_attn.weight" 109 | mapped_pt_params[ait_name] = pt_params[pt_name] 110 | elif name.endswith("attention.proj.bias"): 111 | prefix = name[: -len("attention.proj.bias")] 112 | pt_name = prefix + "proj_attn.bias" 113 | mapped_pt_params[ait_name] = pt_params[pt_name] 114 | elif name.endswith("attention.cu_length"): 115 | cu_len = np.cumsum([0] + [seq_len] * batch_size).astype("int32") 116 | mapped_pt_params[ait_name] = torch.from_numpy(cu_len).cuda() 117 | else: 118 | pt_param = pt_module.get_parameter(name) 119 | mapped_pt_params[ait_name] = pt_param 120 | 121 | return mapped_pt_params 122 | 123 | 124 | def map_clip_params(pt_mod, batch_size, seqlen, depth): 125 | 126 | params_pt = list(pt_mod.named_parameters()) 127 | 128 | params_ait = {} 129 | pt_params = {} 130 | for key, arr in params_pt: 131 | pt_params[key.replace("text_model.", "")] = arr 132 | 133 | pt_params = dict(pt_mod.named_parameters()) 134 | for key, arr in pt_params.items(): 135 | name = key.replace("text_model.", "") 136 | ait_name = name.replace(".", "_") 137 | if name.endswith("out_proj.weight"): 138 | ait_name = ait_name.replace("out_proj", "proj") 139 | elif name.endswith("out_proj.bias"): 140 | ait_name = ait_name.replace("out_proj", "proj") 141 | elif name.endswith("q_proj.weight"): 142 | ait_name = ait_name.replace("q_proj", "qkv") 143 | prefix = key[: -len("q_proj.weight")] 144 | q = pt_params[prefix + "q_proj.weight"] 145 | k = pt_params[prefix + "k_proj.weight"] 146 | v = pt_params[prefix + "v_proj.weight"] 147 | qkv_weight = torch.cat([q, k, v], dim=0) 148 | params_ait[ait_name] = qkv_weight 149 | continue 150 | elif name.endswith("q_proj.bias"): 151 | ait_name = ait_name.replace("q_proj", "qkv") 152 | prefix = key[: -len("q_proj.bias")] 153 | q = pt_params[prefix + "q_proj.bias"] 154 | k = pt_params[prefix + "k_proj.bias"] 155 | v = pt_params[prefix + "v_proj.bias"] 156 | qkv_bias = torch.cat([q, k, v], dim=0) 157 | params_ait[ait_name] = qkv_bias 158 | continue 159 | elif name.endswith("k_proj.weight"): 160 | continue 161 | elif name.endswith("k_proj.bias"): 162 | continue 163 | elif name.endswith("v_proj.weight"): 164 | continue 165 | elif name.endswith("v_proj.bias"): 166 | continue 167 | params_ait[ait_name] = arr 168 | 169 | if USE_CUDA: 170 | for i in range(depth): 171 | prefix = "encoder_layers_%d_self_attn_cu_length" % (i) 172 | cu_len = np.cumsum([0] + [seqlen] * batch_size).astype("int32") 173 | params_ait[prefix] = torch.from_numpy(cu_len).cuda() 174 | 175 | return params_ait 176 | 177 | 178 | def compile_unet( 179 | batch_size=2, 180 | hh=64, 181 | ww=64, 182 | dim=320, 183 | use_fp16_acc=False, 184 | convert_conv_to_gemm=False, 185 | ): 186 | 187 | ait_mod = ait_UNet2DConditionModel(sample_size=64, cross_attention_dim=768) 188 | ait_mod.name_parameter_tensor() 189 | 190 | # set AIT parameters 191 | pt_mod = pipe.unet 192 | pt_mod = pt_mod.eval() 193 | params_ait = map_unet_params(pt_mod, dim) 194 | 195 | latent_model_input_ait = Tensor( 196 | [batch_size, hh, ww, 4], name="input0", is_input=True 197 | ) 198 | timesteps_ait = Tensor([batch_size], name="input1", is_input=True) 199 | text_embeddings_pt_ait = Tensor([batch_size, 64, 768], name="input2", is_input=True) 200 | 201 | Y = ait_mod(latent_model_input_ait, timesteps_ait, text_embeddings_pt_ait) 202 | mark_output(Y) 203 | 204 | target = detect_target( 205 | use_fp16_acc=use_fp16_acc, convert_conv_to_gemm=convert_conv_to_gemm 206 | ) 207 | compile_model(Y, target, "./tmp", "UNet2DConditionModel", constants=params_ait) 208 | 209 | 210 | def compile_clip( 211 | batch_size=1, 212 | seqlen=64, 213 | dim=768, 214 | num_heads=12, 215 | hidden_size=768, 216 | vocab_size=49408, 217 | max_position_embeddings=77, 218 | use_fp16_acc=False, 219 | convert_conv_to_gemm=False, 220 | ): 221 | mask_seq = 0 222 | causal = True 223 | depth = 12 224 | 225 | ait_mod = ait_CLIPTextTransformer( 226 | num_hidden_layers=depth, 227 | hidden_size=dim, 228 | num_attention_heads=num_heads, 229 | batch_size=batch_size, 230 | seq_len=seqlen, 231 | causal=causal, 232 | mask_seq=mask_seq, 233 | ) 234 | ait_mod.name_parameter_tensor() 235 | 236 | pt_mod = pipe.text_encoder 237 | pt_mod = pt_mod.eval() 238 | params_ait = map_clip_params(pt_mod, batch_size, seqlen, depth) 239 | 240 | input_ids_ait = Tensor( 241 | [batch_size, seqlen], name="input0", dtype="int64", is_input=True 242 | ) 243 | position_ids_ait = Tensor( 244 | [batch_size, seqlen], name="input1", dtype="int64", is_input=True 245 | ) 246 | Y = ait_mod(input_ids=input_ids_ait, position_ids=position_ids_ait) 247 | mark_output(Y) 248 | 249 | target = detect_target( 250 | use_fp16_acc=use_fp16_acc, convert_conv_to_gemm=convert_conv_to_gemm 251 | ) 252 | compile_model(Y, target, "./tmp", "CLIPTextModel", constants=params_ait) 253 | 254 | 255 | def compile_vae( 256 | batch_size=1, height=64, width=64, use_fp16_acc=False, convert_conv_to_gemm=False 257 | ): 258 | in_channels = 3 259 | out_channels = 3 260 | down_block_types = [ 261 | "DownEncoderBlock2D", 262 | "DownEncoderBlock2D", 263 | "DownEncoderBlock2D", 264 | "DownEncoderBlock2D", 265 | ] 266 | up_block_types = [ 267 | "UpDecoderBlock2D", 268 | "UpDecoderBlock2D", 269 | "UpDecoderBlock2D", 270 | "UpDecoderBlock2D", 271 | ] 272 | block_out_channels = [128, 256, 512, 512] 273 | layers_per_block = 2 274 | act_fn = "silu" 275 | latent_channels = 4 276 | sample_size = 512 277 | 278 | ait_vae = ait_AutoencoderKL( 279 | batch_size, 280 | height, 281 | width, 282 | in_channels=in_channels, 283 | out_channels=out_channels, 284 | down_block_types=down_block_types, 285 | up_block_types=up_block_types, 286 | block_out_channels=block_out_channels, 287 | layers_per_block=layers_per_block, 288 | act_fn=act_fn, 289 | latent_channels=latent_channels, 290 | sample_size=sample_size, 291 | ) 292 | ait_input = Tensor( 293 | shape=[batch_size, height, width, latent_channels], 294 | name="vae_input", 295 | is_input=True, 296 | ) 297 | ait_vae.name_parameter_tensor() 298 | 299 | pt_mod = pipe.vae 300 | pt_mod = pt_mod.eval() 301 | params_ait = map_vae_params(ait_vae, pt_mod, batch_size, height * width) 302 | 303 | Y = ait_vae.decode(ait_input) 304 | mark_output(Y) 305 | target = detect_target( 306 | use_fp16_acc=use_fp16_acc, convert_conv_to_gemm=convert_conv_to_gemm 307 | ) 308 | compile_model( 309 | Y, 310 | target, 311 | "./tmp", 312 | "AutoencoderKL", 313 | constants=params_ait, 314 | ) 315 | 316 | 317 | @click.command() 318 | @click.option("--token", default="", help="access token") 319 | @click.option("--batch-size", default=1, help="batch size") 320 | @click.option("--use-fp16-acc", default=True, help="use fp16 accumulation") 321 | @click.option("--convert-conv-to-gemm", default=True, help="convert 1x1 conv to gemm") 322 | def compile_diffusers(token, batch_size, use_fp16_acc=True, convert_conv_to_gemm=True): 323 | logging.getLogger().setLevel(logging.INFO) 324 | np.random.seed(0) 325 | torch.manual_seed(4896) 326 | 327 | if detect_target().name() == "rocm": 328 | convert_conv_to_gemm = False 329 | 330 | global access_token, pipe 331 | if token != "": 332 | access_token = token 333 | 334 | pipe = StableDiffusionPipeline.from_pretrained( 335 | "CompVis/stable-diffusion-v1-4", 336 | revision="fp16", 337 | torch_dtype=torch.float16, 338 | use_auth_token=True, 339 | ).to("cuda") 340 | 341 | # CLIP 342 | compile_clip( 343 | batch_size=batch_size, 344 | use_fp16_acc=use_fp16_acc, 345 | convert_conv_to_gemm=convert_conv_to_gemm, 346 | ) 347 | # UNet 348 | compile_unet( 349 | batch_size=batch_size * 2, 350 | use_fp16_acc=use_fp16_acc, 351 | convert_conv_to_gemm=convert_conv_to_gemm, 352 | ) 353 | # VAE 354 | compile_vae( 355 | batch_size=batch_size, 356 | use_fp16_acc=use_fp16_acc, 357 | convert_conv_to_gemm=convert_conv_to_gemm, 358 | ) 359 | 360 | 361 | if __name__ == "__main__": 362 | compile_diffusers() 363 | -------------------------------------------------------------------------------- /demo_part.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch import autocast 4 | from transformers import CLIPTextModel, CLIPTokenizer 5 | from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler 6 | from diffusers import LMSDiscreteScheduler 7 | from tqdm.auto import tqdm 8 | from PIL import Image 9 | import argparse 10 | 11 | # from alfred.deploy.tensorrt.wrapper import TensorRTInferencer 12 | from stablefusion.trt_model import TRTModel 13 | from alfred import logger 14 | 15 | 16 | torch_device = "cuda" 17 | 18 | YOUR_TOKEN = None 19 | # height = 960 20 | # width = 1080 21 | height = 512 22 | width = 544 23 | UNET_INPUTS_CHANNEL = 4 24 | BASE_MODEL_DIR = "weights/stable-diffusion-v1-4" 25 | 26 | 27 | def parse_args(): 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument("--seed", type=int, default=1024) 30 | parser.add_argument("--beta-start", type=float, default=0.00085, help="beta_start") 31 | parser.add_argument("--beta-end", type=float, default=0.012, help="::beta_end") 32 | parser.add_argument("--beta-schedule", type=str, default="scaled_linear") 33 | parser.add_argument("--trt", action="store_true", default=False) 34 | 35 | parser.add_argument("--num_inference_steps", type=int, default=68) 36 | parser.add_argument("--guidance_scale", type=float, default=7.5) 37 | parser.add_argument("--eta", type=float, default=0.0, help="eta") 38 | 39 | parser.add_argument("--prompt", type=str, default="prompts.txt", help="prompt") 40 | parser.add_argument("--init-image", type=str, default=None, help=" image") 41 | parser.add_argument("--strength", type=float, default=0.5, help="how [0.0, 1.0]") 42 | parser.add_argument("--mask", type=str, default=None, help="maskial image") 43 | return parser.parse_args() 44 | 45 | 46 | def main(): 47 | args = parse_args() 48 | num_inference_steps = args.num_inference_steps 49 | guidance_scale = args.guidance_scale 50 | batch_size = 1 51 | 52 | if os.path.isfile(args.prompt): 53 | txts = open(args.prompt, "r").readlines() 54 | txts = [i.strip() for i in txts] 55 | else: 56 | txts = [args.prompt] 57 | 58 | if args.trt: 59 | unet_trt_enigne = "unet_fp16.trt" 60 | logger.info(f"using TensorRT inference unet: {unet_trt_enigne}") 61 | assert os.path.exists(unet_trt_enigne), f"{unet_trt_enigne} not found!" 62 | # unet = TensorRTInferencer(unet_trt_enigne) 63 | unet = TRTModel(unet_trt_enigne) 64 | logger.info("unet loaded in trt.") 65 | else: 66 | unet = UNet2DConditionModel.from_pretrained( 67 | BASE_MODEL_DIR, 68 | subfolder="unet", 69 | torch_dtype=torch.float16, 70 | revision="fp16", 71 | use_auth_token=YOUR_TOKEN, 72 | ) 73 | 74 | vae = AutoencoderKL.from_pretrained( 75 | BASE_MODEL_DIR, subfolder="vae", use_auth_token=YOUR_TOKEN 76 | ) 77 | tokenizer = CLIPTokenizer.from_pretrained(BASE_MODEL_DIR + "/tokenizer") 78 | text_encoder = CLIPTextModel.from_pretrained(BASE_MODEL_DIR + "/text_encoder") 79 | 80 | scheduler = LMSDiscreteScheduler( 81 | beta_start=0.00085, 82 | beta_end=0.012, 83 | beta_schedule="scaled_linear", 84 | num_train_timesteps=1000, 85 | ) 86 | # Set the models to your inference device 87 | vae.to(torch_device) 88 | text_encoder.to(torch_device) 89 | if not args.trt: 90 | unet.to(torch_device) 91 | 92 | for index, prompt in enumerate(txts): 93 | text_input = tokenizer( 94 | prompt, 95 | padding="max_length", 96 | max_length=tokenizer.model_max_length, 97 | truncation=True, 98 | return_tensors="pt", 99 | ) 100 | text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0] 101 | 102 | max_length = text_input.input_ids.shape[-1] 103 | uncond_input = tokenizer( 104 | [""] * batch_size, 105 | padding="max_length", 106 | max_length=max_length, 107 | return_tensors="pt", 108 | ) 109 | uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0] 110 | 111 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).half().cuda() 112 | 113 | latents = torch.randn( 114 | (batch_size, UNET_INPUTS_CHANNEL, height // 8, width // 8) 115 | ) 116 | latents = latents.to(torch_device) 117 | scheduler.set_timesteps(num_inference_steps) 118 | latents = latents * scheduler.sigmas[0] 119 | 120 | scheduler.set_timesteps(num_inference_steps) 121 | # Denoising Loop 122 | with torch.inference_mode(), autocast("cuda"): 123 | for i, t in tqdm(enumerate(scheduler.timesteps)): 124 | # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. 125 | latent_model_input = torch.cat([latents] * 2) 126 | sigma = scheduler.sigmas[i] 127 | latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) 128 | 129 | # predict the noise residual 130 | if args.trt: 131 | # noise_pred = unet.infer( 132 | # latent_model_input, t, encoder_hidden_states=text_embeddings 133 | # ) 134 | inputs = [ 135 | latent_model_input, 136 | torch.tensor([t]).to(torch_device), 137 | text_embeddings, 138 | ] 139 | noise_pred, duration = unet(inputs, timing=True) 140 | noise_pred = torch.reshape( 141 | noise_pred[0], (batch_size * 2, 4, 64, 64) 142 | ) 143 | else: 144 | noise_pred = unet( 145 | latent_model_input, t, encoder_hidden_states=text_embeddings 146 | )["sample"] 147 | 148 | # perform guidance 149 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 150 | noise_pred = noise_pred_uncond + guidance_scale * ( 151 | noise_pred_text - noise_pred_uncond 152 | ) 153 | 154 | # compute the previous noisy sample x_t -> x_t-1 155 | latents = scheduler.step(noise_pred, i, latents)["prev_sample"] 156 | 157 | # scale and decode the image latents with vae 158 | latents = 1 / 0.18215 * latents 159 | image = vae.decode(latents) 160 | 161 | image = image.sample 162 | # Convert the image with PIL and save it 163 | image = (image / 2 + 0.5).clamp(0, 1) 164 | image = image.detach().cpu().permute(0, 2, 3, 1).numpy() 165 | images = (image * 255).round().astype("uint8") 166 | pil_images = [Image.fromarray(image) for image in images] 167 | pil_images[0].save(f"image_generated_{index}.png") 168 | 169 | 170 | if __name__ == "__main__": 171 | main() 172 | -------------------------------------------------------------------------------- /demo_sd_fp16.py: -------------------------------------------------------------------------------- 1 | from stablefusion.stablefusion_ov_engine import StableDiffusionEngine 2 | from diffusers.pipelines import StableDiffusionPipeline 3 | import torch 4 | import cv2 5 | 6 | 7 | def main(): 8 | pipe = StableDiffusionPipeline.from_pretrained( 9 | "weights/stable-diffusion-v1-4", 10 | revision="fp16", 11 | torch_dtype=torch.float16, 12 | safety_checker=None, 13 | ) 14 | pipe = pipe.to("cuda") 15 | 16 | prompt = "a photo of an astronaut riding a horse on mars" 17 | image = pipe(prompt).images[0] 18 | image.save("res.png") 19 | # cv2.imwrite(f"res.png", image.cpu().numpy()) 20 | 21 | 22 | if __name__ == '__main__': 23 | main() 24 | -------------------------------------------------------------------------------- /demo_sd_openvino.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from stablefusion.stablefusion_ov_engine import StableDiffusionEngine 4 | from diffusers import LMSDiscreteScheduler, PNDMScheduler 5 | import cv2 6 | import numpy as np 7 | from alfred import logger 8 | import os 9 | 10 | 11 | def main(args): 12 | if args.seed is not None: 13 | np.random.seed(args.seed) 14 | if args.init_image is None: 15 | scheduler = LMSDiscreteScheduler( 16 | beta_start=args.beta_start, 17 | beta_end=args.beta_end, 18 | beta_schedule=args.beta_schedule, 19 | tensor_format="np", 20 | ) 21 | else: 22 | scheduler = PNDMScheduler( 23 | beta_start=args.beta_start, 24 | beta_end=args.beta_end, 25 | beta_schedule=args.beta_schedule, 26 | skip_prk_steps=True, 27 | tensor_format="np", 28 | ) 29 | engine = StableDiffusionEngine( 30 | model=args.model, 31 | scheduler=scheduler, 32 | tokenizer=args.tokenizer, 33 | local_model_path="weights/onnx", 34 | ) 35 | txts = [] 36 | if os.path.isfile(args.prompt): 37 | txts = open(args.prompt, "r").readlines() 38 | txts = [i.strip() for i in txts] 39 | else: 40 | txts = [args.prompt] 41 | 42 | for i, prompt in enumerate(txts): 43 | image = engine( 44 | prompt=prompt, 45 | init_image=None if args.init_image is None else cv2.imread(args.init_image), 46 | mask=None if args.mask is None else cv2.imread(args.mask, 0), 47 | strength=args.strength, 48 | num_inference_steps=args.num_inference_steps, 49 | guidance_scale=args.guidance_scale, 50 | eta=args.eta, 51 | ) 52 | cv2.imwrite(f"res{i}_{args.output}", image) 53 | logger.info(f"result save into: res{i}_{args.output}") 54 | 55 | 56 | if __name__ == "__main__": 57 | parser = argparse.ArgumentParser() 58 | # pipeline configure 59 | parser.add_argument( 60 | "--model", 61 | type=str, 62 | default="bes-dev/stable-diffusion-v1-4-openvino", 63 | help="model name", 64 | ) 65 | # randomizer params 66 | parser.add_argument( 67 | "--seed", 68 | type=int, 69 | default=None, 70 | help="random seed for generating consistent images per prompt", 71 | ) 72 | # scheduler params 73 | parser.add_argument( 74 | "--beta-start", 75 | type=float, 76 | default=0.00085, 77 | help="LMSDiscreteScheduler::beta_start", 78 | ) 79 | parser.add_argument( 80 | "--beta-end", type=float, default=0.012, help="LMSDiscreteScheduler::beta_end" 81 | ) 82 | parser.add_argument( 83 | "--beta-schedule", 84 | type=str, 85 | default="scaled_linear", 86 | help="LMSDiscreteScheduler::beta_schedule", 87 | ) 88 | # diffusion params 89 | parser.add_argument( 90 | "--num-inference-steps", type=int, default=32, help="num inference steps" 91 | ) 92 | parser.add_argument( 93 | "--guidance-scale", type=float, default=7.5, help="guidance scale" 94 | ) 95 | parser.add_argument("--eta", type=float, default=0.0, help="eta") 96 | # tokenizer 97 | parser.add_argument( 98 | "--tokenizer", 99 | type=str, 100 | default="openai/clip-vit-large-patch14", 101 | help="tokenizer", 102 | ) 103 | # prompt 104 | parser.add_argument( 105 | "--prompt", type=str, default="prompts.txt", help="prompt", 106 | ) 107 | # img2img params 108 | parser.add_argument( 109 | "--init-image", type=str, default=None, help="path to initial image" 110 | ) 111 | parser.add_argument( 112 | "--strength", 113 | type=float, 114 | default=0.5, 115 | help="how strong the initial image should be noised [0.0, 1.0]", 116 | ) 117 | # inpainting 118 | parser.add_argument( 119 | "--mask", 120 | type=str, 121 | default=None, 122 | help="mask of the region to inpaint on the initial image", 123 | ) 124 | # output name 125 | parser.add_argument( 126 | "--output", type=str, default="output.png", help="output image name" 127 | ) 128 | args = parser.parse_args() 129 | main(args) 130 | -------------------------------------------------------------------------------- /export.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler 3 | import torch 4 | from transformers import CLIPTextModel 5 | from alfred import logger 6 | from torch import nn 7 | from stablefusion.clip_textmodel import CLIPTextModelTracable 8 | from stablefusion.unet_2d_condition import UNet2DConditionModelTracable 9 | 10 | 11 | text_encoder = CLIPTextModelTracable.from_pretrained( 12 | "weights/stable-diffusion-v1-4/text_encoder", return_dict=False 13 | ) 14 | 15 | torch.manual_seed(42) 16 | lms = LMSDiscreteScheduler( 17 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" 18 | ) 19 | 20 | pipe = StableDiffusionPipeline.from_pretrained( 21 | "weights/stable-diffusion-v1-4", scheduler=lms, use_auth_token=False 22 | ) 23 | 24 | 25 | def convert_to_onnx( 26 | unet, post_quant_conv, decoder, text_encoder, height=512, width=512 27 | ): 28 | """Convert given input models to onnx files. 29 | unet: UNet2DConditionModel 30 | post_quant_conv: AutoencoderKL.post_quant_conv 31 | decoder: AutoencoderKL.decoder 32 | text_encoder: CLIPTextModel 33 | feature_extractor: TODO 34 | safetychecker: TODO 35 | height: Int 36 | width: Int 37 | Note: 38 | - opset_version required is 15 for CLIPTextModel 39 | """ 40 | p = Path("weights/onnx/") 41 | p.mkdir(parents=True, exist_ok=True) 42 | 43 | if height % 8 != 0 or width % 8 != 0: 44 | raise ValueError( 45 | f"`height` and `width` have to be divisible by 8 but are {height} and {width}." 46 | ) 47 | h, w = height // 8, width // 8 48 | # unet onnx export 49 | check_inputs = [ 50 | ( 51 | torch.rand(2, 4, h, w), 52 | torch.tensor([980], dtype=torch.long), 53 | torch.rand(2, 77, 768), 54 | ), 55 | ( 56 | torch.rand(2, 4, h, w), 57 | torch.tensor([910], dtype=torch.long), 58 | torch.rand(2, 12, 768), 59 | ), # batch change, text embed with no trunc 60 | ] 61 | # traced_model = torch.jit.trace( 62 | # unet, check_inputs[0], check_inputs=[check_inputs[1]], strict=True 63 | # ) 64 | # torch.onnx.export( 65 | # # traced_model, 66 | # unet, 67 | # check_inputs[0], 68 | # p / "unet.onnx", 69 | # input_names=["latent_model_input", "t", "encoder_hidden_states"], 70 | # dynamic_axes={ 71 | # "latent_model_input": [0], 72 | # "t": [0], 73 | # "encoder_hidden_states": [0, 1], 74 | # }, 75 | # opset_version=12, 76 | # ) 77 | # logger.info("unet saved.") 78 | 79 | # post_quant_conv onnx export 80 | check_inputs = [(torch.rand(1, 4, h, w),), (torch.rand(2, 4, h, w),)] 81 | traced_model = torch.jit.trace( 82 | post_quant_conv, check_inputs[0], check_inputs=[check_inputs[1]] 83 | ) 84 | torch.onnx.export( 85 | traced_model, 86 | check_inputs[0], 87 | p / "vae_encoder.onnx", 88 | input_names=["init_image"], 89 | dynamic_axes={"init_image": [0]}, 90 | opset_version=12, 91 | ) 92 | 93 | # decoder onnx export 94 | check_inputs = [(torch.rand(1, 4, h, w),), (torch.rand(2, 4, h, w),)] 95 | traced_model = torch.jit.trace( 96 | decoder, check_inputs[0], check_inputs=[check_inputs[1]] 97 | ) 98 | torch.onnx.export( 99 | traced_model, 100 | check_inputs[0], 101 | p / "vae_decoder.onnx", 102 | input_names=["latents"], 103 | dynamic_axes={"latents": [0]}, 104 | opset_version=12, 105 | ) 106 | logger.info("vae decoder saved.") 107 | 108 | # encoder onnx export 109 | check_inputs = [ 110 | (torch.randint(1, 24000, (1, 77)),), 111 | (torch.randint(1, 24000, (2, 77)),), 112 | ] 113 | 114 | traced_model = torch.jit.trace( 115 | text_encoder, check_inputs[0], check_inputs=[check_inputs[1]], strict=False 116 | ) 117 | torch.onnx.export( 118 | # traced_model, 119 | text_encoder, 120 | check_inputs[0], 121 | p / "text_encoder.onnx", 122 | input_names=["tokens"], 123 | dynamic_axes={"tokens": [0, 1]}, 124 | opset_version=12, 125 | ) 126 | logger.info("vae encoder saved.") 127 | 128 | 129 | # Change height and width to create ONNX model file for that size 130 | convert_to_onnx( 131 | pipe.unet, 132 | pipe.vae.post_quant_conv, 133 | pipe.vae.decoder, 134 | text_encoder, 135 | height=512, 136 | width=512, 137 | ) 138 | -------------------------------------------------------------------------------- /export_unet.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | from diffusers import UNet2DConditionModel 4 | from onnxsim import simplify 5 | 6 | ''' 7 | unet is the most big part in stable fusion 8 | so we make it runing under tensorrt, it might 9 | have a best speed on GPU 10 | ''' 11 | 12 | unet = UNet2DConditionModel.from_pretrained( 13 | "weights/stable-diffusion-v1-4", 14 | torch_dtype=torch.float16, 15 | revision="fp16", 16 | subfolder="unet", 17 | # use_auth_token=YOUR_TOKEN, 18 | ) 19 | unet.cuda() 20 | 21 | with torch.inference_mode(), torch.autocast("cuda"): 22 | inputs = ( 23 | torch.randn(1, 4, 64, 64, dtype=torch.half, device="cuda"), 24 | torch.randn(1, dtype=torch.half, device="cuda"), 25 | torch.randn(1, 77, 768, dtype=torch.half, device="cuda"), 26 | ) 27 | 28 | save_f = 'unet_v1_4_fp16_pytorch.onnx' 29 | save_sim_f = 'sim_unet_v1_4_fp16_pytorch.onnx' 30 | 31 | # Export the model 32 | torch.onnx.export( 33 | unet, # model being run 34 | inputs, # model input (or a tuple for multiple inputs) 35 | save_f, # where to save the model (can be a file or file-like object) 36 | export_params=True, # store the trained parameter weights inside the model file 37 | opset_version=12, # the ONNX version to export the model to 38 | do_constant_folding=True, # whether to execute constant folding for optimization 39 | input_names=["input_0", "input_1", "input_2"], 40 | output_names=["output_0"], 41 | ) 42 | 43 | sim_model, check = simplify(save_f) 44 | onnx.save(sim_model, save_sim_f) 45 | print('model saved') 46 | -------------------------------------------------------------------------------- /prompts.txt: -------------------------------------------------------------------------------- 1 | A gaint tiger standing on a train, some people fighting with it. 2 | A green dog with white tail standing on the roof. 3 | Beautiful girl with extremly detail, red hair and sexy. 4 | AI girl in blue skin, have wings can fly. 5 | A yellow fox and a lonely child on a tiny planet, there is a lake on it. 6 | A man with a gaint glass standing on the cave, holding a torch, talking the angry people. 7 | A red tv in front of sofa, a child is looking at it with a dog. 8 | -------------------------------------------------------------------------------- /prompts/dream.txt: -------------------------------------------------------------------------------- 1 | Iron man in green armor holding Captain America shield, armor flashing, 4k, very detailed, the background is avengers building. 2 | A Hulk with blue skin, smashing on Iron man, very detailed, 4k -------------------------------------------------------------------------------- /prompts/people.txt: -------------------------------------------------------------------------------- 1 | portrait photo of a asia old warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes, 50mm portrait photography, hard rim lighting photography 2 | Keanu Reeves portrait photo of a asia old warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes, 50mm portrait photography, hard rim lighting photography 3 | a vibrant professional studio portrait photography of a young, pale, goth, attractive, friendly, casual, delightful, intricate, gorgeous, female, piercing green eyes, wears a gold ankh necklace, femme fatale, nouveau, curated collection, annie leibovitz, nikon, award winning, breathtaking, groundbreaking, superb, outstanding, lensculture portrait awards, photoshopped, dramatic lighting, 8 k, hi res –testp –ar 3:4 –upbeta 4 | gorgeous young Swiss girl sitting by window with headphones on, wearing white bra with translucent shirt over, soft lips, beach blonde hair, octane render, unreal engine, photograph, realistic skin texture, photorealistic, hyper realism, highly detailed, 85mm portrait photography, award winning, hard rim lighting photography–beta –ar 9:16 –s 5000 –testp –upbeta –upbeta –upbeta 5 | city made out of glass : : close shot : : 3 5 mm, realism, octane render, 8 k, exploration, cinematic, trending on artstation, realistic, 3 5 mm camera, unreal engine, hyper detailed, photo – realistic maximum detail, volumetric light, moody cinematic epic concept art, realistic matte painting, hyper photorealistic, concept art, volumetric light, cinematic epic, octane render, 8 k, corona render, movie concept art, octane render, 8 k, corona render, cinematic, trending on artstation, movie concept art, cinematic composition, ultra – detailed, realistic, hyper – realistic, volumetric lighting, 8 k 6 | a cute magical flying dog, fantasy art drawn by disney concept artists, golden colour, high quality, highly detailed, elegant, sharp focus, concept art, character concepts, digital painting, mystery, adventure 7 | -------------------------------------------------------------------------------- /prompts/wallpaper.txt: -------------------------------------------------------------------------------- 1 | white horses riding alongside river, sunset, 4k, woods, extremly detailed, horses with green hair 2 | dinosaurs roaming the earth during the Jurassic era, detailed, realistic, floogy. 3 | A tiger elephant chasing a monkey in Africa, 4k, extremly detailed, realistic -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # GaintModels 2 | 3 | Experiements on testing GaintModels such as GPT3, StableFusion. We offer TensorRT && Int8 quantization on those gaint models. Make you can inference on a 6GB below GPU mem card! 4 | 5 | 6 | ## Install 7 | 8 | Some requirements to install: 9 | 10 | ``` 11 | pip install diffusers 12 | pip install transformers 13 | pip install alfred-py 14 | ``` 15 | 16 | 17 | ## Models 18 | 19 | 20 | 1. `StableFusion`: 21 | 22 | **update:** 23 | 24 | Now the best way to accelerate StableFusion is using unet TensorRT, keep others in torch (their time is not critical). 25 | to export unet to onnx, run `python export_unet.py`. 26 | 27 | Then you will have unet onnx. using `trtexec --onnx=unet_v1_4_fp16_pytorch_sim.onnx --fp16 --saveEngine=unet_fp16.trt` convert to fp16 trt engine. 28 | 29 | Then you can run with trt unet: 30 | 31 | ``` 32 | python demo_part.py --trt 33 | ``` 34 | 35 | 36 | First, we need download stablefusion weights from hugging face. 37 | 38 | ``` 39 | git clone https://huggingface.co/CompVis/stable-diffusion-v1-4 40 | git lfs install 41 | cd stable-diffusion-v1-4 42 | git lfs pull 43 | ``` 44 | 45 | You should downloading weights using `git lfs` large file system, the model about `3GB`. 46 | 47 | To make `unet_2d_condition` in stablefusion able to export to onnx, make some modification on `diffusers`, following: [link](https://github.com/harishanand95/diffusers/commit/8dd4e822f87e1b4259755a2181218797ceecc410) 48 | 49 | file: `diffuers/models/unet_2d_conditions.py` 50 | 51 | ``` 52 | # L137 53 | timesteps = timesteps.broadcast_to(sample.shape[0]) 54 | #timesteps = timesteps.broadcast_to(sample.shape[0]) 55 | timesteps = timesteps * torch.ones(sample.shape[0]) 56 | 57 | output = {"sample": sample} 58 | #output = {"sample": sample} 59 | 60 | return output 61 | return sample 62 | ``` 63 | 64 | After that, move `stable-diffusion-v1-4` to `weights` folder. Run: 65 | 66 | ``` 67 | python export.py 68 | ``` 69 | 70 | To generate onnx models. -------------------------------------------------------------------------------- /stablefusion/.gitignore: -------------------------------------------------------------------------------- 1 | vendor/ 2 | -------------------------------------------------------------------------------- /stablefusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luohao123/gaintmodels/d9b389a3e781fdeafc7f695c6ce021d5c9ceebbb/stablefusion/__init__.py -------------------------------------------------------------------------------- /stablefusion/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luohao123/gaintmodels/d9b389a3e781fdeafc7f695c6ce021d5c9ceebbb/stablefusion/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /stablefusion/__pycache__/clip_textmodel.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luohao123/gaintmodels/d9b389a3e781fdeafc7f695c6ce021d5c9ceebbb/stablefusion/__pycache__/clip_textmodel.cpython-38.pyc -------------------------------------------------------------------------------- /stablefusion/__pycache__/stablefusion_ov_engine.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luohao123/gaintmodels/d9b389a3e781fdeafc7f695c6ce021d5c9ceebbb/stablefusion/__pycache__/stablefusion_ov_engine.cpython-38.pyc -------------------------------------------------------------------------------- /stablefusion/ait_modeling/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | 16 | """ 17 | Implementations are translated from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py. 18 | """ 19 | 20 | from typing import Optional 21 | 22 | from aitemplate.compiler.ops import reshape 23 | 24 | from aitemplate.frontend import nn, Tensor 25 | 26 | 27 | class AttentionBlock(nn.Module): 28 | """ 29 | An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted 30 | to the N-d case. 31 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. 32 | Uses three q, k, v linear layers to compute attention. 33 | Parameters: 34 | batch_size (:obj:`int`): The number of examples per batch. 35 | height (:obj:`int`): Height of each image example. 36 | width (:obj:`int`): Width of each image example. 37 | channels (:obj:`int`): The number of channels in the input and output. 38 | num_head_channels (:obj:`int`, *optional*): 39 | The number of channels in each head. If None, then `num_heads` = 1. 40 | num_groups (:obj:`int`, *optional*, defaults to 32): The number of groups to use for group norm. 41 | eps (:obj:`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm. 42 | """ 43 | 44 | def __init__( 45 | self, 46 | batch_size: int, 47 | height: int, 48 | width: int, 49 | channels: int, 50 | num_head_channels: Optional[int] = None, 51 | num_groups: int = 32, 52 | rescale_output_factor: float = 1.0, 53 | eps: float = 1e-5, 54 | ): 55 | super().__init__() 56 | self.batch_size = batch_size 57 | self.height = height 58 | self.width = width 59 | self.channels = channels 60 | self.num_heads = ( 61 | channels // num_head_channels if num_head_channels is not None else 1 62 | ) 63 | self.num_head_size = num_head_channels 64 | self.group_norm = nn.GroupNorm(num_groups, channels, eps) 65 | self.attention = nn.MultiheadAttention( 66 | channels, 67 | batch_size, 68 | height * width, 69 | self.num_heads, 70 | qkv_bias=True, 71 | has_residual=True, 72 | ) 73 | self.rescale_output_factor = rescale_output_factor 74 | 75 | def forward(self, hidden_states) -> Tensor: 76 | """ 77 | input hidden_states shape: [batch, height, width, channel] 78 | output shape: [batch, height, width, channel] 79 | """ 80 | residual = hidden_states 81 | 82 | # norm 83 | hidden_states = self.group_norm(hidden_states) 84 | 85 | hidden_states = reshape()( 86 | hidden_states, [self.batch_size, self.height * self.width, self.channels] 87 | ) 88 | 89 | batch, hw, channel = hidden_states.shape() 90 | if ( 91 | batch.value() != self.batch_size 92 | or hw.value() != self.width * self.height 93 | or channel.value() != self.channels 94 | ): 95 | raise RuntimeError( 96 | "nchw params do not match! " 97 | f"Expected: {self.batch_size}, {self.channels}, {self.height} * {self.width}, " 98 | f"actual: {batch}, {channel}, {hw}." 99 | ) 100 | 101 | res = self.attention(hidden_states, residual) * (1 / self.rescale_output_factor) 102 | res = reshape()(res, [self.batch_size, self.height, self.width, self.channels]) 103 | 104 | return res 105 | -------------------------------------------------------------------------------- /stablefusion/ait_modeling/clip.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | from inspect import isfunction 16 | from typing import Optional 17 | 18 | from aitemplate.compiler import ops 19 | from aitemplate.frontend import nn, Tensor 20 | from aitemplate.testing import detect_target 21 | 22 | # pylint: disable=W0102 23 | 24 | USE_CUDA = detect_target().name() == "cuda" 25 | 26 | 27 | def get_shape(x): 28 | shape = [it.value() for it in x._attrs["shape"]] 29 | return shape 30 | 31 | 32 | def exists(val): 33 | return val is not None 34 | 35 | 36 | def default(val, d): 37 | if exists(val): 38 | return val 39 | return d() if isfunction(d) else d 40 | 41 | 42 | class CrossAttention(nn.Module): 43 | def __init__( 44 | self, 45 | query_dim, 46 | context_dim=None, 47 | heads=8, 48 | dim_head=64, 49 | dropout=0.0, 50 | dtype="float16", 51 | ): 52 | super().__init__() 53 | inner_dim = dim_head * heads 54 | context_dim = default(context_dim, query_dim) 55 | 56 | self.scale = dim_head**-0.5 57 | self.heads = heads 58 | self.dim_head = dim_head 59 | 60 | self.to_q_weight = nn.Parameter(shape=[inner_dim, query_dim], dtype=dtype) 61 | self.to_k_weight = nn.Parameter(shape=[inner_dim, context_dim], dtype=dtype) 62 | self.to_v_weight = nn.Parameter(shape=[inner_dim, context_dim], dtype=dtype) 63 | self.to_out = nn.Sequential( 64 | nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) 65 | ) 66 | 67 | def forward(self, x, context=None, mask=None, residual=None): 68 | nheads = self.heads 69 | d = self.dim_head 70 | 71 | layout = "20314" if USE_CUDA else "m2n3" 72 | 73 | bs, seqlen, _ = get_shape(x) 74 | q = ops.gemm_rcr_permute(shape=(seqlen, 1, nheads), layout=layout)( 75 | ops.reshape()(x, [bs * seqlen, -1]), self.to_q_weight.tensor() 76 | ) 77 | context = default(context, x) 78 | 79 | seqlen = get_shape(context)[1] 80 | k = ops.gemm_rcr_permute(shape=(seqlen, 1, nheads), layout=layout)( 81 | ops.reshape()(context, [bs * seqlen, -1]), self.to_k_weight.tensor() 82 | ) 83 | v = ops.gemm_rcr_permute(shape=(seqlen, 1, nheads), layout=layout)( 84 | ops.reshape()(context, [bs * seqlen, -1]), self.to_v_weight.tensor() 85 | ) 86 | 87 | if USE_CUDA: 88 | q = q * self.scale 89 | attn = ops.bmm_rcr()( 90 | (ops.reshape()(q, [bs * nheads, -1, d])), 91 | (ops.reshape()(k, [bs * nheads, -1, d])), 92 | ) 93 | attn = ops.softmax()(attn, -1) 94 | v = ops.reshape()(v, [bs * nheads, -1, d]) 95 | out = ops.bmm_rrr_permute((nheads,))(attn, v) 96 | else: 97 | OP = ops.bmm_softmax_bmm_permute(shape=(nheads,), scale=self.scale) 98 | out = OP( 99 | (ops.reshape()(q, [bs * nheads, -1, d])), 100 | (ops.reshape()(k, [bs * nheads, -1, d])), 101 | (ops.reshape()(v, [bs * nheads, -1, d])), 102 | ) 103 | out = ops.reshape()(out, [bs, -1, nheads * d]) 104 | proj = self.to_out(out) 105 | proj = ops.reshape()(proj, [bs, -1, nheads * d]) 106 | if residual is not None: 107 | return proj + residual 108 | else: 109 | return proj 110 | 111 | 112 | class GEGLU(nn.Module): 113 | def __init__(self, dim_in, dim_out): 114 | super().__init__() 115 | self.proj = nn.Linear(dim_in, dim_out, specialization="mul") 116 | self.gate = nn.Linear(dim_in, dim_out, specialization="fast_gelu") 117 | 118 | def forward(self, x): 119 | return self.proj(x, self.gate(x)) 120 | 121 | 122 | class FeedForward(nn.Module): 123 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): 124 | super().__init__() 125 | inner_dim = int(dim * mult) 126 | dim_out = default(dim_out, dim) 127 | project_in = ( 128 | nn.Sequential( 129 | nn.Linear(dim, inner_dim, specialization="fast_gelu"), 130 | ) 131 | if not glu 132 | else GEGLU(dim, inner_dim) 133 | ) 134 | 135 | self.net = nn.Sequential( 136 | project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) 137 | ) 138 | 139 | def forward(self, x, residual=None): 140 | shape = ops.size()(x) 141 | x = self.net(x) 142 | x = ops.reshape()(x, shape) 143 | if residual is not None: 144 | return x + residual 145 | else: 146 | return x 147 | 148 | 149 | class BasicTransformerBlock(nn.Module): 150 | def __init__( 151 | self, 152 | dim, 153 | n_heads, 154 | d_head, 155 | dropout=0.0, 156 | context_dim=None, 157 | gated_ff=True, 158 | checkpoint=True, 159 | ): 160 | super().__init__() 161 | self.attn1 = CrossAttention( 162 | query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout 163 | ) # is a self-attention 164 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) 165 | self.attn2 = CrossAttention( 166 | query_dim=dim, 167 | context_dim=context_dim, 168 | heads=n_heads, 169 | dim_head=d_head, 170 | dropout=dropout, 171 | ) 172 | self.norm1 = nn.LayerNorm(dim) 173 | self.norm2 = nn.LayerNorm(dim) 174 | self.norm3 = nn.LayerNorm(dim) 175 | self.checkpoint = checkpoint 176 | 177 | self.param = (dim, n_heads, d_head, context_dim, gated_ff, checkpoint) 178 | 179 | def forward(self, x, context=None): 180 | x = self.attn1(self.norm1(x), residual=x) 181 | x = self.attn2(self.norm2(x), context=context, residual=x) 182 | x = self.ff(self.norm3(x), residual=x) 183 | return x 184 | 185 | 186 | def Normalize(in_channels): 187 | return nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 188 | 189 | 190 | class SpatialTransformer(nn.Module): 191 | """ 192 | Transformer block for image-like data. 193 | First, project the input (aka embedding) 194 | and reshape to b, t, d. 195 | Then apply standard transformer action. 196 | Finally, reshape to image 197 | """ 198 | 199 | def __init__( 200 | self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None 201 | ): 202 | super().__init__() 203 | self.in_channels = in_channels 204 | inner_dim = n_heads * d_head 205 | self.norm = Normalize(in_channels) # Group Norm 206 | 207 | self.proj_in = nn.Conv2dBias( 208 | in_channels, inner_dim, kernel_size=1, stride=1, padding=0 209 | ) 210 | 211 | self.transformer_blocks = nn.ModuleList( 212 | [ 213 | BasicTransformerBlock( 214 | inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim 215 | ) 216 | for d in range(depth) 217 | ] 218 | ) 219 | 220 | self.proj_out = nn.Conv2dBias( 221 | inner_dim, in_channels, kernel_size=1, stride=1, padding=0 222 | ) 223 | 224 | def forward(self, x, context=None): 225 | # note: if no context is given, cross-attention defaults to self-attention 226 | b, h, w, c = get_shape(x) 227 | x_in = x 228 | x = self.norm(x) 229 | x = self.proj_in(x) 230 | x = ops.reshape()(x, [b, -1, c]) 231 | for block in self.transformer_blocks: 232 | x = block(x, context=context) 233 | x = ops.reshape()(x, [b, h, w, c]) 234 | x = self.proj_out(x) 235 | return x + x_in 236 | 237 | 238 | class CLIPAttention(nn.Module): 239 | """Multi-headed attention from 'Attention Is All You Need' paper""" 240 | 241 | def __init__( 242 | self, 243 | hidden_size=768, 244 | num_attention_heads=12, 245 | attention_dropout=0.0, 246 | batch_size=1, 247 | seq_len=16, 248 | layer_norm_eps=1e-5, 249 | hidden_dropout_prob=0.0, 250 | causal=False, 251 | mask_seq=0, 252 | ): 253 | super().__init__() 254 | self.attn = nn.MultiheadAttention( 255 | dim=hidden_size, 256 | batch_size=batch_size, 257 | seq_len=seq_len, 258 | num_heads=num_attention_heads, 259 | qkv_bias=True, 260 | attn_drop=attention_dropout, 261 | proj_drop=hidden_dropout_prob, 262 | has_residual=False, 263 | causal=causal, 264 | mask_seq=mask_seq, 265 | ) 266 | 267 | def forward( 268 | self, 269 | hidden_states: Tensor, 270 | attention_mask: Optional[Tensor] = None, 271 | causal_attention_mask: Optional[Tensor] = None, 272 | output_attentions: Optional[bool] = False, 273 | residual: Optional[Tensor] = None, 274 | ): 275 | if residual is not None: 276 | self_output = self.attn(hidden_states, residual) 277 | else: 278 | self_output = self.attn(hidden_states) 279 | return self_output 280 | 281 | 282 | class QuickGELUActivation(nn.Module): 283 | """ 284 | Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs 285 | """ 286 | 287 | def forward(self, x): 288 | x1 = x * 1.702 289 | x1 = ops.sigmoid(x1) 290 | x = x * x1 291 | return x 292 | 293 | 294 | class CLIPMLP(nn.Module): 295 | """MLP as used in Vision Transformer, MLP-Mixer and related networks""" 296 | 297 | def __init__( 298 | self, 299 | in_features, 300 | hidden_features=None, 301 | out_features=None, 302 | act_layer="GELU", 303 | drop=0, 304 | ): 305 | super().__init__() 306 | out_features = out_features or in_features 307 | hidden_features = hidden_features or in_features 308 | 309 | self.fc1 = nn.Linear( 310 | in_features, 311 | hidden_features, 312 | ) 313 | self.activation_fn = QuickGELUActivation() 314 | self.fc2 = nn.Linear(hidden_features, out_features, specialization="add") 315 | 316 | def forward(self, x, res): 317 | shape = get_shape(x) 318 | x = self.fc1(x) 319 | x = self.activation_fn(x) 320 | x = self.fc2(x, res) 321 | return ops.reshape()(x, shape) 322 | 323 | 324 | class CLIPEncoderLayer(nn.Module): 325 | def __init__( 326 | self, 327 | hidden_size=768, 328 | num_attention_heads=12, 329 | attention_dropout=0.0, 330 | mlp_ratio=4.0, 331 | batch_size=1, 332 | seq_len=16, 333 | causal=False, 334 | mask_seq=0, 335 | ): 336 | super().__init__() 337 | self.embed_dim = hidden_size 338 | self.self_attn = nn.MultiheadAttention( 339 | dim=hidden_size, 340 | batch_size=batch_size, 341 | seq_len=seq_len, 342 | num_heads=num_attention_heads, 343 | qkv_bias=True, 344 | attn_drop=attention_dropout, 345 | proj_drop=0, 346 | has_residual=True, 347 | causal=causal, 348 | mask_seq=mask_seq, 349 | ) 350 | self.layer_norm1 = nn.LayerNorm(self.embed_dim) 351 | self.mlp = CLIPMLP(hidden_size, int(hidden_size * mlp_ratio)) 352 | self.layer_norm2 = nn.LayerNorm(self.embed_dim) 353 | 354 | def forward( 355 | self, 356 | hidden_states: Tensor, 357 | output_attentions: Optional[bool] = False, 358 | ): 359 | """ 360 | Args: 361 | hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` 362 | attention_mask (`torch.FloatTensor`): attention mask of size 363 | `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. 364 | `(config.encoder_attention_heads,)`. 365 | output_attentions (`bool`, *optional*): 366 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 367 | returned tensors for more detail. 368 | """ 369 | residual = hidden_states 370 | 371 | hidden_states = self.layer_norm1(hidden_states) 372 | hidden_states = self.self_attn(hidden_states, residual) 373 | 374 | residual = hidden_states 375 | hidden_states = self.layer_norm2(hidden_states) 376 | hidden_states = self.mlp(hidden_states, residual) 377 | 378 | return hidden_states 379 | 380 | 381 | class CLIPEncoder(nn.Module): 382 | """ 383 | Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a 384 | [`CLIPEncoderLayer`]. 385 | Args: 386 | config: CLIPConfig 387 | """ 388 | 389 | def __init__( 390 | self, 391 | num_hidden_layers=12, 392 | output_attentions=False, 393 | output_hidden_states=False, 394 | use_return_dict=False, 395 | hidden_size=768, 396 | num_attention_heads=12, 397 | batch_size=1, 398 | seq_len=64, 399 | causal=False, 400 | mask_seq=0, 401 | ): 402 | super().__init__() 403 | self.layers = nn.ModuleList( 404 | [ 405 | CLIPEncoderLayer( 406 | hidden_size=hidden_size, 407 | num_attention_heads=num_attention_heads, 408 | batch_size=batch_size, 409 | seq_len=seq_len, 410 | causal=causal, 411 | mask_seq=mask_seq, 412 | ) 413 | for _ in range(num_hidden_layers) 414 | ] 415 | ) 416 | self.output_attentions = output_attentions 417 | self.output_hidden_states = output_hidden_states 418 | self.use_return_dict = use_return_dict 419 | 420 | def forward( 421 | self, 422 | inputs_embeds, 423 | attention_mask: Optional[Tensor] = None, 424 | causal_attention_mask: Optional[Tensor] = None, 425 | output_attentions: Optional[bool] = None, 426 | output_hidden_states: Optional[bool] = None, 427 | return_dict: Optional[bool] = None, 428 | ): 429 | r""" 430 | Args: 431 | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): 432 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. 433 | This is useful if you want more control over how to convert `input_ids` indices into associated vectors 434 | than the model's internal embedding lookup matrix. 435 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): 436 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 437 | - 1 for tokens that are **not masked**, 438 | - 0 for tokens that are **masked**. 439 | [What are attention masks?](../glossary#attention-mask) 440 | causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): 441 | Causal mask for the text model. Mask values selected in `[0, 1]`: 442 | - 1 for tokens that are **not masked**, 443 | - 0 for tokens that are **masked**. 444 | [What are attention masks?](../glossary#attention-mask) 445 | output_attentions (`bool`, *optional*): 446 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 447 | returned tensors for more detail. 448 | output_hidden_states (`bool`, *optional*): 449 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors 450 | for more detail. 451 | return_dict (`bool`, *optional*): 452 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 453 | """ 454 | output_attentions = ( 455 | output_attentions 456 | if output_attentions is not None 457 | else self.output_attentions 458 | ) 459 | output_hidden_states = ( 460 | output_hidden_states 461 | if output_hidden_states is not None 462 | else self.output_hidden_states 463 | ) 464 | return_dict = return_dict if return_dict is not None else self.use_return_dict 465 | 466 | encoder_states = () if output_hidden_states else None 467 | # all_attentions = () if output_attentions else None 468 | 469 | hidden_states = inputs_embeds 470 | for _, encoder_layer in enumerate(self.layers): 471 | if output_hidden_states: 472 | encoder_states = encoder_states + (hidden_states,) 473 | layer_outputs = encoder_layer(hidden_states) 474 | hidden_states = layer_outputs 475 | 476 | return hidden_states 477 | 478 | 479 | class CLIPTextEmbeddings(nn.Module): 480 | def __init__( 481 | self, 482 | hidden_size=768, 483 | vocab_size=49408, 484 | max_position_embeddings=77, 485 | dtype="float16", 486 | ): 487 | super().__init__() 488 | embed_dim = hidden_size 489 | 490 | self.token_embedding = nn.Embedding(shape=[vocab_size, embed_dim], dtype=dtype) 491 | self.position_embedding = nn.Embedding( 492 | shape=[max_position_embeddings, embed_dim], dtype=dtype 493 | ) 494 | 495 | def forward( 496 | self, 497 | input_ids: Tensor, 498 | position_ids: Tensor, 499 | inputs_embeds: Optional[Tensor] = None, 500 | ) -> Tensor: 501 | 502 | input_shape = ops.size()(input_ids) 503 | 504 | # [B * S] 505 | input_ids = ops.reshape()(input_ids, [-1]) 506 | 507 | position_ids = ops.reshape()(position_ids, [-1]) 508 | 509 | if inputs_embeds is None: 510 | inputs_embeds = ops.batch_gather()(self.token_embedding.tensor(), input_ids) 511 | 512 | position_embeddings = ops.batch_gather()( 513 | self.position_embedding.tensor(), position_ids 514 | ) 515 | 516 | embeddings = inputs_embeds + position_embeddings 517 | 518 | embeddings = ops.reshape()(embeddings, [input_shape[0], input_shape[1], -1]) 519 | 520 | return embeddings 521 | 522 | 523 | class CLIPTextTransformer(nn.Module): 524 | def __init__( 525 | self, 526 | hidden_size=768, 527 | output_attentions=False, 528 | output_hidden_states=False, 529 | use_return_dict=False, 530 | num_hidden_layers=12, 531 | num_attention_heads=12, 532 | batch_size=1, 533 | seq_len=64, 534 | causal=False, 535 | mask_seq=0, 536 | ): 537 | super().__init__() 538 | embed_dim = hidden_size 539 | self.embeddings = CLIPTextEmbeddings() 540 | self.encoder = CLIPEncoder( 541 | num_hidden_layers=num_hidden_layers, 542 | hidden_size=hidden_size, 543 | num_attention_heads=num_attention_heads, 544 | batch_size=batch_size, 545 | seq_len=seq_len, 546 | causal=causal, 547 | mask_seq=mask_seq, 548 | ) 549 | self.final_layer_norm = nn.LayerNorm(embed_dim) 550 | 551 | self.output_attentions = output_attentions 552 | self.output_hidden_states = output_hidden_states 553 | self.use_return_dict = use_return_dict 554 | 555 | def forward( 556 | self, 557 | input_ids: Optional[Tensor] = None, 558 | attention_mask: Optional[Tensor] = None, 559 | position_ids: Optional[Tensor] = None, 560 | output_attentions: Optional[bool] = None, 561 | output_hidden_states: Optional[bool] = None, 562 | return_dict: Optional[bool] = None, 563 | ): 564 | r""" 565 | Returns: 566 | """ 567 | output_attentions = ( 568 | output_attentions 569 | if output_attentions is not None 570 | else self.output_attentions 571 | ) 572 | output_hidden_states = ( 573 | output_hidden_states 574 | if output_hidden_states is not None 575 | else self.output_hidden_states 576 | ) 577 | return_dict = return_dict if return_dict is not None else self.use_return_dict 578 | 579 | if input_ids is None: 580 | raise ValueError("You have to specify either input_ids") 581 | 582 | hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids) 583 | 584 | encoder_outputs = self.encoder( 585 | inputs_embeds=hidden_states, 586 | ) 587 | 588 | last_hidden_state = encoder_outputs 589 | last_hidden_state = self.final_layer_norm(last_hidden_state) 590 | return last_hidden_state 591 | -------------------------------------------------------------------------------- /stablefusion/ait_modeling/embeddings.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | import math 16 | 17 | from aitemplate.compiler import ops 18 | from aitemplate.frontend import nn, Tensor 19 | 20 | 21 | def get_shape(x): 22 | shape = [it.value() for it in x._attrs["shape"]] 23 | return shape 24 | 25 | 26 | def get_timestep_embedding( 27 | timesteps: Tensor, 28 | embedding_dim: int, 29 | flip_sin_to_cos: bool = False, 30 | downscale_freq_shift: float = 1, 31 | scale: float = 1, 32 | max_period: int = 10000, 33 | ): 34 | """ 35 | This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. 36 | 37 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 38 | These may be fractional. 39 | :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the 40 | embeddings. :return: an [N x dim] Tensor of positional embeddings. 41 | """ 42 | assert len(get_shape(timesteps)) == 1, "Timesteps should be a 1d-array" 43 | 44 | half_dim = embedding_dim // 2 45 | 46 | exponent = (-math.log(max_period)) * Tensor( 47 | shape=[half_dim], dtype="float16", name="arange" 48 | ) 49 | 50 | exponent = exponent * (1.0 / (half_dim - downscale_freq_shift)) 51 | 52 | emb = ops.exp(exponent) 53 | emb = ops.reshape()(timesteps, [-1, 1]) * ops.reshape()(emb, [1, -1]) 54 | 55 | # scale embeddings 56 | emb = scale * emb 57 | 58 | # concat sine and cosine embeddings 59 | if flip_sin_to_cos: 60 | emb = ops.concatenate()( 61 | [ops.cos(emb), ops.sin(emb)], 62 | dim=-1, 63 | ) 64 | else: 65 | emb = ops.concatenate()( 66 | [ops.sin(emb), ops.cos(emb)], 67 | dim=-1, 68 | ) 69 | return emb 70 | 71 | 72 | class TimestepEmbedding(nn.Module): 73 | def __init__(self, channel: int, time_embed_dim: int, act_fn: str = "silu"): 74 | super().__init__() 75 | 76 | self.linear_1 = nn.Linear(channel, time_embed_dim, specialization="swish") 77 | self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim) 78 | 79 | def forward(self, sample): 80 | sample = self.linear_1(sample) 81 | sample = self.linear_2(sample) 82 | return sample 83 | 84 | 85 | class Timesteps(nn.Module): 86 | def __init__( 87 | self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float 88 | ): 89 | super().__init__() 90 | self.num_channels = num_channels 91 | self.flip_sin_to_cos = flip_sin_to_cos 92 | self.downscale_freq_shift = downscale_freq_shift 93 | 94 | def forward(self, timesteps): 95 | t_emb = get_timestep_embedding( 96 | timesteps, 97 | self.num_channels, 98 | flip_sin_to_cos=self.flip_sin_to_cos, 99 | downscale_freq_shift=self.downscale_freq_shift, 100 | ) 101 | return t_emb 102 | -------------------------------------------------------------------------------- /stablefusion/ait_modeling/resnet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | from aitemplate.compiler import ops 16 | from aitemplate.frontend import nn 17 | 18 | 19 | def get_shape(x): 20 | shape = [it.value() for it in x._attrs["shape"]] 21 | return shape 22 | 23 | 24 | class Upsample2D(nn.Module): 25 | """ 26 | An upsampling layer with an optional convolution. 27 | 28 | :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is 29 | applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 30 | upsampling occurs in the inner-two dimensions. 31 | """ 32 | 33 | def __init__( 34 | self, 35 | channels, 36 | use_conv=False, 37 | use_conv_transpose=False, 38 | out_channels=None, 39 | name="conv", 40 | ): 41 | super().__init__() 42 | self.channels = channels 43 | self.out_channels = out_channels or channels 44 | self.use_conv = use_conv 45 | self.use_conv_transpose = use_conv_transpose 46 | self.name = name 47 | 48 | conv = None 49 | if use_conv_transpose: 50 | conv = nn.ConvTranspose2dBias(channels, self.out_channels, 4, 2, 1) 51 | elif use_conv: 52 | conv = nn.Conv2dBias(self.channels, self.out_channels, 3, 1, 1) 53 | 54 | # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed 55 | if name == "conv": 56 | self.conv = conv 57 | else: 58 | self.Conv2d_0 = conv 59 | 60 | def forward(self, x): 61 | assert get_shape(x)[-1] == self.channels 62 | if self.use_conv_transpose: 63 | return self.conv(x) 64 | 65 | x = nn.Upsampling2d(scale_factor=2.0, mode="nearest")(x) 66 | 67 | # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed 68 | if self.use_conv: 69 | if self.name == "conv": 70 | x = self.conv(x) 71 | else: 72 | x = self.Conv2d_0(x) 73 | 74 | return x 75 | 76 | 77 | class Downsample2D(nn.Module): 78 | """ 79 | A downsampling layer with an optional convolution. 80 | 81 | :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is 82 | applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 83 | downsampling occurs in the inner-two dimensions. 84 | """ 85 | 86 | def __init__( 87 | self, channels, use_conv=False, out_channels=None, padding=1, name="conv" 88 | ): 89 | super().__init__() 90 | self.channels = channels 91 | self.out_channels = out_channels or channels 92 | self.use_conv = use_conv 93 | self.padding = padding 94 | stride = 2 95 | self.name = name 96 | 97 | if use_conv: 98 | conv = nn.Conv2dBias( 99 | self.channels, self.out_channels, 3, stride=stride, padding=padding 100 | ) 101 | else: 102 | assert self.channels == self.out_channels 103 | conv = nn.AvgPool2d(kernel_size=stride, stride=stride, padding=0) 104 | 105 | # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed 106 | if name == "conv": 107 | self.Conv2d_0 = conv 108 | self.conv = conv 109 | elif name == "Conv2d_0": 110 | self.conv = conv 111 | else: 112 | self.conv = conv 113 | 114 | def forward(self, x): 115 | assert get_shape(x)[-1] == self.channels 116 | x = self.conv(x) 117 | 118 | return x 119 | 120 | 121 | class ResnetBlock2D(nn.Module): 122 | def __init__( 123 | self, 124 | *, 125 | in_channels, 126 | out_channels=None, 127 | conv_shortcut=False, 128 | dropout=0.0, 129 | temb_channels=512, 130 | groups=32, 131 | groups_out=None, 132 | pre_norm=True, 133 | eps=1e-6, 134 | non_linearity="swish", 135 | time_embedding_norm="default", 136 | kernel=None, 137 | output_scale_factor=1.0, 138 | use_nin_shortcut=None, 139 | up=False, 140 | down=False, 141 | ): 142 | super().__init__() 143 | self.pre_norm = pre_norm 144 | self.pre_norm = True 145 | self.in_channels = in_channels 146 | out_channels = in_channels if out_channels is None else out_channels 147 | self.out_channels = out_channels 148 | self.use_conv_shortcut = conv_shortcut 149 | self.time_embedding_norm = time_embedding_norm 150 | self.up = up 151 | self.down = down 152 | self.output_scale_factor = output_scale_factor 153 | 154 | if groups_out is None: 155 | groups_out = groups 156 | 157 | self.norm1 = nn.GroupNorm( 158 | num_groups=groups, 159 | num_channels=in_channels, 160 | eps=eps, 161 | affine=True, 162 | use_swish=True, 163 | ) 164 | 165 | self.conv1 = nn.Conv2dBias( 166 | in_channels, out_channels, kernel_size=3, stride=1, padding=1 167 | ) 168 | 169 | if temb_channels is not None: 170 | self.time_emb_proj = nn.Linear(temb_channels, out_channels) 171 | else: 172 | self.time_emb_proj = None 173 | 174 | self.norm2 = nn.GroupNorm( 175 | num_groups=groups_out, 176 | num_channels=out_channels, 177 | eps=eps, 178 | affine=True, 179 | use_swish=True, 180 | ) 181 | self.dropout = nn.Dropout(dropout) 182 | self.conv2 = nn.Conv2dBias( 183 | out_channels, out_channels, kernel_size=3, stride=1, padding=1 184 | ) 185 | 186 | self.upsample = self.downsample = None 187 | 188 | self.use_nin_shortcut = ( 189 | self.in_channels != self.out_channels 190 | if use_nin_shortcut is None 191 | else use_nin_shortcut 192 | ) 193 | 194 | if self.use_nin_shortcut: 195 | self.conv_shortcut = nn.Conv2dBias( 196 | in_channels, out_channels, 1, 1, 0 197 | ) # kernel_size=1, stride=1, padding=0) # conv_bias_add 198 | else: 199 | self.conv_shortcut = None 200 | 201 | def forward(self, x, temb=None): 202 | hidden_states = x 203 | 204 | # make sure hidden states is in float32 205 | # when running in half-precision 206 | hidden_states = self.norm1( 207 | hidden_states 208 | ) # .float()).type(hidden_states.dtype) # fused swish 209 | # hidden_states = self.nonlinearity(hidden_states) 210 | 211 | if self.upsample is not None: 212 | x = self.upsample(x) 213 | hidden_states = self.upsample(hidden_states) 214 | elif self.downsample is not None: 215 | x = self.downsample(x) 216 | hidden_states = self.downsample(hidden_states) 217 | 218 | hidden_states = self.conv1(hidden_states) 219 | 220 | if temb is not None: 221 | temb = self.time_emb_proj(ops.silu(temb)) 222 | bs, dim = get_shape(temb) 223 | temb = ops.reshape()(temb, [bs, 1, 1, dim]) 224 | hidden_states = hidden_states + temb 225 | 226 | # make sure hidden states is in float32 227 | # when running in half-precision 228 | hidden_states = self.norm2(hidden_states) 229 | 230 | hidden_states = self.dropout(hidden_states) 231 | hidden_states = self.conv2(hidden_states) 232 | 233 | if self.conv_shortcut is not None: 234 | x = self.conv_shortcut(x) 235 | 236 | out = hidden_states + x 237 | 238 | return out 239 | -------------------------------------------------------------------------------- /stablefusion/ait_modeling/unet_2d_condition.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | from typing import Optional, Tuple 16 | 17 | from aitemplate.frontend import nn 18 | 19 | from modeling.embeddings import TimestepEmbedding, Timesteps 20 | from modeling.unet_blocks import get_down_block, get_up_block, UNetMidBlock2DCrossAttn 21 | 22 | 23 | class UNet2DConditionModel(nn.Module): 24 | r""" 25 | UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep 26 | and returns sample shaped output. 27 | 28 | This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library 29 | implements for all the model (such as downloading or saving, etc.) 30 | 31 | Parameters: 32 | sample_size (`int`, *optional*): The size of the input sample. 33 | in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. 34 | out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. 35 | center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. 36 | flip_sin_to_cos (`bool`, *optional*, defaults to `False`): 37 | Whether to flip the sin to cos in the time embedding. 38 | freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. 39 | down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): 40 | The tuple of downsample blocks to use. 41 | up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`): 42 | The tuple of upsample blocks to use. 43 | block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): 44 | The tuple of output channels for each block. 45 | layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. 46 | downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. 47 | mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. 48 | act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. 49 | norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. 50 | norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. 51 | cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features. 52 | attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. 53 | """ 54 | 55 | def __init__( 56 | self, 57 | sample_size: Optional[int] = None, 58 | in_channels: int = 4, 59 | out_channels: int = 4, 60 | center_input_sample: bool = False, 61 | flip_sin_to_cos: bool = True, 62 | freq_shift: int = 0, 63 | down_block_types: Tuple[str] = ( 64 | "CrossAttnDownBlock2D", 65 | "CrossAttnDownBlock2D", 66 | "CrossAttnDownBlock2D", 67 | "DownBlock2D", 68 | ), 69 | up_block_types: Tuple[str] = ( 70 | "UpBlock2D", 71 | "CrossAttnUpBlock2D", 72 | "CrossAttnUpBlock2D", 73 | "CrossAttnUpBlock2D", 74 | ), 75 | block_out_channels: Tuple[int] = (320, 640, 1280, 1280), 76 | layers_per_block: int = 2, 77 | downsample_padding: int = 1, 78 | mid_block_scale_factor: float = 1, 79 | act_fn: str = "silu", 80 | norm_num_groups: int = 32, 81 | norm_eps: float = 1e-5, 82 | cross_attention_dim: int = 1280, 83 | attention_head_dim: int = 8, 84 | ): 85 | super().__init__() 86 | self.center_input_sample = center_input_sample 87 | self.sample_size = sample_size 88 | time_embed_dim = block_out_channels[0] * 4 89 | 90 | # input 91 | self.conv_in = nn.Conv2dBias(in_channels, block_out_channels[0], 3, 1, 1) 92 | # time 93 | self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) 94 | timestep_input_dim = block_out_channels[0] 95 | 96 | self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) 97 | 98 | self.down_blocks = nn.ModuleList([]) 99 | self.up_blocks = nn.ModuleList([]) 100 | 101 | # down 102 | output_channel = block_out_channels[0] 103 | for i, down_block_type in enumerate(down_block_types): 104 | input_channel = output_channel 105 | output_channel = block_out_channels[i] 106 | is_final_block = i == len(block_out_channels) - 1 107 | 108 | down_block = get_down_block( 109 | down_block_type, 110 | num_layers=layers_per_block, 111 | in_channels=input_channel, 112 | out_channels=output_channel, 113 | temb_channels=time_embed_dim, 114 | add_downsample=not is_final_block, 115 | resnet_eps=norm_eps, 116 | resnet_act_fn=act_fn, 117 | cross_attention_dim=cross_attention_dim, 118 | attn_num_head_channels=attention_head_dim, 119 | downsample_padding=downsample_padding, 120 | ) 121 | self.down_blocks.append(down_block) 122 | 123 | # mid 124 | self.mid_block = UNetMidBlock2DCrossAttn( 125 | in_channels=block_out_channels[-1], 126 | temb_channels=time_embed_dim, 127 | resnet_eps=norm_eps, 128 | resnet_act_fn=act_fn, 129 | output_scale_factor=mid_block_scale_factor, 130 | resnet_time_scale_shift="default", 131 | cross_attention_dim=cross_attention_dim, 132 | attn_num_head_channels=attention_head_dim, 133 | resnet_groups=norm_num_groups, 134 | ) 135 | 136 | # up 137 | reversed_block_out_channels = list(reversed(block_out_channels)) 138 | output_channel = reversed_block_out_channels[0] 139 | for i, up_block_type in enumerate(up_block_types): 140 | prev_output_channel = output_channel 141 | output_channel = reversed_block_out_channels[i] 142 | input_channel = reversed_block_out_channels[ 143 | min(i + 1, len(block_out_channels) - 1) 144 | ] 145 | 146 | is_final_block = i == len(block_out_channels) - 1 147 | 148 | up_block = get_up_block( 149 | up_block_type, 150 | num_layers=layers_per_block + 1, 151 | in_channels=input_channel, 152 | out_channels=output_channel, 153 | prev_output_channel=prev_output_channel, 154 | temb_channels=time_embed_dim, 155 | add_upsample=not is_final_block, 156 | resnet_eps=norm_eps, 157 | resnet_act_fn=act_fn, 158 | cross_attention_dim=cross_attention_dim, 159 | attn_num_head_channels=attention_head_dim, 160 | ) 161 | self.up_blocks.append(up_block) 162 | prev_output_channel = output_channel 163 | 164 | # out 165 | self.conv_norm_out = nn.GroupNorm( 166 | num_channels=block_out_channels[0], 167 | num_groups=norm_num_groups, 168 | eps=norm_eps, 169 | use_swish=True, 170 | ) 171 | 172 | self.conv_out = nn.Conv2dBias(block_out_channels[0], out_channels, 3, 1, 1) 173 | 174 | def forward( 175 | self, 176 | sample, 177 | timesteps, 178 | encoder_hidden_states, 179 | return_dict: bool = True, 180 | ): 181 | """r 182 | Args: 183 | sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor 184 | timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps 185 | encoder_hidden_states (`torch.FloatTensor`): (batch, channel, height, width) encoder hidden states 186 | return_dict (`bool`, *optional*, defaults to `True`): 187 | Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. 188 | 189 | Returns: 190 | [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: 191 | [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When 192 | returning a tuple, the first element is the sample tensor. 193 | """ 194 | 195 | # 1. time 196 | t_emb = self.time_proj(timesteps) 197 | emb = self.time_embedding(t_emb) 198 | 199 | # 2. pre-process 200 | sample = self.conv_in(sample) 201 | 202 | # 3. down 203 | down_block_res_samples = (sample,) 204 | for downsample_block in self.down_blocks: 205 | if ( 206 | hasattr(downsample_block, "attentions") 207 | and downsample_block.attentions is not None 208 | ): 209 | sample, res_samples = downsample_block( 210 | hidden_states=sample, 211 | temb=emb, 212 | encoder_hidden_states=encoder_hidden_states, 213 | ) 214 | else: 215 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb) 216 | 217 | down_block_res_samples += res_samples 218 | 219 | # 4. mid 220 | sample = self.mid_block( 221 | sample, emb, encoder_hidden_states=encoder_hidden_states 222 | ) 223 | 224 | # 5. up 225 | for upsample_block in self.up_blocks: 226 | res_samples = down_block_res_samples[-len(upsample_block.resnets) :] 227 | down_block_res_samples = down_block_res_samples[ 228 | : -len(upsample_block.resnets) 229 | ] 230 | 231 | if ( 232 | hasattr(upsample_block, "attentions") 233 | and upsample_block.attentions is not None 234 | ): 235 | sample = upsample_block( 236 | hidden_states=sample, 237 | temb=emb, 238 | res_hidden_states_tuple=res_samples, 239 | encoder_hidden_states=encoder_hidden_states, 240 | ) 241 | else: 242 | sample = upsample_block( 243 | hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples 244 | ) 245 | 246 | # 6. post-process 247 | # make sure hidden states is in float32 248 | # when running in half-precision 249 | sample = self.conv_norm_out(sample) 250 | sample = self.conv_out(sample) 251 | return sample 252 | -------------------------------------------------------------------------------- /stablefusion/ait_modeling/unet_blocks.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # Copyright 2022 The HuggingFace Team. All rights reserved. 16 | # 17 | # Licensed under the Apache License, Version 2.0 (the "License"); 18 | # you may not use this file except in compliance with the License. 19 | # You may obtain a copy of the License at 20 | # 21 | # http://www.apache.org/licenses/LICENSE-2.0 22 | # 23 | # Unless required by applicable law or agreed to in writing, software 24 | # distributed under the License is distributed on an "AS IS" BASIS, 25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 | # See the License for the specific language governing permissions and 27 | 28 | # flake8: noqa 29 | from aitemplate.compiler import ops 30 | 31 | from aitemplate.frontend import nn, Tensor 32 | from aitemplate.testing import detect_target 33 | from modeling.attention import AttentionBlock 34 | 35 | from modeling.clip import SpatialTransformer 36 | from modeling.resnet import Downsample2D, ResnetBlock2D, Upsample2D 37 | 38 | # pylint: disable=W0102 39 | 40 | 41 | def get_down_block( 42 | down_block_type, 43 | num_layers, 44 | in_channels, 45 | out_channels, 46 | temb_channels, 47 | add_downsample, 48 | resnet_eps, 49 | resnet_act_fn, 50 | attn_num_head_channels, 51 | cross_attention_dim=None, 52 | downsample_padding=None, 53 | ): 54 | down_block_type = ( 55 | down_block_type[7:] 56 | if down_block_type.startswith("UNetRes") 57 | else down_block_type 58 | ) 59 | if down_block_type == "DownBlock2D": 60 | return DownBlock2D( 61 | num_layers=num_layers, 62 | in_channels=in_channels, 63 | out_channels=out_channels, 64 | temb_channels=temb_channels, 65 | add_downsample=add_downsample, 66 | resnet_eps=resnet_eps, 67 | resnet_act_fn=resnet_act_fn, 68 | downsample_padding=downsample_padding, 69 | ) 70 | elif down_block_type == "AttnDownBlock2D": 71 | return AttnDownBlock2D( 72 | num_layers=num_layers, 73 | in_channels=in_channels, 74 | out_channels=out_channels, 75 | temb_channels=temb_channels, 76 | add_downsample=add_downsample, 77 | resnet_eps=resnet_eps, 78 | resnet_act_fn=resnet_act_fn, 79 | downsample_padding=downsample_padding, 80 | attn_num_head_channels=attn_num_head_channels, 81 | ) 82 | elif down_block_type == "CrossAttnDownBlock2D": 83 | if cross_attention_dim is None: 84 | raise ValueError( 85 | "cross_attention_dim must be specified for CrossAttnDownBlock2D" 86 | ) 87 | return CrossAttnDownBlock2D( 88 | num_layers=num_layers, 89 | in_channels=in_channels, 90 | out_channels=out_channels, 91 | temb_channels=temb_channels, 92 | add_downsample=add_downsample, 93 | resnet_eps=resnet_eps, 94 | resnet_act_fn=resnet_act_fn, 95 | downsample_padding=downsample_padding, 96 | cross_attention_dim=cross_attention_dim, 97 | attn_num_head_channels=attn_num_head_channels, 98 | ) 99 | elif down_block_type == "SkipDownBlock2D": 100 | return SkipDownBlock2D( 101 | num_layers=num_layers, 102 | in_channels=in_channels, 103 | out_channels=out_channels, 104 | temb_channels=temb_channels, 105 | add_downsample=add_downsample, 106 | resnet_eps=resnet_eps, 107 | resnet_act_fn=resnet_act_fn, 108 | downsample_padding=downsample_padding, 109 | ) 110 | elif down_block_type == "AttnSkipDownBlock2D": 111 | return AttnSkipDownBlock2D( 112 | num_layers=num_layers, 113 | in_channels=in_channels, 114 | out_channels=out_channels, 115 | temb_channels=temb_channels, 116 | add_downsample=add_downsample, 117 | resnet_eps=resnet_eps, 118 | resnet_act_fn=resnet_act_fn, 119 | downsample_padding=downsample_padding, 120 | attn_num_head_channels=attn_num_head_channels, 121 | ) 122 | elif down_block_type == "DownEncoderBlock2D": 123 | return DownEncoderBlock2D( 124 | num_layers=num_layers, 125 | in_channels=in_channels, 126 | out_channels=out_channels, 127 | add_downsample=add_downsample, 128 | resnet_eps=resnet_eps, 129 | resnet_act_fn=resnet_act_fn, 130 | downsample_padding=downsample_padding, 131 | ) 132 | 133 | 134 | def get_up_block( 135 | up_block_type, 136 | num_layers, 137 | in_channels, 138 | out_channels, 139 | prev_output_channel, 140 | temb_channels, 141 | add_upsample, 142 | resnet_eps, 143 | resnet_act_fn, 144 | attn_num_head_channels, 145 | cross_attention_dim=None, 146 | ): 147 | up_block_type = ( 148 | up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type 149 | ) 150 | if up_block_type == "UpBlock2D": 151 | return UpBlock2D( 152 | num_layers=num_layers, 153 | in_channels=in_channels, 154 | out_channels=out_channels, 155 | prev_output_channel=prev_output_channel, 156 | temb_channels=temb_channels, 157 | add_upsample=add_upsample, 158 | resnet_eps=resnet_eps, 159 | resnet_act_fn=resnet_act_fn, 160 | ) 161 | elif up_block_type == "CrossAttnUpBlock2D": 162 | if cross_attention_dim is None: 163 | raise ValueError( 164 | "cross_attention_dim must be specified for CrossAttnUpBlock2D" 165 | ) 166 | return CrossAttnUpBlock2D( 167 | num_layers=num_layers, 168 | in_channels=in_channels, 169 | out_channels=out_channels, 170 | prev_output_channel=prev_output_channel, 171 | temb_channels=temb_channels, 172 | add_upsample=add_upsample, 173 | resnet_eps=resnet_eps, 174 | resnet_act_fn=resnet_act_fn, 175 | cross_attention_dim=cross_attention_dim, 176 | attn_num_head_channels=attn_num_head_channels, 177 | ) 178 | elif up_block_type == "AttnUpBlock2D": 179 | return AttnUpBlock2D( 180 | num_layers=num_layers, 181 | in_channels=in_channels, 182 | out_channels=out_channels, 183 | prev_output_channel=prev_output_channel, 184 | temb_channels=temb_channels, 185 | add_upsample=add_upsample, 186 | resnet_eps=resnet_eps, 187 | resnet_act_fn=resnet_act_fn, 188 | attn_num_head_channels=attn_num_head_channels, 189 | ) 190 | elif up_block_type == "SkipUpBlock2D": 191 | return SkipUpBlock2D( 192 | num_layers=num_layers, 193 | in_channels=in_channels, 194 | out_channels=out_channels, 195 | prev_output_channel=prev_output_channel, 196 | temb_channels=temb_channels, 197 | add_upsample=add_upsample, 198 | resnet_eps=resnet_eps, 199 | resnet_act_fn=resnet_act_fn, 200 | ) 201 | elif up_block_type == "AttnSkipUpBlock2D": 202 | return AttnSkipUpBlock2D( 203 | num_layers=num_layers, 204 | in_channels=in_channels, 205 | out_channels=out_channels, 206 | prev_output_channel=prev_output_channel, 207 | temb_channels=temb_channels, 208 | add_upsample=add_upsample, 209 | resnet_eps=resnet_eps, 210 | resnet_act_fn=resnet_act_fn, 211 | attn_num_head_channels=attn_num_head_channels, 212 | ) 213 | elif up_block_type == "UpDecoderBlock2D": 214 | return UpDecoderBlock2D( 215 | num_layers=num_layers, 216 | in_channels=in_channels, 217 | out_channels=out_channels, 218 | add_upsample=add_upsample, 219 | resnet_eps=resnet_eps, 220 | resnet_act_fn=resnet_act_fn, 221 | ) 222 | raise ValueError(f"{up_block_type} does not exist.") 223 | 224 | 225 | class UNetMidBlock2DCrossAttn(nn.Module): 226 | def __init__( 227 | self, 228 | in_channels: int, 229 | temb_channels: int, 230 | dropout: float = 0.0, 231 | num_layers: int = 1, 232 | resnet_eps: float = 1e-6, 233 | resnet_time_scale_shift: str = "default", 234 | resnet_act_fn: str = "swish", 235 | resnet_groups: int = 32, 236 | resnet_pre_norm: bool = True, 237 | attn_num_head_channels=1, 238 | attention_type="default", 239 | output_scale_factor=1.0, 240 | cross_attention_dim=1280, 241 | **kwargs, 242 | ): 243 | super().__init__() 244 | 245 | self.attention_type = attention_type 246 | self.attn_num_head_channels = attn_num_head_channels 247 | resnet_groups = ( 248 | resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) 249 | ) 250 | 251 | # there is always at least one resnet 252 | resnets = [ 253 | ResnetBlock2D( 254 | in_channels=in_channels, 255 | out_channels=in_channels, 256 | temb_channels=temb_channels, 257 | eps=resnet_eps, 258 | groups=resnet_groups, 259 | dropout=dropout, 260 | time_embedding_norm=resnet_time_scale_shift, 261 | non_linearity=resnet_act_fn, 262 | output_scale_factor=output_scale_factor, 263 | pre_norm=resnet_pre_norm, 264 | ) 265 | ] 266 | attentions = [] 267 | 268 | for _ in range(num_layers): 269 | attentions.append( 270 | SpatialTransformer( 271 | in_channels, 272 | attn_num_head_channels, 273 | in_channels // attn_num_head_channels, 274 | depth=1, 275 | context_dim=cross_attention_dim, 276 | ) 277 | ) 278 | resnets.append( 279 | ResnetBlock2D( 280 | in_channels=in_channels, 281 | out_channels=in_channels, 282 | temb_channels=temb_channels, 283 | eps=resnet_eps, 284 | groups=resnet_groups, 285 | dropout=dropout, 286 | time_embedding_norm=resnet_time_scale_shift, 287 | non_linearity=resnet_act_fn, 288 | output_scale_factor=output_scale_factor, 289 | pre_norm=resnet_pre_norm, 290 | ) 291 | ) 292 | 293 | self.attentions = nn.ModuleList(attentions) 294 | self.resnets = nn.ModuleList(resnets) 295 | 296 | def forward(self, hidden_states, temb=None, encoder_hidden_states=None): 297 | hidden_states = self.resnets[0](hidden_states, temb) 298 | for attn, resnet in zip(self.attentions, self.resnets[1:]): 299 | hidden_states = attn(hidden_states, encoder_hidden_states) 300 | hidden_states = resnet(hidden_states, temb) 301 | 302 | return hidden_states 303 | 304 | 305 | class CrossAttnDownBlock2D(nn.Module): 306 | def __init__( 307 | self, 308 | in_channels: int, 309 | out_channels: int, 310 | temb_channels: int, 311 | dropout: float = 0.0, 312 | num_layers: int = 1, 313 | resnet_eps: float = 1e-6, 314 | resnet_time_scale_shift: str = "default", 315 | resnet_act_fn: str = "swish", 316 | resnet_groups: int = 32, 317 | resnet_pre_norm: bool = True, 318 | attn_num_head_channels=1, 319 | cross_attention_dim=1280, 320 | attention_type="default", 321 | output_scale_factor=1.0, 322 | downsample_padding=1, 323 | add_downsample=True, 324 | ): 325 | super().__init__() 326 | 327 | resnets = [] 328 | attentions = [] 329 | 330 | self.attention_type = attention_type 331 | self.attn_num_head_channels = attn_num_head_channels 332 | 333 | for i in range(num_layers): 334 | in_channels = in_channels if i == 0 else out_channels 335 | resnets.append( 336 | ResnetBlock2D( 337 | in_channels=in_channels, 338 | out_channels=out_channels, 339 | temb_channels=temb_channels, 340 | eps=resnet_eps, 341 | groups=resnet_groups, 342 | dropout=dropout, 343 | time_embedding_norm=resnet_time_scale_shift, 344 | non_linearity=resnet_act_fn, 345 | output_scale_factor=output_scale_factor, 346 | pre_norm=resnet_pre_norm, 347 | ) 348 | ) 349 | attentions.append( 350 | SpatialTransformer( 351 | out_channels, 352 | attn_num_head_channels, 353 | out_channels // attn_num_head_channels, 354 | depth=1, 355 | context_dim=cross_attention_dim, 356 | ) 357 | ) 358 | self.attentions = nn.ModuleList(attentions) 359 | self.resnets = nn.ModuleList(resnets) 360 | 361 | if add_downsample: 362 | self.downsamplers = nn.ModuleList( 363 | [ 364 | Downsample2D( 365 | in_channels, 366 | use_conv=True, 367 | out_channels=out_channels, 368 | padding=downsample_padding, 369 | name="op", 370 | ) 371 | ] 372 | ) 373 | else: 374 | self.downsamplers = None 375 | 376 | def forward(self, hidden_states, temb=None, encoder_hidden_states=None): 377 | output_states = () 378 | 379 | for resnet, attn in zip(self.resnets, self.attentions): 380 | hidden_states = resnet(hidden_states, temb) 381 | hidden_states = attn(hidden_states, context=encoder_hidden_states) 382 | output_states += (hidden_states,) 383 | 384 | if self.downsamplers is not None: 385 | for downsampler in self.downsamplers: 386 | hidden_states = downsampler(hidden_states) 387 | 388 | output_states += (hidden_states,) 389 | 390 | return hidden_states, output_states 391 | 392 | 393 | class DownBlock2D(nn.Module): 394 | def __init__( 395 | self, 396 | in_channels: int, 397 | out_channels: int, 398 | temb_channels: int, 399 | dropout: float = 0.0, 400 | num_layers: int = 1, 401 | resnet_eps: float = 1e-6, 402 | resnet_time_scale_shift: str = "default", 403 | resnet_act_fn: str = "swish", 404 | resnet_groups: int = 32, 405 | resnet_pre_norm: bool = True, 406 | output_scale_factor=1.0, 407 | add_downsample=True, 408 | downsample_padding=1, 409 | ): 410 | super().__init__() 411 | resnets = [] 412 | 413 | for i in range(num_layers): 414 | in_channels = in_channels if i == 0 else out_channels 415 | resnets.append( 416 | ResnetBlock2D( 417 | in_channels=in_channels, 418 | out_channels=out_channels, 419 | temb_channels=temb_channels, 420 | eps=resnet_eps, 421 | groups=resnet_groups, 422 | dropout=dropout, 423 | time_embedding_norm=resnet_time_scale_shift, 424 | non_linearity=resnet_act_fn, 425 | output_scale_factor=output_scale_factor, 426 | pre_norm=resnet_pre_norm, 427 | ) 428 | ) 429 | 430 | self.resnets = nn.ModuleList(resnets) 431 | 432 | if add_downsample: 433 | self.downsamplers = nn.ModuleList( 434 | [ 435 | Downsample2D( 436 | in_channels, 437 | use_conv=True, 438 | out_channels=out_channels, 439 | padding=downsample_padding, 440 | name="op", 441 | ) 442 | ] 443 | ) 444 | else: 445 | self.downsamplers = None 446 | 447 | def forward(self, hidden_states, temb=None): 448 | output_states = () 449 | 450 | for resnet in self.resnets: 451 | hidden_states = resnet(hidden_states, temb) 452 | output_states += (hidden_states,) 453 | 454 | if self.downsamplers is not None: 455 | for downsampler in self.downsamplers: 456 | hidden_states = downsampler(hidden_states) 457 | 458 | output_states += (hidden_states,) 459 | 460 | return hidden_states, output_states 461 | 462 | 463 | class CrossAttnUpBlock2D(nn.Module): 464 | def __init__( 465 | self, 466 | in_channels: int, 467 | out_channels: int, 468 | prev_output_channel: int, 469 | temb_channels: int, 470 | dropout: float = 0.0, 471 | num_layers: int = 1, 472 | resnet_eps: float = 1e-6, 473 | resnet_time_scale_shift: str = "default", 474 | resnet_act_fn: str = "swish", 475 | resnet_groups: int = 32, 476 | resnet_pre_norm: bool = True, 477 | attn_num_head_channels=1, 478 | cross_attention_dim=1280, 479 | attention_type="default", 480 | output_scale_factor=1.0, 481 | downsample_padding=1, 482 | add_upsample=True, 483 | ): 484 | super().__init__() 485 | 486 | resnets = [] 487 | attentions = [] 488 | 489 | self.attention_type = attention_type 490 | self.attn_num_head_channels = attn_num_head_channels 491 | 492 | for i in range(num_layers): 493 | res_skip_channels = in_channels if (i == num_layers - 1) else out_channels 494 | resnet_in_channels = prev_output_channel if i == 0 else out_channels 495 | 496 | resnets.append( 497 | ResnetBlock2D( 498 | in_channels=resnet_in_channels + res_skip_channels, 499 | out_channels=out_channels, 500 | temb_channels=temb_channels, 501 | eps=resnet_eps, 502 | groups=resnet_groups, 503 | dropout=dropout, 504 | time_embedding_norm=resnet_time_scale_shift, 505 | non_linearity=resnet_act_fn, 506 | output_scale_factor=output_scale_factor, 507 | pre_norm=resnet_pre_norm, 508 | ) 509 | ) 510 | attentions.append( 511 | SpatialTransformer( 512 | out_channels, 513 | attn_num_head_channels, 514 | out_channels // attn_num_head_channels, 515 | depth=1, 516 | context_dim=cross_attention_dim, 517 | ) 518 | ) 519 | self.attentions = nn.ModuleList(attentions) 520 | self.resnets = nn.ModuleList(resnets) 521 | 522 | if add_upsample: 523 | self.upsamplers = nn.ModuleList( 524 | [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)] 525 | ) 526 | else: 527 | self.upsamplers = None 528 | 529 | def forward( 530 | self, 531 | hidden_states, 532 | res_hidden_states_tuple, 533 | temb=None, 534 | encoder_hidden_states=None, 535 | ): 536 | for resnet, attn in zip(self.resnets, self.attentions): 537 | # pop res hidden states 538 | res_hidden_states = res_hidden_states_tuple[-1] 539 | res_hidden_states_tuple = res_hidden_states_tuple[:-1] 540 | hidden_states = ops.concatenate()( 541 | [hidden_states, res_hidden_states], dim=-1 542 | ) 543 | 544 | hidden_states = resnet(hidden_states, temb=temb) 545 | hidden_states = attn(hidden_states, context=encoder_hidden_states) 546 | 547 | if self.upsamplers is not None: 548 | for upsampler in self.upsamplers: 549 | hidden_states = upsampler(hidden_states) 550 | 551 | return hidden_states 552 | 553 | 554 | class UpBlock2D(nn.Module): 555 | def __init__( 556 | self, 557 | in_channels: int, 558 | prev_output_channel: int, 559 | out_channels: int, 560 | temb_channels: int, 561 | dropout: float = 0.0, 562 | num_layers: int = 1, 563 | resnet_eps: float = 1e-6, 564 | resnet_time_scale_shift: str = "default", 565 | resnet_act_fn: str = "swish", 566 | resnet_groups: int = 32, 567 | resnet_pre_norm: bool = True, 568 | output_scale_factor=1.0, 569 | add_upsample=True, 570 | ): 571 | super().__init__() 572 | resnets = [] 573 | 574 | for i in range(num_layers): 575 | res_skip_channels = in_channels if (i == num_layers - 1) else out_channels 576 | resnet_in_channels = prev_output_channel if i == 0 else out_channels 577 | 578 | resnets.append( 579 | ResnetBlock2D( 580 | in_channels=resnet_in_channels + res_skip_channels, 581 | out_channels=out_channels, 582 | temb_channels=temb_channels, 583 | eps=resnet_eps, 584 | groups=resnet_groups, 585 | dropout=dropout, 586 | time_embedding_norm=resnet_time_scale_shift, 587 | non_linearity=resnet_act_fn, 588 | output_scale_factor=output_scale_factor, 589 | pre_norm=resnet_pre_norm, 590 | ) 591 | ) 592 | 593 | self.resnets = nn.ModuleList(resnets) 594 | 595 | if add_upsample: 596 | self.upsamplers = nn.ModuleList( 597 | [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)] 598 | ) 599 | else: 600 | self.upsamplers = None 601 | 602 | def forward(self, hidden_states, res_hidden_states_tuple, temb=None): 603 | for resnet in self.resnets: 604 | # pop res hidden states 605 | res_hidden_states = res_hidden_states_tuple[-1] 606 | res_hidden_states_tuple = res_hidden_states_tuple[:-1] 607 | hidden_states = ops.concatenate()( 608 | [hidden_states, res_hidden_states], dim=-1 609 | ) 610 | 611 | hidden_states = resnet(hidden_states, temb) 612 | 613 | if self.upsamplers is not None: 614 | for upsampler in self.upsamplers: 615 | hidden_states = upsampler(hidden_states) 616 | 617 | return hidden_states 618 | 619 | 620 | class UpDecoderBlock2D(nn.Module): 621 | def __init__( 622 | self, 623 | in_channels: int, 624 | out_channels: int, 625 | dropout: float = 0.0, 626 | num_layers: int = 1, 627 | resnet_eps: float = 1e-6, 628 | resnet_time_scale_shift: str = "default", 629 | resnet_act_fn: str = "swish", 630 | resnet_groups: int = 32, 631 | resnet_pre_norm: bool = True, 632 | output_scale_factor=1.0, 633 | add_upsample=True, 634 | ): 635 | super().__init__() 636 | resnets = [] 637 | 638 | for i in range(num_layers): 639 | input_channels = in_channels if i == 0 else out_channels 640 | 641 | resnets.append( 642 | ResnetBlock2D( 643 | in_channels=input_channels, 644 | out_channels=out_channels, 645 | temb_channels=None, 646 | eps=resnet_eps, 647 | groups=resnet_groups, 648 | dropout=dropout, 649 | time_embedding_norm=resnet_time_scale_shift, 650 | non_linearity=resnet_act_fn, 651 | output_scale_factor=output_scale_factor, 652 | pre_norm=resnet_pre_norm, 653 | ) 654 | ) 655 | 656 | self.resnets = nn.ModuleList(resnets) 657 | 658 | if add_upsample: 659 | self.upsamplers = nn.ModuleList( 660 | [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)] 661 | ) 662 | else: 663 | self.upsamplers = None 664 | 665 | def forward(self, hidden_states): 666 | for resnet in self.resnets: 667 | hidden_states = resnet(hidden_states, temb=None) 668 | 669 | if self.upsamplers is not None: 670 | for upsampler in self.upsamplers: 671 | hidden_states = upsampler(hidden_states) 672 | 673 | return hidden_states 674 | 675 | 676 | class UNetMidBlock2D(nn.Module): 677 | def __init__( 678 | self, 679 | batch_size, 680 | height, 681 | width, 682 | in_channels: int, 683 | temb_channels: int, 684 | dropout: float = 0.0, 685 | num_layers: int = 1, 686 | resnet_eps: float = 1e-6, 687 | resnet_time_scale_shift: str = "default", 688 | resnet_act_fn: str = "swish", 689 | resnet_groups: int = 32, 690 | resnet_pre_norm: bool = True, 691 | attn_num_head_channels=1, 692 | attention_type="default", 693 | output_scale_factor=1.0, 694 | **kwargs, 695 | ): 696 | super().__init__() 697 | 698 | if attention_type != "default": 699 | raise NotImplementedError( 700 | f"attention_type must be default! current value: {attention_type}" 701 | ) 702 | 703 | resnet_groups = ( 704 | resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) 705 | ) 706 | 707 | # there is always at least one resnet 708 | resnets = [ 709 | ResnetBlock2D( 710 | in_channels=in_channels, 711 | out_channels=in_channels, 712 | temb_channels=temb_channels, 713 | eps=resnet_eps, 714 | groups=resnet_groups, 715 | dropout=dropout, 716 | time_embedding_norm=resnet_time_scale_shift, 717 | non_linearity=resnet_act_fn, 718 | output_scale_factor=output_scale_factor, 719 | pre_norm=resnet_pre_norm, 720 | ) 721 | ] 722 | attentions = [] 723 | 724 | for _ in range(num_layers): 725 | attentions.append( 726 | AttentionBlock( 727 | batch_size, 728 | height, 729 | width, 730 | in_channels, 731 | num_head_channels=attn_num_head_channels, 732 | rescale_output_factor=output_scale_factor, 733 | eps=resnet_eps, 734 | num_groups=resnet_groups, 735 | ) 736 | ) 737 | resnets.append( 738 | ResnetBlock2D( 739 | in_channels=in_channels, 740 | out_channels=in_channels, 741 | temb_channels=temb_channels, 742 | eps=resnet_eps, 743 | groups=resnet_groups, 744 | dropout=dropout, 745 | time_embedding_norm=resnet_time_scale_shift, 746 | non_linearity=resnet_act_fn, 747 | output_scale_factor=output_scale_factor, 748 | pre_norm=resnet_pre_norm, 749 | ) 750 | ) 751 | 752 | self.attentions = nn.ModuleList(attentions) 753 | self.resnets = nn.ModuleList(resnets) 754 | 755 | def forward(self, hidden_states, temb=None, encoder_states=None): 756 | hidden_states = self.resnets[0](hidden_states, temb) 757 | for attn, resnet in zip(self.attentions, self.resnets[1:]): 758 | hidden_states = attn(hidden_states) 759 | hidden_states = resnet(hidden_states, temb) 760 | 761 | return hidden_states 762 | -------------------------------------------------------------------------------- /stablefusion/ait_modeling/vae.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Translated from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/vae.py. 16 | """ 17 | 18 | from typing import Tuple 19 | 20 | from aitemplate.frontend import nn, Tensor 21 | from modeling.unet_blocks import get_up_block, UNetMidBlock2D 22 | 23 | 24 | class Decoder(nn.Module): 25 | def __init__( 26 | self, 27 | batch_size, 28 | height, 29 | width, 30 | in_channels=3, 31 | out_channels=3, 32 | up_block_types=("UpDecoderBlock2D",), 33 | block_out_channels=(64,), 34 | layers_per_block=2, 35 | act_fn="silu", 36 | ): 37 | super().__init__() 38 | self.layers_per_block = layers_per_block 39 | 40 | self.conv_in = nn.Conv2dBias( 41 | in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1 42 | ) 43 | 44 | # mid 45 | self.mid_block = UNetMidBlock2D( 46 | batch_size, 47 | height, 48 | width, 49 | in_channels=block_out_channels[-1], 50 | resnet_eps=1e-6, 51 | resnet_act_fn=act_fn, 52 | output_scale_factor=1, 53 | resnet_time_scale_shift="default", 54 | attn_num_head_channels=None, 55 | resnet_groups=32, 56 | temb_channels=None, 57 | ) 58 | 59 | # up 60 | self.up_blocks = nn.ModuleList([]) 61 | reversed_block_out_channels = list(reversed(block_out_channels)) 62 | output_channel = reversed_block_out_channels[0] 63 | for i, up_block_type in enumerate(up_block_types): 64 | prev_output_channel = output_channel 65 | output_channel = reversed_block_out_channels[i] 66 | 67 | is_final_block = i == len(block_out_channels) - 1 68 | 69 | up_block = get_up_block( 70 | up_block_type, 71 | num_layers=self.layers_per_block + 1, 72 | in_channels=prev_output_channel, 73 | out_channels=output_channel, 74 | prev_output_channel=None, 75 | add_upsample=not is_final_block, 76 | resnet_eps=1e-6, 77 | resnet_act_fn=act_fn, 78 | attn_num_head_channels=None, 79 | temb_channels=None, 80 | ) 81 | self.up_blocks.append(up_block) 82 | prev_output_channel = output_channel 83 | 84 | # out 85 | num_groups_out = 32 86 | self.conv_norm_out = nn.GroupNorm( 87 | num_channels=block_out_channels[0], 88 | num_groups=num_groups_out, 89 | eps=1e-6, 90 | use_swish=True, 91 | ) 92 | self.conv_out = nn.Conv2dBias( 93 | block_out_channels[0], out_channels, kernel_size=3, padding=1, stride=1 94 | ) 95 | 96 | def forward(self, z) -> Tensor: 97 | sample = z 98 | sample = self.conv_in(sample) 99 | 100 | # middle 101 | sample = self.mid_block(sample) 102 | 103 | # up 104 | for up_block in self.up_blocks: 105 | sample = up_block(sample) 106 | 107 | sample = self.conv_norm_out(sample) 108 | sample = self.conv_out(sample) 109 | 110 | return sample 111 | 112 | 113 | class AutoencoderKL(nn.Module): 114 | def __init__( 115 | self, 116 | batch_size: int, 117 | height: int, 118 | width: int, 119 | in_channels: int = 3, 120 | out_channels: int = 3, 121 | down_block_types: Tuple[str] = ("DownEncoderBlock2D",), 122 | up_block_types: Tuple[str] = ("UpDecoderBlock2D",), 123 | block_out_channels: Tuple[int] = (64,), 124 | layers_per_block: int = 1, 125 | act_fn: str = "silu", 126 | latent_channels: int = 4, 127 | sample_size: int = 32, 128 | ): 129 | super().__init__() 130 | self.decoder = Decoder( 131 | batch_size, 132 | height, 133 | width, 134 | in_channels=latent_channels, 135 | out_channels=out_channels, 136 | up_block_types=up_block_types, 137 | block_out_channels=block_out_channels, 138 | layers_per_block=layers_per_block, 139 | act_fn=act_fn, 140 | ) 141 | self.post_quant_conv = nn.Conv2dBias( 142 | latent_channels, latent_channels, kernel_size=1, stride=1, padding=0 143 | ) 144 | 145 | def decode(self, z: Tensor, return_dict: bool = True): 146 | 147 | z = self.post_quant_conv(z) 148 | dec = self.decoder(z) 149 | return dec 150 | 151 | def forward(self): 152 | raise NotImplementedError("Only decode() is implemented for AutoencoderKL!") 153 | -------------------------------------------------------------------------------- /stablefusion/clip_textmodel.py: -------------------------------------------------------------------------------- 1 | from numpy import triu 2 | from transformers import CLIPTextConfig 3 | from transformers.models.clip.modeling_clip import CLIPTextModel, CLIPTextTransformer 4 | import torch 5 | from torch import nn 6 | 7 | 8 | def triu_onnx(x, diagonal=0): 9 | l = x.shape[0] 10 | arange = torch.arange(l, device=x.device) 11 | mask = arange.expand(l, l) 12 | arange = arange.unsqueeze(-1) 13 | if diagonal: 14 | arange = arange + diagonal 15 | mask = mask >= arange 16 | return mask * x 17 | 18 | 19 | class Triu(nn.Module): 20 | """export-friendly version of nn.SiLU()""" 21 | 22 | @staticmethod 23 | def forward(x): 24 | return triu_onnx(x) 25 | 26 | 27 | class CIPTextTransformerTracable(CLIPTextTransformer): 28 | def __init__(self, config: CLIPTextConfig): 29 | super().__init__(config) 30 | 31 | def _build_causal_attention_mask(self, bsz, seq_len, dtype): 32 | # lazily create causal attention mask, with full attention between the vision tokens 33 | # pytorch uses additive attention mask; fill with -inf 34 | mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype) 35 | mask.fill_(torch.tensor(torch.finfo(dtype).min)) 36 | # mask.triu_(1) # zero out the lower diagonal 37 | triu_onnx(mask, 1) 38 | mask = mask.unsqueeze(1) # expand mask 39 | return mask 40 | 41 | 42 | class CLIPTextModelTracable(CLIPTextModel): 43 | def __init__(self, config: CLIPTextConfig): 44 | super().__init__(config) 45 | self.text_model = CIPTextTransformerTracable(config) 46 | -------------------------------------------------------------------------------- /stablefusion/stablefusion_ov_engine.py: -------------------------------------------------------------------------------- 1 | """ 2 | we don't need saftychecker!! 3 | """ 4 | import inspect 5 | import numpy as np 6 | 7 | # openvino 8 | from openvino.runtime import Core 9 | 10 | # tokenizer 11 | from transformers import CLIPTokenizer 12 | 13 | # utils 14 | from tqdm import tqdm 15 | from huggingface_hub import hf_hub_download 16 | from diffusers import LMSDiscreteScheduler, PNDMScheduler 17 | import cv2 18 | import os 19 | from alfred import logger 20 | 21 | 22 | def result(var): 23 | return next(iter(var.values())) 24 | 25 | 26 | class StableDiffusionEngine: 27 | def __init__( 28 | self, 29 | scheduler, 30 | model="bes-dev/stable-diffusion-v1-4-openvino", 31 | local_model_path=None, 32 | tokenizer="openai/clip-vit-large-patch14", 33 | int8=False, 34 | device="CPU", 35 | ): 36 | self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer) 37 | self.scheduler = scheduler 38 | # models 39 | self.core = Core() 40 | 41 | load_local = False 42 | if os.path.exists(local_model_path): 43 | logger.info( 44 | f"detected model in local path, loading onnx model: {local_model_path}" 45 | ) 46 | load_local = True 47 | 48 | # text features 49 | if False: 50 | self._text_encoder = self.core.read_model( 51 | os.path.join(local_model_path, "text_encoder.onnx") 52 | ) 53 | logger.info("text encoder read.") 54 | else: 55 | self._text_encoder = self.core.read_model( 56 | hf_hub_download(repo_id=model, filename="text_encoder.xml"), 57 | hf_hub_download(repo_id=model, filename="text_encoder.bin"), 58 | ) 59 | self.text_encoder = self.core.compile_model(self._text_encoder, device) 60 | 61 | # diffusion 62 | self._unet = self.core.read_model( 63 | hf_hub_download(repo_id=model, filename="unet.xml"), 64 | hf_hub_download(repo_id=model, filename="unet.bin"), 65 | ) 66 | self.unet = self.core.compile_model(self._unet, device) 67 | self.latent_shape = tuple(self._unet.inputs[0].shape)[1:] 68 | 69 | # decoder 70 | if False: 71 | self._vae_decoder = self.core.read_model( 72 | os.path.join(local_model_path, "vae_decoder.onnx") 73 | ) 74 | else: 75 | self._vae_decoder = self.core.read_model( 76 | hf_hub_download(repo_id=model, filename="vae_decoder.xml"), 77 | hf_hub_download(repo_id=model, filename="vae_decoder.bin"), 78 | ) 79 | self.vae_decoder = self.core.compile_model(self._vae_decoder, device) 80 | 81 | # encoder 82 | # if load_local: 83 | if False: 84 | self._vae_encoder = self.core.read_model( 85 | os.path.join(local_model_path, "post_quant_conv.onnx") 86 | ) 87 | else: 88 | self._vae_encoder = self.core.read_model( 89 | hf_hub_download(repo_id=model, filename="vae_encoder.xml"), 90 | hf_hub_download(repo_id=model, filename="vae_encoder.bin"), 91 | ) 92 | self.vae_encoder = self.core.compile_model(self._vae_encoder, device) 93 | self.init_image_shape = tuple(self._vae_encoder.inputs[0].shape)[2:] 94 | 95 | def _preprocess_mask(self, mask): 96 | h, w = mask.shape 97 | if h != self.init_image_shape[0] and w != self.init_image_shape[1]: 98 | mask = cv2.resize( 99 | mask, 100 | (self.init_image_shape[1], self.init_image_shape[0]), 101 | interpolation=cv2.INTER_NEAREST, 102 | ) 103 | mask = cv2.resize( 104 | mask, 105 | (self.init_image_shape[1] // 8, self.init_image_shape[0] // 8), 106 | interpolation=cv2.INTER_NEAREST, 107 | ) 108 | mask = mask.astype(np.float32) / 255.0 109 | mask = np.tile(mask, (4, 1, 1)) 110 | mask = mask[None].transpose(0, 1, 2, 3) 111 | mask = 1 - mask 112 | return mask 113 | 114 | def _preprocess_image(self, image): 115 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 116 | h, w = image.shape[1:] 117 | if h != self.init_image_shape[0] and w != self.init_image_shape[1]: 118 | image = cv2.resize( 119 | image, 120 | (self.init_image_shape[1], self.init_image_shape[0]), 121 | interpolation=cv2.INTER_LANCZOS4, 122 | ) 123 | # normalize 124 | image = image.astype(np.float32) / 255.0 125 | image = 2.0 * image - 1.0 126 | # to batch 127 | image = image[None].transpose(0, 3, 1, 2) 128 | return image 129 | 130 | def _encode_image(self, init_image): 131 | moments = result( 132 | self.vae_encoder.infer_new_request( 133 | {"init_image": self._preprocess_image(init_image)} 134 | ) 135 | ) 136 | mean, logvar = np.split(moments, 2, axis=1) 137 | std = np.exp(logvar * 0.5) 138 | latent = (mean + std * np.random.randn(*mean.shape)) * 0.18215 139 | return latent 140 | 141 | def __call__( 142 | self, 143 | prompt, 144 | init_image=None, 145 | mask=None, 146 | strength=0.5, 147 | num_inference_steps=32, 148 | guidance_scale=7.5, 149 | eta=0.0, 150 | ): 151 | # extract condition 152 | tokens = self.tokenizer( 153 | prompt, 154 | padding="max_length", 155 | max_length=self.tokenizer.model_max_length, 156 | truncation=True, 157 | ).input_ids 158 | text_embeddings = result( 159 | self.text_encoder.infer_new_request({"tokens": np.array([tokens])}) 160 | ) 161 | 162 | # do classifier free guidance 163 | if guidance_scale > 1.0: 164 | tokens_uncond = self.tokenizer( 165 | "", 166 | padding="max_length", 167 | max_length=self.tokenizer.model_max_length, 168 | truncation=True, 169 | ).input_ids 170 | uncond_embeddings = result( 171 | self.text_encoder.infer_new_request( 172 | {"tokens": np.array([tokens_uncond])} 173 | ) 174 | ) 175 | text_embeddings = np.concatenate( 176 | (uncond_embeddings, text_embeddings), axis=0 177 | ) 178 | 179 | # set timesteps 180 | accepts_offset = "offset" in set( 181 | inspect.signature(self.scheduler.set_timesteps).parameters.keys() 182 | ) 183 | extra_set_kwargs = {} 184 | offset = 0 185 | if accepts_offset: 186 | offset = 1 187 | extra_set_kwargs["offset"] = 1 188 | 189 | self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) 190 | 191 | # initialize latent latent 192 | if init_image is None: 193 | latents = np.random.randn(*self.latent_shape) 194 | init_timestep = num_inference_steps 195 | else: 196 | init_latents = self._encode_image(init_image) 197 | init_timestep = int(num_inference_steps * strength) + offset 198 | init_timestep = min(init_timestep, num_inference_steps) 199 | timesteps = np.array([[self.scheduler.timesteps[-init_timestep]]]).astype( 200 | np.long 201 | ) 202 | noise = np.random.randn(*self.latent_shape) 203 | latents = self.scheduler.add_noise(init_latents, noise, timesteps)[0] 204 | 205 | if init_image is not None and mask is not None: 206 | mask = self._preprocess_mask(mask) 207 | else: 208 | mask = None 209 | 210 | # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas 211 | if isinstance(self.scheduler, LMSDiscreteScheduler): 212 | latents = latents * self.scheduler.sigmas[0] 213 | 214 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 215 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 216 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 217 | # and should be between [0, 1] 218 | accepts_eta = "eta" in set( 219 | inspect.signature(self.scheduler.step).parameters.keys() 220 | ) 221 | extra_step_kwargs = {} 222 | if accepts_eta: 223 | extra_step_kwargs["eta"] = eta 224 | 225 | t_start = max(num_inference_steps - init_timestep + offset, 0) 226 | for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])): 227 | # expand the latents if we are doing classifier free guidance 228 | latent_model_input = ( 229 | np.stack([latents, latents], 0) 230 | if guidance_scale > 1.0 231 | else latents[None] 232 | ) 233 | if isinstance(self.scheduler, LMSDiscreteScheduler): 234 | sigma = self.scheduler.sigmas[i] 235 | latent_model_input = latent_model_input / ((sigma ** 2 + 1) ** 0.5) 236 | 237 | # predict the noise residual 238 | noise_pred = result( 239 | self.unet.infer_new_request( 240 | { 241 | "latent_model_input": latent_model_input, 242 | "t": t, 243 | "encoder_hidden_states": text_embeddings, 244 | } 245 | ) 246 | ) 247 | 248 | # perform guidance 249 | if guidance_scale > 1.0: 250 | noise_pred = noise_pred[0] + guidance_scale * ( 251 | noise_pred[1] - noise_pred[0] 252 | ) 253 | 254 | # compute the previous noisy sample x_t -> x_t-1 255 | if isinstance(self.scheduler, LMSDiscreteScheduler): 256 | latents = self.scheduler.step( 257 | noise_pred, i, latents, **extra_step_kwargs 258 | )["prev_sample"] 259 | else: 260 | latents = self.scheduler.step( 261 | noise_pred, t, latents, **extra_step_kwargs 262 | )["prev_sample"] 263 | 264 | # masking for inapinting 265 | if mask is not None: 266 | init_latents_proper = self.scheduler.add_noise(init_latents, noise, t) 267 | latents = ((init_latents_proper * mask) + (latents * (1 - mask)))[0] 268 | 269 | image = result( 270 | self.vae_decoder.infer_new_request({"latents": np.expand_dims(latents, 0)}) 271 | ) 272 | 273 | # convert tensor to opencv's image format 274 | image = (image / 2 + 0.5).clip(0, 1) 275 | image = (image[0].transpose(1, 2, 0)[:, :, ::-1] * 255).astype(np.uint8) 276 | return image 277 | -------------------------------------------------------------------------------- /stablefusion/stablefusion_pipeline.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luohao123/gaintmodels/d9b389a3e781fdeafc7f695c6ce021d5c9ceebbb/stablefusion/stablefusion_pipeline.py -------------------------------------------------------------------------------- /stablefusion/test.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luohao123/gaintmodels/d9b389a3e781fdeafc7f695c6ce021d5c9ceebbb/stablefusion/test.py -------------------------------------------------------------------------------- /stablefusion/trt_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import autocast 3 | import tensorrt as trt 4 | 5 | trt.init_libnvinfer_plugins(None, "") 6 | import numpy as np 7 | import pycuda.driver as cuda 8 | import pycuda.autoinit # without this, "LogicError: explicit_context_dependent failed: invalid device context - no currently active context?" 9 | from time import time 10 | 11 | 12 | class TRTModel: 13 | """ 14 | Generic class to run a TRT engine by specifying engine path and giving input data. 15 | """ 16 | 17 | class HostDeviceMem(object): 18 | """ 19 | Helper class to record host-device memory pointer pairs 20 | """ 21 | 22 | def __init__(self, host_mem, device_mem): 23 | self.host = host_mem 24 | self.device = device_mem 25 | 26 | def __str__(self): 27 | return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device) 28 | 29 | def __repr__(self): 30 | return self.__str__() 31 | 32 | def __init__(self, engine_path): 33 | self.engine_path = engine_path 34 | self.logger = trt.Logger(trt.Logger.WARNING) 35 | self.runtime = trt.Runtime(self.logger) 36 | 37 | # load and deserialize TRT engine 38 | self.engine = self.load_engine() 39 | 40 | # allocate input/output memory buffers 41 | self.inputs, self.outputs, self.bindings, self.stream = self.allocate_buffers( 42 | self.engine 43 | ) 44 | 45 | # create context 46 | self.context = self.engine.create_execution_context() 47 | 48 | # Dict of NumPy dtype -> torch dtype (when the correspondence exists). From: https://github.com/pytorch/pytorch/blob/e180ca652f8a38c479a3eff1080efe69cbc11621/torch/testing/_internal/common_utils.py#L349 49 | self.numpy_to_torch_dtype_dict = { 50 | bool: torch.bool, 51 | np.uint8: torch.uint8, 52 | np.int8: torch.int8, 53 | np.int16: torch.int16, 54 | np.int32: torch.int32, 55 | np.int64: torch.int64, 56 | np.float16: torch.float16, 57 | np.float32: torch.float32, 58 | np.float64: torch.float64, 59 | np.complex64: torch.complex64, 60 | np.complex128: torch.complex128, 61 | } 62 | 63 | def load_engine(self): 64 | with open(self.engine_path, "rb") as f: 65 | engine = self.runtime.deserialize_cuda_engine(f.read()) 66 | return engine 67 | 68 | def allocate_buffers(self, engine): 69 | """ 70 | Allocates all buffers required for an engine, i.e. host/device inputs/outputs. 71 | """ 72 | inputs = [] 73 | outputs = [] 74 | bindings = [] 75 | stream = cuda.Stream() 76 | 77 | for binding in engine: # binding is the name of input/output 78 | size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size 79 | dtype = trt.nptype(engine.get_binding_dtype(binding)) 80 | 81 | # Allocate host and device buffers 82 | host_mem = cuda.pagelocked_empty( 83 | size, dtype 84 | ) # page-locked memory buffer (won't swapped to disk) 85 | device_mem = cuda.mem_alloc(host_mem.nbytes) 86 | 87 | # Append the device buffer address to device bindings. When cast to int, it's a linear index into the context's memory (like memory address). See https://documen.tician.de/pycuda/driver.html#pycuda.driver.DeviceAllocation 88 | bindings.append(int(device_mem)) 89 | 90 | # Append to the appropriate input/output list. 91 | if engine.binding_is_input(binding): 92 | inputs.append(self.HostDeviceMem(host_mem, device_mem)) 93 | else: 94 | outputs.append(self.HostDeviceMem(host_mem, device_mem)) 95 | 96 | return inputs, outputs, bindings, stream 97 | 98 | def __call__(self, model_inputs: list, timing=False): 99 | """ 100 | Inference step (like forward() in PyTorch). 101 | model_inputs: list of numpy array or list of torch.Tensor (on GPU) 102 | """ 103 | NUMPY = False 104 | TORCH = False 105 | if isinstance(model_inputs[0], np.ndarray): 106 | NUMPY = True 107 | elif torch.is_tensor(model_inputs[0]): 108 | TORCH = True 109 | else: 110 | assert False, "Unsupported input data format!" 111 | 112 | # batch size consistency check 113 | if NUMPY: 114 | batch_size = np.unique(np.array([i.shape[0] for i in model_inputs])) 115 | elif TORCH: 116 | batch_size = np.unique(np.array([i.size(dim=0) for i in model_inputs])) 117 | # assert len(batch_size) == 1, 'Input batch sizes are not consistent!' 118 | batch_size = batch_size[0] 119 | 120 | for i, model_input in enumerate(model_inputs): 121 | # print("set input for ",i) 122 | binding_name = self.engine[i] # i-th input/output name 123 | # print("set input for ",binding_name) 124 | binding_dtype = trt.nptype( 125 | self.engine.get_binding_dtype(binding_name) 126 | ) # trt can only tell to numpy dtype 127 | # print("set input for ",binding_name,binding_dtype) 128 | # input type cast 129 | if NUMPY: 130 | model_input = model_input.astype(binding_dtype) 131 | elif TORCH: 132 | model_input = model_input.to( 133 | self.numpy_to_torch_dtype_dict[binding_dtype] 134 | ) 135 | 136 | if NUMPY: 137 | # fill host memory with flattened input data 138 | np.copyto(self.inputs[i].host, model_input.ravel()) 139 | elif TORCH: 140 | if timing: 141 | cuda.memcpy_dtod( 142 | self.inputs[i].device, 143 | model_input.data_ptr(), 144 | model_input.element_size() * model_input.nelement(), 145 | ) 146 | else: 147 | # for Torch GPU tensor it's easier, can just do Device to Device copy 148 | cuda.memcpy_dtod_async( 149 | self.inputs[i].device, 150 | model_input.data_ptr(), 151 | model_input.element_size() * model_input.nelement(), 152 | self.stream, 153 | ) # dtod need size in bytes 154 | 155 | if NUMPY: 156 | if timing: 157 | [cuda.memcpy_htod(inp.device, inp.host) for inp in self.inputs] 158 | else: 159 | # input, Host to Device 160 | [ 161 | cuda.memcpy_htod_async(inp.device, inp.host, self.stream) 162 | for inp in self.inputs 163 | ] 164 | 165 | duration = 0 166 | if timing: 167 | start_time = time() 168 | self.context.execute_v2(bindings=self.bindings) 169 | end_time = time() 170 | duration = end_time - start_time 171 | else: 172 | # run inference 173 | self.context.execute_async_v2( 174 | bindings=self.bindings, stream_handle=self.stream.handle 175 | ) # v2 no need for batch_size arg 176 | 177 | if timing: 178 | [cuda.memcpy_dtoh(out.host, out.device) for out in self.outputs] 179 | else: 180 | # output, Device to Host 181 | [ 182 | cuda.memcpy_dtoh_async(out.host, out.device, self.stream) 183 | for out in self.outputs 184 | ] 185 | 186 | if not timing: 187 | # synchronize to ensure completion of async calls 188 | self.stream.synchronize() 189 | 190 | if NUMPY: 191 | return [out.host.reshape(batch_size, -1) for out in self.outputs], duration 192 | elif TORCH: 193 | return [ 194 | torch.from_numpy(out.host.reshape(batch_size, -1)) 195 | for out in self.outputs 196 | ], duration 197 | -------------------------------------------------------------------------------- /stablefusion/unet_2d_condition.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Union 2 | from diffusers.models.unet_2d_condition import ( 3 | UNet2DConditionModel, 4 | UNet2DConditionModel, 5 | ) 6 | import torch 7 | 8 | 9 | class UNet2DConditionModelTracable(UNet2DConditionModel): 10 | def __init__( 11 | self, 12 | sample_size=None, 13 | in_channels=4, 14 | out_channels=4, 15 | center_input_sample=False, 16 | flip_sin_to_cos=True, 17 | freq_shift=0, 18 | down_block_types=..., 19 | up_block_types=..., 20 | block_out_channels=..., 21 | layers_per_block=2, 22 | downsample_padding=1, 23 | mid_block_scale_factor=1, 24 | act_fn="silu", 25 | norm_num_groups=32, 26 | norm_eps=0.00001, 27 | cross_attention_dim=1280, 28 | attention_head_dim=8, 29 | ): 30 | super().__init__( 31 | sample_size, 32 | in_channels, 33 | out_channels, 34 | center_input_sample, 35 | flip_sin_to_cos, 36 | freq_shift, 37 | down_block_types, 38 | up_block_types, 39 | block_out_channels, 40 | layers_per_block, 41 | downsample_padding, 42 | mid_block_scale_factor, 43 | act_fn, 44 | norm_num_groups, 45 | norm_eps, 46 | cross_attention_dim, 47 | attention_head_dim, 48 | ) 49 | 50 | def forward( 51 | self, 52 | sample: torch.FloatTensor, 53 | timestep: Union[torch.Tensor, float, int], 54 | encoder_hidden_states: torch.Tensor, 55 | ) -> Dict[str, torch.FloatTensor]: 56 | 57 | # 0. center input if necessary 58 | if self.config.center_input_sample: 59 | sample = 2 * sample - 1.0 60 | 61 | # 1. time 62 | timesteps = timestep 63 | if not torch.is_tensor(timesteps): 64 | timesteps = torch.tensor( 65 | [timesteps], dtype=torch.long, device=sample.device 66 | ) 67 | elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: 68 | timesteps = timesteps[None].to(sample.device) 69 | 70 | # broadcast to batch dimension 71 | # timesteps = timesteps.broadcast_to(sample.shape[0]) 72 | timesteps = timesteps * torch.ones(sample.shape[0]) 73 | 74 | t_emb = self.time_proj(timesteps) 75 | emb = self.time_embedding(t_emb) 76 | 77 | # 2. pre-process 78 | sample = self.conv_in(sample) 79 | 80 | # 3. down 81 | down_block_res_samples = (sample,) 82 | for downsample_block in self.down_blocks: 83 | 84 | if ( 85 | hasattr(downsample_block, "attentions") 86 | and downsample_block.attentions is not None 87 | ): 88 | sample, res_samples = downsample_block( 89 | hidden_states=sample, 90 | temb=emb, 91 | encoder_hidden_states=encoder_hidden_states, 92 | ) 93 | else: 94 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb) 95 | 96 | down_block_res_samples += res_samples 97 | 98 | # 4. mid 99 | sample = self.mid_block( 100 | sample, emb, encoder_hidden_states=encoder_hidden_states 101 | ) 102 | 103 | # 5. up 104 | for upsample_block in self.up_blocks: 105 | 106 | res_samples = down_block_res_samples[-len(upsample_block.resnets) :] 107 | down_block_res_samples = down_block_res_samples[ 108 | : -len(upsample_block.resnets) 109 | ] 110 | 111 | if ( 112 | hasattr(upsample_block, "attentions") 113 | and upsample_block.attentions is not None 114 | ): 115 | sample = upsample_block( 116 | hidden_states=sample, 117 | temb=emb, 118 | res_hidden_states_tuple=res_samples, 119 | encoder_hidden_states=encoder_hidden_states, 120 | ) 121 | else: 122 | sample = upsample_block( 123 | hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples 124 | ) 125 | 126 | # 6. post-process 127 | # make sure hidden states is in float32 128 | # when running in half-precision 129 | sample = self.conv_norm_out(sample.float()).type(sample.dtype) 130 | sample = self.conv_act(sample) 131 | sample = self.conv_out(sample) 132 | 133 | # output = {"sample": sample} 134 | # return output 135 | return sample 136 | -------------------------------------------------------------------------------- /test_diffusers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers import StableDiffusionPipeline 3 | 4 | model_id = "CompVis/stable-diffusion-v1-4" 5 | # device = "cuda" 6 | device = "cpu" 7 | 8 | 9 | pipe = StableDiffusionPipeline.from_pretrained('weights/stable-diffusion-v1-4') 10 | pipe = pipe.to(device) 11 | 12 | prompt = "naked wonder woman wearing underwear on the beach" 13 | image = pipe(prompt, guidance_scale=7.5)["sample"][0] 14 | 15 | image.save("astronaut_rides_horse.png") --------------------------------------------------------------------------------