├── fastldm ├── __init__.py ├── environ.py ├── registry.py ├── mapping.py ├── modifier.py ├── experiment.py ├── benchmark.py ├── plugins.py ├── helper.py └── modules.py ├── setup.py ├── examples ├── README.md ├── benchmark_flash_attn.py ├── transform_diffusers_unet.py ├── benchmark_unet.py ├── transform_stable_diffusion_unet.py └── test_plugin.py ├── README.md ├── .gitignore └── LICENSE /fastldm/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name='fastldm', 5 | version='0.2.2', 6 | description='', 7 | packages=['fastldm'], 8 | install_requires=[ 9 | 'torch', 10 | 'numpy', 11 | 'tqdm', 12 | 'flash-attn', 13 | 'triton' 14 | ], 15 | ) 16 | -------------------------------------------------------------------------------- /fastldm/environ.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | PLUGINS = [] 4 | if 'PLUGINS' in os.environ: 5 | for path in os.environ['PLUGINS'].split(','): 6 | PLUGINS.append(path) 7 | ONNX_ONLY = False if 'ONNX_ONLY' not in os.environ or not eval(os.environ['ONNX_ONLY']) else True 8 | DISABLE_FASTLDM = False if 'DISABLE_FASTLDM' not in os.environ or not eval(os.environ['DISABLE_FASTLDM']) else True 9 | print('DISABLE_FASTLDM', DISABLE_FASTLDM) 10 | print('ONNX_ONLY', ONNX_ONLY) 11 | print('PLUGINS', PLUGINS) 12 | TRT_PATH = None if 'TRT_PATH' not in os.environ else os.environ['TRT_PATH'] 13 | TRT_NUM_WORKER = 1 if 'TRT_NUM_WORKER' not in os.environ else int(os.environ['TRT_NUM_WORKER']) -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | For `transform_stable_diffusion_unet.py`, use this fork: https://github.com/1049451037/stable-diffusion 2 | 3 | We recommend to run scripts using CUDA_VISIBLE_DEVICES=0,1 (or any other devices not less than 2), because TensorRT will constantly use cuda:0. We should leave as much space as possible for TensorRT, so we'd better use cuda:1 in pytorch. 4 | 5 | The engine will speed up 4x to the original torch module, and 2.8x to the autocast context for RTX 3090. You can run `benchmark_unet.py` to benchmark it. 6 | 7 | For `transform_diffusers_unet.py`, you may meet [this issue](https://github.com/pytorch/pytorch/issues/93937). It's a pytorch bug so we don't provide detailed plugin example here. If you want to use attention plugins, operations are similar to `transform_stable_diffusion_unet.py`. 8 | -------------------------------------------------------------------------------- /fastldm/registry.py: -------------------------------------------------------------------------------- 1 | class Registry: 2 | def __init__(self, name): 3 | self.name = name 4 | self.member = {} 5 | def register(self, src_type, dst_type): 6 | def func(f): 7 | self.member[(src_type, dst_type)] = f 8 | return f 9 | return func 10 | def get(self, src, dst): 11 | if not isinstance(src, type): 12 | src = type(src) 13 | if not isinstance(dst, type): 14 | dst = type(dst) 15 | if (src, dst) not in self.member: 16 | return self.member[(type(None), type(None))] 17 | return self.member[(src, dst)] 18 | def transform(self, src, dst): 19 | return self.get(src, dst)(src, dst) 20 | def __repr__(self): 21 | return 'Registry: ' + self.name + " " + str(self.member) -------------------------------------------------------------------------------- /examples/benchmark_flash_attn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from fastldm.modules import FlashCrossAttnWG, ldmCrossAttnWG 3 | model = ldmCrossAttnWG(768, 768, 8, 64).half().cuda() 4 | model_flash = FlashCrossAttnWG(768, 768, 8, 64).half().cuda() 5 | from fastldm.mapping import MAPPING 6 | MAPPING.transform(model, model_flash) 7 | x = torch.randn(4, 128, 768).half().cuda() 8 | context = torch.randn(4, 88, 768).half().cuda() 9 | mask = (torch.randn(4, 8, 128, 88) > 0.1).cuda() 10 | from fastldm.experiment import experiment 11 | from fastldm.benchmark import benchmark_backward 12 | 13 | measure, var, outputs = experiment([model, model_flash], [], (x, context, mask)) 14 | print(measure) 15 | print(var) 16 | measure, var, outputs = experiment([model, model_flash], [], (x, context, mask), forward_only=False) 17 | print(measure) 18 | print(var) 19 | measure, var, outputs = experiment([model, model_flash], [], (x, context, mask), benchmark_func=benchmark_backward) 20 | print(measure) 21 | print(var) 22 | breakpoint() -------------------------------------------------------------------------------- /examples/transform_diffusers_unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers import UNet2DConditionModel 3 | 4 | unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", 5 | torch_dtype=torch.float16, 6 | revision="fp16", 7 | subfolder="unet") 8 | unet.cuda(1) 9 | inputs = torch.randn(2, 4, 64, 64, dtype=torch.half, device='cuda:1'), torch.tensor([1, 3], dtype=torch.int32, device='cuda:1'), torch.randn(2, 77, 768, dtype=torch.half, device='cuda:1') 10 | 11 | import fastldm.modules as fm 12 | from fastldm.modifier import modify, MODIFIER 13 | map_dict = { 14 | torch.nn.LayerNorm: fm.LayerNorm, 15 | torch.nn.GroupNorm: fm.GroupNorm, 16 | } 17 | unet = modify(unet, map_dict) 18 | 19 | from fastldm.experiment import experiment_onnx, experiment_trt 20 | from fastldm.environ import ONNX_ONLY 21 | if ONNX_ONLY: 22 | measure, var, outputs = experiment_onnx(unet, inputs) 23 | else: 24 | measure, var, outputs = experiment_trt(unet, inputs) 25 | 26 | print(var) 27 | breakpoint() 28 | for i in range(len(outputs[type(unet).__name__])): 29 | out_model = outputs[type(unet).__name__][i].cpu() 30 | out_ort = outputs['TRTModule'][i].cpu() 31 | 32 | from fastldm.helper import profile_outdiff 33 | measure = profile_outdiff(out_model, out_ort) 34 | # print(measure) 35 | import pprint 36 | pp = pprint.PrettyPrinter(indent=4) 37 | pp.pprint(measure) 38 | breakpoint() -------------------------------------------------------------------------------- /examples/benchmark_unet.py: -------------------------------------------------------------------------------- 1 | from omegaconf import OmegaConf 2 | config = OmegaConf.load("stable-diffusion/configs/stable-diffusion/v1-inference.yaml") 3 | config = config.model.params.unet_config 4 | # from ldm.util import instantiate_from_config 5 | # unet = instantiate_from_config(config) 6 | import torch 7 | from ldm.modules.diffusionmodules.openaimodel import UNetModel 8 | 9 | unet = UNetModel(**config.params) 10 | ckpt = torch.load('stable-diffusion/models/ldm/stable-diffusion-v1/model.ckpt', map_location='cpu') 11 | state_dict = {} 12 | for k in ckpt['state_dict']: 13 | if 'model.diffusion_model.' in k: 14 | state_dict[k[len('model.diffusion_model.'):]] = ckpt['state_dict'][k] 15 | unet.load_state_dict(state_dict) 16 | model = unet.cuda(1) 17 | input_0 = torch.randn(6, 4, 64, 64, dtype=torch.float32).cuda(1) 18 | input_1 = torch.tensor([1, 3, 7, 8, 9, 23], dtype=torch.int32).cuda(1) 19 | input_2 = torch.randn(6, 77, 768, dtype=torch.float32).cuda(1) 20 | 21 | from fastldm.benchmark import benchmark, benchmark_trt 22 | time_origin, outputs_origin = benchmark(model, (input_0, input_1, input_2), {}, 100) 23 | time_ac, outputs_ac = benchmark(model, (input_0, input_1, input_2), {}, 100, use_autocast=True) 24 | time_trt, outputs_trt = benchmark_trt('trt/UNetModel_ONNX_ONLY_False_.trt', (input_0, input_1, input_2), 100) 25 | 26 | print('time origin', time_origin) 27 | print('time autocast', time_ac) 28 | print('time tensorrt', time_trt) 29 | 30 | from fastldm.helper import profile_outdiff 31 | measure = profile_outdiff(outputs_origin, outputs_trt['output_0']) 32 | import pprint 33 | pp = pprint.PrettyPrinter(indent=4) 34 | pp.pprint(measure) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FastLDM 2 | 3 | We focus on inference speed-up for [ldm](https://github.com/CompVis/stable-diffusion). 4 | 5 | ## TensorRT 6 | 7 | We are now trying to convert model of stable-diffusion into an optimized TensorRT engine. See `examples/` folder for use cases. 8 | 9 | When we transform a PyTorch model to TensorRT engine, we usually follow "torch->onnx->trt" footprint. It is convenient for most cases, except when we want to use [plugins](https://github.com/NVIDIA/TensorRT/tree/main/plugin). If we want to use plugins, we usually need to modify the exported onnx file based on [TensorRT official demo](https://github.com/NVIDIA/TensorRT/tree/main/demo/Diffusion). The core idea of this repo is to provide a shortcut. We provide many plugin-friendly PyTorch modules, which can plug into your normal PyTorch model and export onnx directly able to be optimized by TensorRT with plugins. 10 | 11 | The advantages of this repo are: 12 | 13 | * Seamlessly transform from PyTorch to TensorRT with plugins. 14 | * Seamlessly export onnx to run directly in onnxruntime. 15 | 16 | More specifically, we provide: 17 | 18 | * parameter mapping functions from a module to another one in utils.mapping and utils.modifier 19 | * TensorRT-friendly modules in utils.plugins and utils.modules 20 | * benchmarking functions in utils.benchmark 21 | * experiment functions to compare different implemented models with same interface 22 | 23 | ## Install 24 | 25 | ``` 26 | pip install -e . 27 | ``` 28 | 29 | or 30 | 31 | ``` 32 | pip install git+https://github.com/THUDM/FastLDM.git 33 | ``` 34 | 35 | ## Environment 36 | 37 | We rely on the docker environment `nvcr.io/nvidia/pytorch:22.12-py3`. (For newer version of TensorRT, this repo may not work because of removement of plugins.) 38 | -------------------------------------------------------------------------------- /fastldm/mapping.py: -------------------------------------------------------------------------------- 1 | from .registry import Registry 2 | import torch 3 | 4 | MAPPING = Registry('mapping') 5 | 6 | @MAPPING.register(type(None), type(None)) 7 | def mapping_default(src, dst): 8 | for s, d in zip(src.parameters(), dst.parameters()): 9 | d.data = s.data 10 | 11 | from .modules import qkvLinearSlow, qkvLinear 12 | @MAPPING.register(qkvLinearSlow, qkvLinear) 13 | def mapping_qkv_linear(src, dst): 14 | wq = src.Wq.weight.data 15 | bq = src.Wq.bias.data 16 | wk = src.Wk.weight.data 17 | bk = src.Wk.bias.data 18 | wv = src.Wv.weight.data 19 | bv = src.Wv.bias.data 20 | wqkv = torch.cat([wq, wk, wv]).view(3, src.num_heads, src.size_per_head, src.hidden_size).transpose(0, 1).contiguous().view(3*src.hidden_size, src.hidden_size) 21 | bqkv = torch.cat([bq, bk, bv]).view(3, src.num_heads, src.size_per_head).transpose(0, 1).contiguous().view(3*src.hidden_size) 22 | dst.Wqkv.weight.data = wqkv 23 | dst.Wqkv.bias.data = bqkv 24 | 25 | 26 | from torch.nn import MultiheadAttention 27 | @MAPPING.register(qkvLinearSlow, MultiheadAttention) 28 | def mapping_qkv_linear_mha(src, dst): 29 | dst.in_proj_weight.data = torch.cat([src.Wq.weight.data, src.Wk.weight.data, src.Wv.weight.data]) 30 | dst.in_proj_bias.data = torch.cat([src.Wq.bias.data, src.Wk.bias.data, src.Wv.bias.data]) 31 | dst.out_proj.weight.data = torch.eye(dst.out_proj.weight.data.shape[0], dtype=dst.out_proj.weight.data.dtype, device=dst.out_proj.weight.data.device) 32 | dst.out_proj.bias.data = torch.zeros(dst.out_proj.bias.data.shape[0], dtype=dst.out_proj.bias.data.dtype, device=dst.out_proj.bias.data.device) 33 | 34 | from torch.nn import Conv2d 35 | from .modules import LinearConv 36 | @MAPPING.register(Conv2d, LinearConv) 37 | def mapping_linear_conv(src, dst): 38 | dst.linear.weight.data = src.weight.data[...,0,0] 39 | dst.linear.bias.data = src.bias.data -------------------------------------------------------------------------------- /examples/transform_stable_diffusion_unet.py: -------------------------------------------------------------------------------- 1 | from omegaconf import OmegaConf 2 | config = OmegaConf.load("stable-diffusion/configs/stable-diffusion/v1-inference.yaml") 3 | config = config.model.params.unet_config 4 | # from ldm.util import instantiate_from_config 5 | # unet = instantiate_from_config(config) 6 | import torch 7 | from ldm.modules.diffusionmodules.openaimodel import UNetModel 8 | 9 | unet = UNetModel(**config.params) 10 | ckpt = torch.load('stable-diffusion/models/ldm/stable-diffusion-v1/model.ckpt', map_location='cpu') 11 | state_dict = {} 12 | for k in ckpt['state_dict']: 13 | if 'model.diffusion_model.' in k: 14 | state_dict[k[len('model.diffusion_model.'):]] = ckpt['state_dict'][k] 15 | unet.load_state_dict(state_dict) 16 | model = unet.cuda(1) 17 | input_0 = torch.randn(6, 4, 64, 64, dtype=torch.float32).cuda(1) 18 | input_1 = torch.tensor([1, 3, 7, 8, 9, 23], dtype=torch.int32).cuda(1) 19 | input_2 = torch.randn(6, 77, 768, dtype=torch.float32).cuda(1) 20 | 21 | import fastldm.modules as fm 22 | from fastldm.modifier import modify, MODIFIER, post_transform 23 | from ldm.modules.diffusionmodules.util import GroupNorm32 24 | MODIFIER.register(GroupNorm32, fm.GroupNorm)(MODIFIER.get(torch.nn.GroupNorm, fm.GroupNorm)) 25 | 26 | from ldm.modules.attention import CrossAttention 27 | @MODIFIER.register(CrossAttention, fm.ldmCrossAttn) 28 | def modifier_crossattn(src_instance, dst_type): 29 | dst_instance = dst_type(src_instance.to_q.weight.shape[1], context_dim=src_instance.to_k.weight.shape[1], heads=src_instance.heads, dim_head=src_instance.to_q.weight.shape[0]//src_instance.heads) 30 | return post_transform(src_instance, dst_instance) 31 | 32 | MODIFIER.register(CrossAttention, fm.ldmSelfAttn)(MODIFIER.get(CrossAttention, fm.ldmCrossAttn)) 33 | 34 | map_dict = { 35 | torch.nn.LayerNorm: fm.LayerNorm, 36 | torch.nn.GroupNorm: fm.GroupNorm, 37 | GroupNorm32: fm.GroupNorm, 38 | CrossAttention: {'attn1': fm.ldmSelfAttn, 'attn2': fm.ldmCrossAttn} 39 | } 40 | model = modify(model, map_dict) 41 | 42 | from fastldm.experiment import experiment_onnx, experiment_trt 43 | from fastldm.environ import ONNX_ONLY 44 | if ONNX_ONLY: 45 | measure, var, outputs = experiment_onnx(model, (input_0, input_1, input_2)) 46 | else: 47 | measure, var, outputs = experiment_trt(model, (input_0, input_1, input_2)) 48 | 49 | print('output maximum absolute error:', var) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.engine 2 | *.onnx 3 | *.trt 4 | *.ts 5 | *.pt 6 | *.pth 7 | onnx/ 8 | trt/ 9 | 10 | # Byte-compiled / optimized / DLL files 11 | __pycache__/ 12 | *.py[cod] 13 | *$py.class 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | pip-wheel-metadata/ 33 | share/python-wheels/ 34 | *.egg-info/ 35 | .installed.cfg 36 | *.egg 37 | MANIFEST 38 | 39 | # PyInstaller 40 | # Usually these files are written by a python script from a template 41 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 42 | *.manifest 43 | *.spec 44 | 45 | # Installer logs 46 | pip-log.txt 47 | pip-delete-this-directory.txt 48 | 49 | # Unit test / coverage reports 50 | htmlcov/ 51 | .tox/ 52 | .nox/ 53 | .coverage 54 | .coverage.* 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | *.cover 59 | *.py,cover 60 | .hypothesis/ 61 | .pytest_cache/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | local_settings.py 70 | db.sqlite3 71 | db.sqlite3-journal 72 | 73 | # Flask stuff: 74 | instance/ 75 | .webassets-cache 76 | 77 | # Scrapy stuff: 78 | .scrapy 79 | 80 | # Sphinx documentation 81 | docs/_build/ 82 | 83 | # PyBuilder 84 | target/ 85 | 86 | # Jupyter Notebook 87 | .ipynb_checkpoints 88 | 89 | # IPython 90 | profile_default/ 91 | ipython_config.py 92 | 93 | # pyenv 94 | .python-version 95 | 96 | # pipenv 97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 100 | # install all needed dependencies. 101 | #Pipfile.lock 102 | 103 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 104 | __pypackages__/ 105 | 106 | # Celery stuff 107 | celerybeat-schedule 108 | celerybeat.pid 109 | 110 | # SageMath parsed files 111 | *.sage.py 112 | 113 | # Environments 114 | .env 115 | .venv 116 | env/ 117 | venv/ 118 | ENV/ 119 | env.bak/ 120 | venv.bak/ 121 | 122 | # Spyder project settings 123 | .spyderproject 124 | .spyproject 125 | 126 | # Rope project settings 127 | .ropeproject 128 | 129 | # mkdocs documentation 130 | /site 131 | 132 | # mypy 133 | .mypy_cache/ 134 | .dmypy.json 135 | dmypy.json 136 | 137 | # Pyre type checker 138 | .pyre/ 139 | -------------------------------------------------------------------------------- /fastldm/modifier.py: -------------------------------------------------------------------------------- 1 | from .registry import Registry 2 | import torch 3 | from .mapping import MAPPING 4 | 5 | MODIFIER = Registry('modifier') 6 | 7 | def modify(model, map_dict, name=None): 8 | if type(model) in map_dict: 9 | if type(map_dict[type(model)]) is dict: 10 | key = None if name not in map_dict[type(model)] else name 11 | target = map_dict[type(model)][key] 12 | else: 13 | target = map_dict[type(model)] 14 | return MODIFIER.transform(model, target) 15 | names = set() 16 | cache = {} 17 | for name, child in model.named_children(): 18 | if name not in names: 19 | new_child = modify(child, map_dict, name=name) 20 | cache[child] = new_child 21 | setattr(model, name, new_child) 22 | names.add(name) 23 | flag = True 24 | while flag: 25 | flag = False 26 | for name, child in model.named_children(): 27 | if name not in names: 28 | setattr(model, name, cache[child]) 29 | names.add(name) 30 | flag = True 31 | return model 32 | 33 | @MODIFIER.register(type(None), type(None)) 34 | def modifier_default(src_instance, dst_type): 35 | if type(src_instance) is dst_type: 36 | return src_instance 37 | raise Exception("Related modifier from {} to {} is not implemented yet".format(type(src_instance).__name__, dst_type.__name__)) 38 | 39 | def post_transform(src_instance, dst_instance): 40 | param = src_instance.parameters().__next__() 41 | dst_instance = dst_instance.to(param.dtype).to(param.device) 42 | MAPPING.transform(src_instance, dst_instance) 43 | return dst_instance 44 | 45 | from .modules import qkvLinearSlow, qkvLinear 46 | @MODIFIER.register(qkvLinearSlow, qkvLinear) 47 | def modifier_qkv_linear(src_instance, dst_type): 48 | dst_instance = dst_type(src_instance.hidden_size, src_instance.num_heads) 49 | return post_transform(src_instance, dst_instance) 50 | 51 | from .modules import GroupNorm 52 | @MODIFIER.register(torch.nn.GroupNorm, GroupNorm) 53 | def modifier_groupnorm(src_instance, dst_type): 54 | dst_instance = dst_type(src_instance.num_groups, src_instance.num_channels, src_instance.eps) 55 | return post_transform(src_instance, dst_instance) 56 | 57 | from .modules import NewLayerNorm 58 | @MODIFIER.register(torch.nn.LayerNorm, NewLayerNorm) 59 | def modifier_layernorm(src_instance, dst_type): 60 | dst_instance = dst_type(src_instance.normalized_shape[0], src_instance.eps) 61 | return post_transform(src_instance, dst_instance) 62 | 63 | from .modules import LayerNorm 64 | MODIFIER.register(torch.nn.LayerNorm, LayerNorm)(MODIFIER.get(torch.nn.LayerNorm, NewLayerNorm)) 65 | -------------------------------------------------------------------------------- /examples/test_plugin.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from fastldm.mapping import MAPPING 4 | from fastldm.experiment import generate_trt, experiment 5 | 6 | def experiment_self_attn(): 7 | from fastldm.modules import qkvLinearSlow, TRTSelfAttn, FlashSelfAttn, TorchSelfAttn 8 | x = torch.randn(512, 2, 768).cuda().half() # seq_len 128 bug 9 | lin = qkvLinearSlow(768, 8) 10 | model = TRTSelfAttn(768, 8) 11 | MAPPING.get(lin, model.projection)(lin, model.projection) 12 | model = model.cuda().half() 13 | flash = FlashSelfAttn(768, 8) 14 | MAPPING.get(lin, flash.projection)(lin, flash.projection) 15 | flash = flash.cuda().half() 16 | th_model = TorchSelfAttn(768, 8) 17 | MAPPING.get(lin, th_model.mha)(lin, th_model.mha) 18 | th_model = th_model.cuda().half() 19 | _, trt_name = generate_trt(model, (x,)) 20 | measure_dict, var, outputs_dict = experiment([model, flash, th_model], [trt_name], (x,)) 21 | for k in measure_dict: 22 | print(k, measure_dict[k]) 23 | print(var) 24 | 25 | def experiment_ldm_attn(seq_len): 26 | from fastldm.modules import ldmSelfAttn, ldmCrossAttn 27 | from ldm.modules.attention import CrossAttention 28 | x = torch.randn(2, seq_len, 320).half().cuda() 29 | # model = ldmSelfAttn(320, heads=8, dim_head=40).half() 30 | model = ldmCrossAttn(320, heads=8, dim_head=40).half() 31 | # ldm_model = CrossAttention(320, heads=8, dim_head=40).half() 32 | # for src, dst in zip(model.parameters(), ldm_model.parameters()): 33 | # dst.data = src.data 34 | model = model.cuda() 35 | # ldm_model = ldm_model.cuda() 36 | _, trt_name = generate_trt(model, (x,)) 37 | measure_dict, var, outputs_dict = experiment([model], [trt_name], (x,)) 38 | # for k in measure_dict: 39 | # print(k, measure_dict[k]) 40 | # print(var) 41 | return var[0, 1] 42 | 43 | def experiment_flash_attn(): 44 | from fastldm.modules import ldmSelfAttn 45 | from ldm.modules.attention import CrossAttention 46 | x = torch.randn(2, 4096, 320).half().cuda() 47 | model = ldmSelfAttn(320, heads=8, dim_head=40).half() 48 | ldm_model = CrossAttention(320, heads=8, dim_head=40).half() 49 | MAPPING.get(None, None)(model, ldm_model) 50 | model = model.cuda() 51 | ldm_model = ldm_model.cuda() 52 | _, trt_name = generate_trt(model, (x,)) 53 | measure_dict, var, outputs_dict = experiment([model, ldm_model], [trt_name], (x,)) 54 | for k in measure_dict: 55 | print(k, measure_dict[k]) 56 | print(var) 57 | 58 | def experiment_ldm_crossattn(): 59 | from fastldm.modules import ldmCrossAttn 60 | from ldm.modules.attention import CrossAttention 61 | x = torch.randn(2, 4096, 320).cuda() 62 | context = torch.randn(2, 77, 768).cuda() 63 | model = ldmCrossAttn(320, context_dim=768, heads=8, dim_head=40) 64 | ldm_model = CrossAttention(320, context_dim=768, heads=8, dim_head=40) 65 | MAPPING.get(None, None)(model, ldm_model) 66 | model = model.cuda() 67 | ldm_model = ldm_model.cuda() 68 | _, trt_name = generate_trt(model, (x, context)) 69 | measure_dict, var, outputs_dict = experiment([model, ldm_model], [trt_name], (x, context)) 70 | for k in measure_dict: 71 | print(k, measure_dict[k]) 72 | print(var) 73 | 74 | def experiment_var(): 75 | x = [] 76 | y = [] 77 | for l in range(64, 1024+1, 64): 78 | x.append(l) 79 | v = experiment_ldm_attn(l) 80 | y.append(v) 81 | import matplotlib.pyplot as plt 82 | plt.scatter(x, y) 83 | plt.savefig('var.png') 84 | 85 | 86 | if __name__ == '__main__': 87 | # experiment_self_attn() 88 | # experiment_var() 89 | # experiment_ldm_crossattn() 90 | experiment_flash_attn() -------------------------------------------------------------------------------- /fastldm/experiment.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.onnx import OperatorExportTypes 3 | from .benchmark import benchmark, benchmark_trt 4 | import numpy as np 5 | from collections import OrderedDict 6 | import torch 7 | from .helper import ORTModule, TRTModule, list_or_tuple 8 | from .environ import ONNX_ONLY, PLUGINS 9 | from collections import Counter 10 | 11 | PLUGIN_config = ' '.join(['--plugins={}'.format(p) for p in PLUGINS]) 12 | 13 | def generate_trt(model, inputs, kw_args={}, experiment_name='', onnx_only=False, use_fp16=True, skip=False, onnx_hook=None): 14 | os.makedirs('./onnx/', exist_ok=True) 15 | os.makedirs('./trt/', exist_ok=True) 16 | name = '{}_ONNX_ONLY_{}_{}'.format(type(model).__name__, ONNX_ONLY, experiment_name) 17 | onnx_path = 'onnx/{}.onnx'.format(name) 18 | trt_path = 'trt/{}.trt'.format(name) 19 | if skip: 20 | if onnx_only: 21 | return onnx_path 22 | return onnx_path, trt_path 23 | model.eval() 24 | with torch.no_grad(): 25 | outputs = model(*inputs, **kw_args) 26 | num_output = 1 if not list_or_tuple(outputs) else len(outputs) 27 | os.system("rm onnx/{}.onnx".format(name)) 28 | torch.onnx.export(model, tuple(inputs)+(kw_args,), onnx_path, operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH, input_names=['input_{}'.format(i) for i in range(len(inputs)+len(kw_args))], output_names=['output_{}'.format(i) for i in range(num_output)]) 29 | if onnx_hook is not None: 30 | onnx_hook(onnx_path) 31 | if onnx_only: 32 | return onnx_path 33 | os.system("rm {}".format(trt_path)) 34 | os.system("trtexec --onnx={} --saveEngine={} --buildOnly {} {}".format(onnx_path, trt_path, '--fp16' if use_fp16 else '', PLUGIN_config)) 35 | return onnx_path, trt_path 36 | 37 | def experiment(models, trt_models, inputs, kw_args={}, n_iter=100, warm_up_step=5, forward_only=True, benchmark_func=benchmark): 38 | measure_dict = OrderedDict() 39 | outputs_dict = OrderedDict() 40 | name_count = Counter() 41 | len_out = 0 42 | for model in models: 43 | name = type(model).__name__ + str(name_count[type(model).__name__]) 44 | measure, outputs = benchmark_func(model, inputs, kw_args, n_iter, warmup_step=warm_up_step, forward_only=forward_only) 45 | if not list_or_tuple(outputs): 46 | outputs = [outputs] 47 | measure_dict[name] = measure 48 | outputs_dict[name] = outputs 49 | len_out = len(outputs) 50 | name_count[type(model).__name__] += 1 51 | inputs_h = {'input_{}'.format(i): inputs[i] for i in range(len(inputs))} 52 | shift = len(inputs) 53 | for k in kw_args: 54 | inputs_h['input_{}'.format(shift)] = kw_args[k] 55 | shift += 1 56 | for trt in trt_models: 57 | measure, outputs = benchmark_trt(trt, inputs_h, n_iter, warmup_step=warm_up_step) 58 | measure_dict[trt] = measure 59 | outputs = [outputs['output_{}'.format(i)] for i in range(len(outputs))] 60 | outputs_dict[trt] = outputs 61 | len_out = len(outputs) 62 | num = len(outputs_dict) 63 | var = np.zeros((len_out, num, num)) 64 | for i, k1 in enumerate(outputs_dict): 65 | for j, k2 in enumerate(outputs_dict): 66 | if i == j: 67 | continue 68 | for k, (o1, o2) in enumerate(zip(outputs_dict[k1], outputs_dict[k2])): 69 | var[k, i, j] = (o1.cpu().type(torch.float32)-o2.cpu().type(torch.float32)).abs().max().item() 70 | return measure_dict, var, outputs_dict 71 | 72 | def experiment_onnx(model, inputs, kw_args={}, new_inputs=None, new_kw_args=None): 73 | onnx_path = generate_trt(model, inputs, kw_args=kw_args, onnx_only=True) 74 | ortmodel = ORTModule(onnx_path) 75 | if new_inputs is None: 76 | new_inputs = inputs 77 | if new_kw_args is None: 78 | new_kw_args = kw_args 79 | return experiment([model, ortmodel], [], new_inputs, kw_args=new_kw_args, n_iter=1, warm_up_step=0) 80 | 81 | def experiment_trt(model, inputs, kw_args={}, new_inputs=None, new_kw_args=None, use_fp16=True, skip=False, onnx_hook=None): 82 | _, trt_path = generate_trt(model, inputs, kw_args=kw_args, use_fp16=use_fp16, skip=skip, onnx_hook=onnx_hook) 83 | trtmodel = TRTModule(trt_path, 1) 84 | if new_inputs is None: 85 | new_inputs = inputs 86 | if new_kw_args is None: 87 | new_kw_args = kw_args 88 | return experiment([model, trtmodel], [], new_inputs, kw_args=new_kw_args, n_iter=1, warm_up_step=0) 89 | 90 | def experiment_onnx_trt(model, inputs, kw_args={}, new_inputs=None, new_kw_args=None): 91 | onnx_path, trt_path = generate_trt(model, inputs, kw_args=kw_args) 92 | ortmodel = ORTModule(onnx_path) 93 | trtmodel = TRTModule(trt_path) 94 | if new_inputs is None: 95 | new_inputs = inputs 96 | if new_kw_args is None: 97 | new_kw_args = kw_args 98 | return experiment([model, ortmodel, trtmodel], [], new_inputs, kw_args=new_kw_args, n_iter=1, warm_up_step=0) -------------------------------------------------------------------------------- /fastldm/benchmark.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import numpy as np 4 | from tqdm import tqdm 5 | from contextlib import nullcontext 6 | from .helper import list_or_tuple 7 | 8 | def benchmark(model, inputs, kwargs, n_iter, func_name=None, warmup_step=5, use_autocast=False, forward_only=True): 9 | if hasattr(model, 'eval') and callable(model.eval): 10 | model.eval() 11 | print('eval mode...') 12 | if func_name is not None: 13 | func = getattr(model, func_name) 14 | else: 15 | func = model 16 | print('start warming up...') 17 | grad = torch.no_grad() if forward_only else nullcontext() 18 | context = torch.autocast("cuda") if use_autocast else nullcontext() 19 | with grad: 20 | with context: 21 | for i in tqdm(range(warmup_step)): 22 | outputs = func(*inputs, **kwargs) 23 | assert outputs is not None 24 | print('start timing...') 25 | time_list = [] 26 | with grad: 27 | with context: 28 | for i in tqdm(range(n_iter)): 29 | torch.cuda.synchronize() 30 | start_time = time.time() 31 | outputs = func(*inputs, **kwargs) 32 | torch.cuda.synchronize() 33 | end_time = time.time() 34 | assert outputs is not None 35 | time_list.append(end_time - start_time) 36 | times = np.array(time_list) 37 | measurements = {'average': times.mean(), 'min': times.min()} 38 | return measurements, outputs 39 | 40 | def benchmark_backward(model, inputs, kwargs, n_iter, func_name=None, warmup_step=5, use_autocast=False, forward_only=False): 41 | if hasattr(model, 'eval') and callable(model.eval): 42 | model.eval() 43 | print('eval mode...') 44 | if func_name is not None: 45 | func = getattr(model, func_name) 46 | else: 47 | func = model 48 | print('start warming up...') 49 | grad = nullcontext() 50 | context = torch.autocast("cuda") if use_autocast else nullcontext() 51 | with grad: 52 | with context: 53 | for i in tqdm(range(warmup_step)): 54 | func.zero_grad() 55 | outputs = func(*inputs, **kwargs) 56 | assert outputs is not None 57 | if list_or_tuple(outputs): 58 | outputs = outputs[0] 59 | loss = outputs.sum() 60 | loss.backward() 61 | print('start timing...') 62 | time_list = [] 63 | with grad: 64 | with context: 65 | for i in tqdm(range(n_iter)): 66 | func.zero_grad() 67 | outputs = func(*inputs, **kwargs) 68 | assert outputs is not None 69 | if list_or_tuple(outputs): 70 | outputs = outputs[0] 71 | loss = outputs.sum() 72 | torch.cuda.synchronize() 73 | start_time = time.time() 74 | loss.backward() 75 | torch.cuda.synchronize() 76 | end_time = time.time() 77 | time_list.append(end_time - start_time) 78 | times = np.array(time_list) 79 | measurements = {'average': times.mean(), 'min': times.min()} 80 | return measurements, func.parameters().__next__().grad 81 | 82 | from .helper import get_trt_stuff 83 | from .helper import load_engine 84 | def benchmark_trt(engine_path, inputs_h, n_iter, warmup_step=5): 85 | engine = load_engine(engine_path) 86 | context, bindings, inputs_dict, outputs_dict = get_trt_stuff(engine) 87 | if list_or_tuple(inputs_h): 88 | inputs_h = {'input_{}'.format(i):x for i, x in enumerate(inputs_h)} 89 | for k in inputs_h: 90 | inputs_dict[k].copy_(inputs_h[k]) 91 | stream = torch.cuda.default_stream() 92 | def func(): 93 | state = context.execute_async_v2(bindings=bindings, stream_handle=stream.cuda_stream) 94 | stream.synchronize() 95 | return state 96 | measurement, state = benchmark(func, (), {}, n_iter, warmup_step=warmup_step) 97 | assert state is True 98 | return measurement, outputs_dict 99 | 100 | 101 | def benchmark_trt_np(engine_path, inputs_dict, n_iter, warmup_step=5): 102 | import tensorrt as trt 103 | import pycuda.driver as cuda 104 | import pycuda.autoinit 105 | engine = load_engine(engine_path) 106 | context = engine.create_execution_context() 107 | inputs_h = inputs_dict 108 | inputs_d = {} 109 | outputs_h = {} 110 | outputs_d = {} 111 | bindings = [] 112 | for binding in engine: 113 | binding_idx = engine.get_binding_index(binding) 114 | # size = trt.volume(context.get_binding_shape(binding_idx)) 115 | dtype = trt.nptype(engine.get_binding_dtype(binding)) 116 | if engine.binding_is_input(binding): 117 | inputs_d[binding] = cuda.mem_alloc(inputs_h[binding].nbytes) 118 | bindings.append(int(inputs_d[binding])) 119 | else: 120 | shape = tuple(context.get_binding_shape(binding_idx)) 121 | outputs_h[binding] = cuda.pagelocked_empty(shape, dtype) 122 | outputs_d[binding] = cuda.mem_alloc(outputs_h[binding].nbytes) 123 | bindings.append(int(outputs_d[binding])) 124 | stream = cuda.Stream() 125 | for k in inputs_h: 126 | cuda.memcpy_htod_async(inputs_d[k], inputs_h[k], stream) 127 | stream.synchronize() 128 | def func(): 129 | state = context.execute_async_v2(bindings=bindings, stream_handle=stream.handle) 130 | stream.synchronize() 131 | return state 132 | # measurement, _ = benchmark(context, (), {'bindings': bindings, 'stream_handle':stream.handle}, n_iter, func_name='execute_async_v2', warmup_step=warmup_step) 133 | measurement, state = benchmark(func, (), {}, n_iter, warmup_step=warmup_step) 134 | assert state is True 135 | for k in outputs_d: 136 | cuda.memcpy_dtoh_async(outputs_h[k], outputs_d[k], stream) 137 | stream.synchronize() 138 | return measurement, outputs_h -------------------------------------------------------------------------------- /fastldm/plugins.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from .environ import ONNX_ONLY 4 | 5 | class BaseApply: 6 | @classmethod 7 | def apply(cls, *inputs, **kw_args): 8 | return cls.forward(None, *inputs, **kw_args) 9 | 10 | BasePlugin = BaseApply if ONNX_ONLY else torch.autograd.Function 11 | 12 | class CustomQKVToContextPluginDynamic(BasePlugin): 13 | # https://github.com/NVIDIA/TensorRT/tree/release/8.5/plugin/bertQKVToContextPlugin 14 | # Restriction of this plugin: https://github.com/NVIDIA/TensorRT/issues/2653 15 | @staticmethod 16 | def forward(ctx, qkv, hidden_size, num_heads, type_id): 17 | # Now I get qkv with shape (seq_len, batch_size, 3*hidden_size, 1, 1) 18 | # I need to do attention to it 19 | # The layout of 3*hidden_size dimension is (num_heads, 3, size_per_head) 20 | size_per_head = hidden_size // num_heads 21 | seq_len = qkv.size(0) 22 | batch_size = qkv.size(1) 23 | qkv = qkv.view(seq_len, batch_size, num_heads, 3, size_per_head).transpose(0, 2) 24 | # q, k, v = torch.chunk(qkv, 3, dim=3) 25 | q = qkv.select(-2, 0) 26 | k = qkv.select(-2, 1) 27 | v = qkv.select(-2, 2) # (num_heads, batch_size, seq_len, size_per_head) 28 | scores = torch.matmul(q, k.transpose(-2, -1)) * (size_per_head**-0.5) 29 | scores = F.softmax(scores, -1) 30 | result = torch.matmul(scores, v).transpose(0, 2).contiguous().view(seq_len, batch_size, hidden_size, 1, 1) 31 | return result 32 | @staticmethod 33 | def symbolic(g, qkv, hidden_size, num_heads, type_id): 34 | return g.op("CustomQKVToContextPluginDynamic", qkv, plugin_version_s='1', type_id_i=type_id, hidden_size_i=hidden_size, num_heads_i=num_heads, has_mask_i=False) 35 | 36 | class fMHCA(BasePlugin): 37 | # https://github.com/NVIDIA/TensorRT/tree/release/8.5/plugin/multiHeadCrossAttentionPlugin 38 | # Potential issue of this plugin: https://github.com/NVIDIA/TensorRT/issues/2674 39 | @staticmethod 40 | def forward(ctx, q, kv): 41 | """ 42 | q: (batch_size, seq_len, num_head, size_per_head) 43 | kv: (batch_size, seq_len, num_head, 2, size_per_head) 44 | output: like q 45 | """ 46 | size_per_head = q.size(3) 47 | batch_size = q.size(0) 48 | seq_len = q.size(1) 49 | num_head = q.size(2) 50 | q = q.transpose(1, 2).contiguous() 51 | kv = kv.transpose(1, 2).contiguous() 52 | k = kv.select(-2, 0) 53 | v = kv.select(-2, 1) 54 | scores = torch.matmul(q, k.transpose(-2, -1)) * (size_per_head**-0.5) 55 | scores = F.softmax(scores, -1) 56 | result = torch.matmul(scores, v).transpose(1, 2).contiguous().view(batch_size, seq_len, num_head, size_per_head) 57 | return result 58 | @staticmethod 59 | def symbolic(g, q, kv): 60 | return g.op("fMHCA", q, kv, plugin_version_s='1') 61 | 62 | class fMHA_V2(BasePlugin): 63 | # https://github.com/NVIDIA/TensorRT/tree/release/8.5/plugin/multiHeadFlashAttentionPlugin 64 | @staticmethod 65 | def forward(ctx, qkv): 66 | """ 67 | qkv: (batch_size, seq_len, num_head, 3, size_per_head) 68 | output: (batch_size, seq_len, num_head, size_per_head) 69 | """ 70 | size_per_head = qkv.size(4) 71 | batch_size = qkv.size(0) 72 | seq_len = qkv.size(1) 73 | num_head = qkv.size(2) 74 | qkv = qkv.transpose(1, 2).contiguous() 75 | q = qkv.select(-2, 0) 76 | k = qkv.select(-2, 1) 77 | v = qkv.select(-2, 2) 78 | scores = torch.matmul(q, k.transpose(-2, -1)) * (size_per_head**-0.5) 79 | scores = F.softmax(scores, -1) 80 | result = torch.matmul(scores, v).transpose(1, 2).contiguous().view(batch_size, seq_len, num_head, size_per_head) 81 | return result 82 | @staticmethod 83 | def symbolic(g, qkv): 84 | return g.op("fMHA_V2", qkv, plugin_version_s='1') 85 | 86 | class GroupNormalizationPlugin(BasePlugin): 87 | # https://github.com/NVIDIA/TensorRT/tree/release/8.5/plugin/groupNormalizationPlugin 88 | @staticmethod 89 | def forward(ctx, x, scale, bias, num_groups, eps): 90 | return F.group_norm(x, num_groups, weight=scale, bias=bias, eps=eps) 91 | @staticmethod 92 | def symbolic(g, x, scale, bias, num_groups, eps): 93 | return g.op("GroupNormalizationPlugin", x, scale, bias, plugin_version_s='1', eps_f=eps, num_groups_i=num_groups) 94 | 95 | class LayerNormPlugin(BasePlugin): 96 | # https://github.com/NVIDIA/TensorRT/tree/release/8.5/plugin/layerNormPlugin 97 | # Potential issue of this plugin: https://github.com/NVIDIA/TensorRT/issues/2707 98 | @staticmethod 99 | def forward(ctx, x, scale, bias, channels, eps): 100 | return F.layer_norm(x, [channels], scale, bias, eps) 101 | @staticmethod 102 | def symbolic(g, x, scale, bias, channels, eps): 103 | return g.op("LayerNorm", x, scale, bias, plugin_version_s='1', epsilon_f=eps, axis_i=-1) 104 | 105 | class NewLayerNormPlugin(BasePlugin): 106 | # https://github.com/1049451037/tensorrt-layernorm-plugin 107 | @staticmethod 108 | def forward(ctx, x, scale, bias, channels, eps): 109 | return F.layer_norm(x, [channels], scale, bias, eps) 110 | @staticmethod 111 | def symbolic(g, x, scale, bias, channels, eps): 112 | return g.op("LayerNormalizationPlugin", x, scale, bias, plugin_version_s='1', eps_f=eps) 113 | 114 | class MHCAWG(BaseApply): 115 | @staticmethod 116 | def forward(ctx, q, kv, bias): 117 | """ 118 | q: (batch_size, seq_len, num_head, size_per_head) 119 | kv: (batch_size, seq_len, num_head, 2, size_per_head) 120 | output: like q 121 | """ 122 | size_per_head = q.size(3) 123 | batch_size = q.size(0) 124 | seq_len = q.size(1) 125 | num_head = q.size(2) 126 | q = q.transpose(1, 2).contiguous() 127 | kv = kv.transpose(1, 2).contiguous() 128 | k = kv.select(-2, 0) 129 | v = kv.select(-2, 1) 130 | scores = torch.matmul(q, k.transpose(-2, -1)) * (size_per_head**-0.5) + bias 131 | scores = F.softmax(scores, -1) 132 | result = torch.matmul(scores, v).transpose(1, 2).contiguous().view(batch_size, seq_len, num_head, size_per_head) 133 | return result -------------------------------------------------------------------------------- /fastldm/helper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import tensorrt as trt 5 | import ctypes 6 | from collections import defaultdict 7 | 8 | def list_or_tuple(x): 9 | return isinstance(x, (list, tuple)) 10 | 11 | TRT_LOGGER = trt.Logger() 12 | trt.init_libnvinfer_plugins(TRT_LOGGER, '') 13 | from .environ import PLUGINS 14 | for path in PLUGINS: 15 | ctypes.cdll.LoadLibrary(path) 16 | 17 | def load_engine(engine_file_path): 18 | assert os.path.exists(engine_file_path) 19 | print("Reading engine from file {}".format(engine_file_path)) 20 | with open(engine_file_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime: 21 | return runtime.deserialize_cuda_engine(f.read()) 22 | 23 | from torch.testing._internal.common_utils import numpy_to_torch_dtype_dict 24 | numpy_to_torch_dtype_dict[bool] = torch.bool 25 | def get_trt_stuff(engine): 26 | context = engine.create_execution_context() 27 | inputs_dict = {} 28 | outputs_dict = {} 29 | bindings = [] 30 | for binding in engine: 31 | binding_idx = engine.get_binding_index(binding) 32 | # size = trt.volume(context.get_binding_shape(binding_idx)) 33 | dtype = trt.nptype(engine.get_binding_dtype(binding)) 34 | shape = tuple(context.get_binding_shape(binding_idx)) 35 | if engine.binding_is_input(binding): 36 | inputs_dict[binding] = torch.empty(*shape, dtype=numpy_to_torch_dtype_dict[dtype], device='cuda') 37 | bindings.append(int(inputs_dict[binding].data_ptr())) 38 | else: 39 | outputs_dict[binding] = torch.empty(*shape, dtype=numpy_to_torch_dtype_dict[dtype], device='cuda') 40 | bindings.append(int(outputs_dict[binding].data_ptr())) 41 | return context, bindings, inputs_dict, outputs_dict 42 | 43 | class TRTModule(nn.Module): 44 | def __init__(self, engine_path, num_worker): 45 | """ 46 | Only support running engine on cuda:0 for now 47 | """ 48 | super().__init__() 49 | self.num_worker = num_worker 50 | self.engine = load_engine(engine_path) 51 | self.context = [] 52 | self.bindings = [] 53 | self.inputs_dict = [] 54 | self.outputs_dict = [] 55 | self.stream = [] 56 | for i in range(num_worker): 57 | context, bindings, inputs_dict, outputs_dict = get_trt_stuff(self.engine) 58 | self.context.append(context) 59 | self.bindings.append(bindings) 60 | self.inputs_dict.append(inputs_dict) 61 | self.outputs_dict.append(outputs_dict) 62 | self.stream.append(torch.cuda.Stream(0)) 63 | def move_to_engine(self, inputs_h): 64 | for i in range(self.num_worker): 65 | for k in inputs_h: 66 | self.inputs_dict[i][k].copy_(inputs_h[k][i]) 67 | torch.cuda.default_stream().synchronize() 68 | def run_engine(self): 69 | for context, bindings, stream in zip(self.context, self.bindings, self.stream): 70 | state = context.execute_async_v2(bindings=bindings, stream_handle=stream.cuda_stream) 71 | if not state: 72 | raise Exception("trt engine execution failed") 73 | for stream in self.stream: 74 | stream.synchronize() 75 | def merge_output(self): 76 | final_outputs_dict = defaultdict(list) 77 | for outputs_dict in self.outputs_dict: 78 | for k in outputs_dict: 79 | final_outputs_dict[k].append(outputs_dict[k]) 80 | for k in final_outputs_dict: 81 | final_outputs_dict[k] = torch.cat(final_outputs_dict[k]) 82 | return final_outputs_dict 83 | def forward(self, *inputs, **kw_args): 84 | inputs_h = {} 85 | device = 'cpu' 86 | for i, inp in enumerate(inputs): 87 | inputs_h['input_{}'.format(i)] = torch.chunk(inp, self.num_worker) 88 | device = inp.device 89 | shift = len(inputs) 90 | for k in kw_args: 91 | inputs_h['input_{}'.format(shift)] = torch.chunk(kw_args[k], self.num_worker) 92 | shift += 1 93 | self.move_to_engine(inputs_h) 94 | self.run_engine() 95 | outputs_dict = self.merge_output() 96 | outputs = [] 97 | for i in range(len(outputs_dict)): 98 | outputs.append(outputs_dict['output_{}'.format(i)].to(device)) 99 | if len(outputs) == 1: 100 | outputs = outputs[0] 101 | return outputs 102 | 103 | import onnxruntime as ort 104 | 105 | def get_ort_stuff(onnx_path, providers): 106 | return ort.InferenceSession(onnx_path, providers=providers) 107 | 108 | class ORTModule(nn.Module): 109 | def __init__(self, onnx_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider']): 110 | super().__init__() 111 | self.sess = get_ort_stuff(onnx_path, providers) 112 | def forward(self, *inputs, **kw_args): 113 | device = 'cpu' 114 | for inp in inputs: 115 | device = inp.device 116 | for k in kw_args: 117 | device = kw_args[k].device 118 | inputs_dict = {'input_{}'.format(i):x.cpu().numpy() if isinstance(x, torch.Tensor) else x for i, x in enumerate(inputs)} 119 | shift = len(inputs_dict) 120 | for k in kw_args: 121 | inputs_dict['input_{}'.format(shift)] = kw_args[k].cpu().numpy() 122 | shift += 1 123 | outputs = self.sess.run(None, inputs_dict) 124 | outputs = [torch.from_numpy(x).to(device) for x in outputs] 125 | if len(outputs) == 1: 126 | outputs = outputs[0] 127 | return outputs 128 | 129 | from collections import OrderedDict 130 | 131 | def profile_matrix(m): 132 | old_m = m 133 | m = torch.nan_to_num(m) 134 | mi = m.min() 135 | ma = m.max() 136 | mm = m.mean() 137 | profile = OrderedDict([ 138 | ("has_nan", not (old_m==m).all()), 139 | ("shape", m.shape), 140 | ("min", (mi, (m==mi).nonzero()[0])), 141 | ("max", (ma, (m==ma).nonzero()[0])), 142 | ("mean", (mm, )) 143 | ]) 144 | return profile 145 | 146 | def profile_outdiff(o1, o2): 147 | o1 = o1.cpu().type(torch.float32) 148 | o2 = o2.cpu().type(torch.float32) 149 | error = (o1-o2).abs() 150 | measure_dict = OrderedDict([ 151 | ("o1", profile_matrix(o1)), 152 | ("o2", profile_matrix(o2)), 153 | ("|o1|", profile_matrix(o1.abs())), 154 | ("|o2|", profile_matrix(o2.abs())), 155 | ("absolute error", profile_matrix(error)), 156 | ("relative error", profile_matrix(error / torch.max(o1.abs(), o2.abs()))), 157 | ("norm relative error", profile_matrix(error / o1.abs().mean())) 158 | ]) 159 | for i in ["absolute error", "relative error", "norm relative error"]: 160 | for j in ["min", "max"]: 161 | ind = measure_dict[i][j][-1] 162 | measure_dict[i][j] += (o1[tuple(ind)], o2[tuple(ind)]) 163 | return measure_dict 164 | 165 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2023 Qingsong Lv 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /fastldm/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops import rearrange 4 | 5 | class qkvLinearSlow(nn.Module): 6 | def __init__(self, hidden_size, num_heads): 7 | super().__init__() 8 | assert hidden_size % num_heads == 0 9 | self.hidden_size = hidden_size 10 | self.num_heads = num_heads 11 | self.size_per_head = hidden_size // num_heads 12 | self.Wq = nn.Linear(self.hidden_size, self.hidden_size) 13 | self.Wk = nn.Linear(self.hidden_size, self.hidden_size) 14 | self.Wv = nn.Linear(self.hidden_size, self.hidden_size) 15 | 16 | def forward(self, x): 17 | Q = self.Wq(x) 18 | K = self.Wk(x) 19 | V = self.Wv(x) 20 | qkv = torch.cat([Q, K, V], dim=2) 21 | qkv = qkv.view(x.size(0), x.size(1), 3, self.num_heads, self.size_per_head) 22 | qkv = qkv.transpose(2, 3).contiguous().view(x.size(0), x.size(1), 3*self.hidden_size, 1, 1) 23 | return qkv 24 | 25 | class qkvLinear(nn.Module): 26 | def __init__(self, hidden_size, num_heads): 27 | super().__init__() 28 | assert hidden_size % num_heads == 0 29 | self.hidden_size = hidden_size 30 | self.num_heads = num_heads 31 | self.size_per_head = hidden_size // num_heads 32 | self.Wqkv = nn.Linear(hidden_size, 3*hidden_size) 33 | 34 | def forward(self, x): 35 | return self.Wqkv(x).view(x.size(0), x.size(1), 3*self.hidden_size, 1, 1) 36 | 37 | from .plugins import CustomQKVToContextPluginDynamic 38 | class TRTSelfAttn(nn.Module): 39 | def __init__(self, hidden_size, num_heads): 40 | super().__init__() 41 | assert hidden_size % num_heads == 0 42 | self.hidden_size = hidden_size 43 | self.num_heads = num_heads 44 | self.projection = qkvLinear(hidden_size, num_heads) 45 | def forward(self, x): 46 | # shape of x (seq_len, batch_size, hidden_size) 47 | # shape of i_mask (batch_size) 48 | # output (seq_len, batch_size, hidden_size) 49 | qkv = self.projection(x) 50 | type_id = 0 if qkv.dtype == torch.float32 else 1 51 | return CustomQKVToContextPluginDynamic.apply(qkv, self.hidden_size, self.num_heads, type_id).select(-1, 0).select(-1, 0) 52 | 53 | from flash_attn.flash_attn_interface import _flash_attn_forward 54 | class FlashSelfAttn(nn.Module): 55 | def __init__(self, hidden_size, num_heads): 56 | super().__init__() 57 | assert hidden_size % num_heads == 0 58 | self.hidden_size = hidden_size 59 | self.num_heads = num_heads 60 | self.size_per_head = hidden_size // num_heads 61 | self.scale = 1 / self.size_per_head**0.5 62 | self.projection = qkvLinear(hidden_size, num_heads) 63 | def forward(self, x): 64 | # shape of x (seq_len, batch_size, hidden_size) 65 | # shape of i_mask (batch_size) 66 | # output (seq_len, batch_size, hidden_size) 67 | seq_len = x.size(0) 68 | batch_size = x.size(1) 69 | qkv = self.projection(x).view(seq_len, batch_size, self.num_heads, 3, self.size_per_head).transpose(0, 1).contiguous().view(seq_len*batch_size, self.num_heads, 3, self.size_per_head) 70 | q = qkv.select(-2, 0) 71 | k = qkv.select(-2, 1) 72 | v = qkv.select(-2, 2) 73 | cu_seqlen = torch.arange(start=0, end=batch_size*seq_len, step=seq_len, dtype=torch.int32, device=q.device) 74 | max_seqlen = seq_len 75 | out = torch.empty_like(v) 76 | return _flash_attn_forward(q, k, v, out, cu_seqlen, cu_seqlen, max_seqlen, max_seqlen, 0., self.scale, False, False)[0].view(seq_len, batch_size, self.hidden_size) 77 | 78 | from flash_attn.modules.mha import FlashSelfAttention 79 | class FlashSelfAttnWG(nn.Module): 80 | def __init__(self, query_dim, heads=8, dim_head=64, dropout=0.): 81 | super().__init__() 82 | inner_dim = dim_head * heads 83 | context_dim = query_dim 84 | 85 | self.scale = dim_head ** -0.5 86 | self.heads = heads 87 | 88 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 89 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 90 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 91 | 92 | self.to_out = nn.Sequential( 93 | nn.Linear(inner_dim, query_dim), 94 | nn.Dropout(dropout) 95 | ) 96 | self.flash = FlashSelfAttention(softmax_scale=self.scale) 97 | 98 | def forward(self, x, context=None, mask=None, bias=None): 99 | assert context is None and mask is None and bias is None 100 | h = self.heads 101 | context = x 102 | 103 | inputx_shape = x.size() 104 | if len(x.shape) == 3 and len(context.shape) == 3: # normal 105 | pass 106 | elif len(x.shape) == 4 and len(context.shape) == 3: 107 | # BLOCK x but shared context for cross attention 108 | x = rearrange(x, 'b k n d -> b (k n) d') 109 | elif len(x.shape) == 4 and len(context.shape) == 4: 110 | # BLOCK attention 111 | assert x.shape[1] == context.shape[1], 'num of BLOCK not the same' 112 | # b dim later will be "batch * nblk * nheads" 113 | # assert not exists(mask) and not exists(bias), 'not implemented yet' 114 | else: 115 | raise ValueError(f'x shape: {x.shape} , context shape: {context.shape}') 116 | 117 | q = self.to_q(x) 118 | k = self.to_k(context) 119 | v = self.to_v(context) 120 | 121 | q, k, v = map(lambda t: rearrange(t, '... n (h d) -> (...) n h d', h=h), (q, k, v)) 122 | 123 | out = self.flash(torch.stack([q, k, v], dim=2)) 124 | 125 | out = rearrange(out, 'b n h d -> b n (h d)', h=h).view(inputx_shape) 126 | return self.to_out(out) 127 | 128 | from flash_attn.flash_attn_triton import FlashAttnFunc 129 | class FlashCrossAttnWG(nn.Module): 130 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): 131 | super().__init__() 132 | inner_dim = dim_head * heads 133 | assert context_dim is not None 134 | 135 | self.scale = dim_head ** -0.5 136 | self.heads = heads 137 | 138 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 139 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 140 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 141 | 142 | self.to_out = nn.Sequential( 143 | nn.Linear(inner_dim, query_dim), 144 | nn.Dropout(dropout) 145 | ) 146 | 147 | def forward(self, x, context=None, mask=None): 148 | h = self.heads 149 | assert context is not None 150 | 151 | q = self.to_q(x) 152 | k = self.to_k(context) 153 | v = self.to_v(context) 154 | 155 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q, k, v)) 156 | 157 | sim = torch.zeros(q.shape[0], q.shape[2], q.shape[1], k.shape[1], dtype=q.dtype, device=q.device) 158 | max_neg_value = -torch.finfo(sim.dtype).max / 2 159 | sim = sim.masked_fill_(~mask, max_neg_value) 160 | sim = sim.view(q.shape[0], q.shape[2], q.shape[1], k.shape[1]) 161 | 162 | out = FlashAttnFunc.apply(q, k, v, sim, False, self.scale) 163 | 164 | out = rearrange(out, 'b n h d -> b n (h d)') 165 | return self.to_out(out) 166 | 167 | from .plugins import MHCAWG 168 | class ldmCrossAttnWG(nn.Module): 169 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): 170 | super().__init__() 171 | inner_dim = dim_head * heads 172 | assert context_dim is not None 173 | 174 | self.scale = dim_head ** -0.5 175 | self.heads = heads 176 | 177 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 178 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 179 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 180 | 181 | self.to_out = nn.Sequential( 182 | nn.Linear(inner_dim, query_dim), 183 | nn.Dropout(dropout) 184 | ) 185 | 186 | def forward(self, x, context=None, mask=None): 187 | """ 188 | x: (batch_size, seq_len_q, query_dim) 189 | context: (batch_size, seq_len_kv, context_dim) 190 | out: (batch_size, seq_len, seq_len_q, query_dim) 191 | """ 192 | h = self.heads 193 | 194 | q = self.to_q(x) 195 | assert context is not None 196 | k = self.to_k(context) 197 | v = self.to_v(context) 198 | 199 | q = rearrange(q, 'b n (h d) -> b n h d', h=h) 200 | k = rearrange(k, 'b n (h d) -> b n h 1 d', h=h) 201 | v = rearrange(v, 'b n (h d) -> b n h 1 d', h=h) 202 | kv = torch.cat([k, v], dim=-2) 203 | 204 | sim = torch.zeros(q.shape[0], q.shape[2], q.shape[1], k.shape[1], dtype=q.dtype, device=q.device) 205 | max_neg_value = -torch.finfo(sim.dtype).max / 2 206 | sim = sim.masked_fill_(~mask, max_neg_value) 207 | sim = sim.view(q.shape[0], q.shape[2], q.shape[1], k.shape[1]) 208 | 209 | out = MHCAWG.apply(q, kv, sim) 210 | out = rearrange(out, 'b n h d -> b n (h d)', h=h) 211 | return self.to_out(out) 212 | 213 | from torch.nn import MultiheadAttention 214 | class TorchSelfAttn(nn.Module): 215 | def __init__(self, hidden_size, num_heads): 216 | super().__init__() 217 | self.mha = MultiheadAttention(hidden_size, num_heads) 218 | def forward(self, x): 219 | return self.mha(x, x, x)[0] 220 | 221 | from .plugins import fMHCA 222 | class ldmCrossAttn(nn.Module): 223 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=40): 224 | super().__init__() 225 | inner_dim = dim_head * heads 226 | if context_dim is None: 227 | context_dim = query_dim 228 | 229 | self.scale = dim_head ** -0.5 230 | self.heads = heads 231 | 232 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 233 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 234 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 235 | 236 | self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim)) 237 | 238 | def forward(self, x, context=None): 239 | """ 240 | x: (batch_size, seq_len_q, query_dim) 241 | context: (batch_size, seq_len_kv, context_dim) 242 | out: (batch_size, seq_len, seq_len_q, query_dim) 243 | """ 244 | h = self.heads 245 | 246 | q = self.to_q(x) 247 | if context is None: 248 | context = x 249 | k = self.to_k(context) 250 | v = self.to_v(context) 251 | 252 | q = rearrange(q, 'b n (h d) -> b n h d', h=h) 253 | k = rearrange(k, 'b n (h d) -> b n h 1 d', h=h) 254 | v = rearrange(v, 'b n (h d) -> b n h 1 d', h=h) 255 | kv = torch.cat([k, v], dim=-2) 256 | 257 | out = fMHCA.apply(q, kv) 258 | out = rearrange(out, 'b n h d -> b n (h d)', h=h) 259 | return self.to_out(out) 260 | 261 | from .plugins import fMHA_V2 262 | class ldmSelfAttn(nn.Module): 263 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=40): 264 | super().__init__() 265 | inner_dim = dim_head * heads 266 | if context_dim is None: 267 | context_dim = query_dim 268 | 269 | self.scale = dim_head ** -0.5 270 | self.heads = heads 271 | 272 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 273 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 274 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 275 | 276 | self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim)) 277 | 278 | def forward(self, x, context=None): 279 | """ 280 | x: (batch_size, seq_len_q, query_dim) 281 | context: (batch_size, seq_len_kv, context_dim) 282 | out: (batch_size, seq_len, seq_len_q, query_dim) 283 | """ 284 | h = self.heads 285 | 286 | q = self.to_q(x) 287 | if context is None: 288 | context = x 289 | k = self.to_k(context) 290 | v = self.to_v(context) 291 | 292 | q = rearrange(q, 'b n (h d) -> b n h 1 d', h=h) 293 | k = rearrange(k, 'b n (h d) -> b n h 1 d', h=h) 294 | v = rearrange(v, 'b n (h d) -> b n h 1 d', h=h) 295 | qkv = torch.cat([q, k, v], dim=-2) 296 | 297 | out = fMHA_V2.apply(qkv) 298 | out = rearrange(out, 'b n h d -> b n (h d)', h=h) 299 | return self.to_out(out) 300 | 301 | class LinearConv(nn.Module): 302 | def __init__(self, in_dim, out_dim): 303 | super().__init__() 304 | self.linear = nn.Linear(in_dim, out_dim) 305 | def forward(self, x): 306 | return self.linear(x.transpose(1, -1)).transpose(1, -1).contiguous() 307 | 308 | from .plugins import GroupNormalizationPlugin 309 | class GroupNorm(nn.Module): 310 | def __init__(self, num_groups, num_channels, eps, affine=True): 311 | super().__init__() 312 | assert num_channels % num_groups == 0 313 | self.eps = eps 314 | self.num_groups = num_groups 315 | self.weight = nn.Parameter(torch.ones(num_channels)) 316 | self.bias = nn.Parameter(torch.zeros(num_channels)) 317 | def forward(self, x): 318 | return GroupNormalizationPlugin.apply(x, self.weight, self.bias, self.num_groups, self.eps) 319 | 320 | from .plugins import LayerNormPlugin 321 | class LayerNorm(nn.Module): 322 | def __init__(self, channels, eps=1e-5): 323 | super().__init__() 324 | self.eps = eps 325 | self.channels = channels 326 | self.weight = nn.Parameter(torch.ones(channels)) 327 | self.bias = nn.Parameter(torch.zeros(channels)) 328 | def forward(self, x): 329 | return LayerNormPlugin.apply(x, self.weight, self.bias, self.channels, self.eps) 330 | 331 | from .plugins import NewLayerNormPlugin 332 | class NewLayerNorm(nn.Module): 333 | def __init__(self, channels, eps=1e-5): 334 | super().__init__() 335 | self.eps = eps 336 | self.channels = channels 337 | self.weight = nn.Parameter(torch.ones(channels)) 338 | self.bias = nn.Parameter(torch.zeros(channels)) 339 | def forward(self, x): 340 | return NewLayerNormPlugin.apply(x, self.weight, self.bias, self.channels, self.eps) 341 | --------------------------------------------------------------------------------