├── LICENSE ├── README.md ├── add_turbo ├── README.md ├── diffaug.py ├── discriminator.py ├── dnnlib │ ├── __init__.py │ └── util.py ├── shared.py ├── torch_utils │ ├── __init__.py │ ├── custom_ops.py │ ├── misc.py │ └── ops │ │ ├── __init__.py │ │ ├── bias_act.cpp │ │ ├── bias_act.cu │ │ ├── bias_act.h │ │ ├── bias_act.py │ │ ├── conv2d_gradfix.py │ │ ├── conv2d_resample.py │ │ ├── filtered_lrelu.cpp │ │ ├── filtered_lrelu.cu │ │ ├── filtered_lrelu.h │ │ ├── filtered_lrelu.py │ │ ├── filtered_lrelu_ns.cu │ │ ├── filtered_lrelu_rd.cu │ │ ├── filtered_lrelu_wr.cu │ │ ├── fma.py │ │ ├── grid_sample_gradfix.py │ │ ├── upfirdn2d.cpp │ │ ├── upfirdn2d.cu │ │ ├── upfirdn2d.h │ │ └── upfirdn2d.py ├── train_add.py └── vit_utils.py ├── mobile_diffusion_unet ├── README.md ├── __init__.py ├── distill_training_md.py ├── mdAttention.py ├── mdBasicTransformerBlock.py ├── mdCADownBlock.py ├── mdCAUpBlock.py ├── mdDownBlock2D.py ├── mdSepResnetBlock2D.py ├── mdTransformer2DModel.py ├── mdUnet.py └── mdUpBlock2D.py ├── progressive_distillation_for_sd ├── README.md └── train_text_to_image_pd.py └── requirements.txt /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 MFaceTech 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Acceleration for Stable Diffusion 2 | 3 | ## Overview 4 | In this project, we provide a reproduction of acceleration schemes, such as Progressive Distillation[Progressive Distillation for Fast Sampling of Diffusion Models](https://arxiv.org/abs/2202.00512), MobileDiffusion[MobileDiffusion: Subsecond Text-to-Image Generation on Mobile Devices](https://arxiv.org/abs/2311.16567) and SDXL-Turbo[Adversarial Diffusion Distillation](https://arxiv.org/abs/2311.17042), etc. 5 | 6 | ## 🔥 Update 7 | - **[2024.3.6]** We release the first version of the code, containing progressive distillation, mobile diffusion unet and add turbo. 8 | 9 | ## Setup 10 | - **Installation:** 11 | We recommend python version >= 3.8 and cuda version >= 11.4. And install the packages mentioned in the requirements.txt: 12 | ```bash 13 | pip install -r requirements.txt 14 | ``` 15 | 16 | ## Acknowledgements 17 | This project build upon 18 | - [diffusers](https://github.com/huggingface/diffusers) 19 | - [stylegan-t](https://github.com/autonomousvision/stylegan-t) 20 | - [distill-sd](https://github.com/segmind/distill-sd) 21 | -------------------------------------------------------------------------------- /add_turbo/README.md: -------------------------------------------------------------------------------- 1 | ## ADD(Turbo) Training Code Reproduction 2 | ### Project Description 3 | This repository contains our reproduced version of the ADD(Turbo) training code, which is based on the stylegan-t codebase. 4 | The fundamental structure of the code is aligned with the descriptions provided in the original research paper. 5 | 6 | ### Development Status 7 | Currently, the code is in an experimental phase. Users should be aware that there are known issues regarding the convergence 8 | of the GAN (Generative Adversarial Network) component. 9 | 10 | ### Usage 11 | #### Training 12 | ```shell 13 | export MODEL_NAME="xxx" 14 | export MODEL_PATH="xxx" 15 | export CACHE_PATH="xxx" 16 | export OUTPUT_PATH="xxx" 17 | 18 | accelerate launch --mixed_precision="bf16" train_add.py \ 19 | --pretrained_model_name_or_path=$MODEL_PATH \ 20 | --train_data_dir="" \ 21 | --cache_dir=$CACHE_PATH \ 22 | --resolution=512 --center_crop --random_flip \ 23 | --train_batch_size=16 \ 24 | --gradient_accumulation_steps=1 \ 25 | --gradient_checkpointing \ 26 | --checkpointing_steps=500 \ 27 | --max_train_steps=20000 \ 28 | --snr_gamma=5.0 \ 29 | --blur_fade_kimg=500 \ 30 | --learning_rate_G=0.00001 \ 31 | --learning_rate_D=0.00001 \ 32 | --max_grad_norm=1 \ 33 | --lr_scheduler="constant" --lr_warmup_steps=0 \ 34 | --output_dir=$OUTPUT_PATH 35 | 36 | 37 | ``` 38 | 39 | ### Known Issues 40 | Convergence: The GAN part of the code is experiencing convergence problems. We are actively working on debugging and 41 | refining the training process to achieve stable and reliable results. 42 | ### Usage Warning 43 | Given that the code is experimental, it is recommended for research and development purposes only. We advise against using 44 | this code for production environments until the convergence issues have been resolved. 45 | 46 | ### Contribution 47 | We welcome contributions from the community to help resolve the current issues with the GAN convergence. Collaboration is key 48 | to advancing the field, and we appreciate any insights or improvements that can be shared. 49 | 50 | ### Acknowledgments 51 | Our work builds upon the innovative ideas presented in the ADD(Turbo) paper and the stylegan-t codebase. We acknowledge the original authors and contributors to these foundational resources. 52 | -------------------------------------------------------------------------------- /add_turbo/diffaug.py: -------------------------------------------------------------------------------- 1 | # BSD 2-Clause "Simplified" License 2 | # Copyright (c) 2020, Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han 3 | # All rights reserved. 4 | # 5 | # Redistribution and use in source and binary forms, with or without 6 | # modification, are permitted provided that the following conditions are met: 7 | # 8 | # * Redistributions of source code must retain the above copyright notice, this 9 | # list of conditions and the following disclaimer. 10 | # 11 | # * Redistributions in binary form must reproduce the above copyright notice, 12 | # this list of conditions and the following disclaimer in the documentation 13 | # and/or other materials provided with the distribution. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 16 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 19 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 20 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 21 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 22 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 23 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 24 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | # 26 | # Code from https://github.com/mit-han-lab/data-efficient-gans 27 | 28 | """Training GANs with DiffAugment.""" 29 | 30 | import numpy as np 31 | import torch 32 | import torch.nn.functional as F 33 | 34 | 35 | def DiffAugment(x: torch.Tensor, policy: str = '', channels_first: bool = True) -> torch.Tensor: 36 | if policy: 37 | if not channels_first: 38 | x = x.permute(0, 3, 1, 2) 39 | for p in policy.split(','): 40 | for f in AUGMENT_FNS[p]: 41 | x = f(x) 42 | if not channels_first: 43 | x = x.permute(0, 2, 3, 1) 44 | x = x.contiguous() 45 | return x 46 | 47 | 48 | def rand_brightness(x: torch.Tensor) -> torch.Tensor: 49 | x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5) 50 | return x 51 | 52 | 53 | def rand_saturation(x: torch.Tensor) -> torch.Tensor: 54 | x_mean = x.mean(dim=1, keepdim=True) 55 | x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean 56 | return x 57 | 58 | 59 | def rand_contrast(x: torch.Tensor) -> torch.Tensor: 60 | x_mean = x.mean(dim=[1, 2, 3], keepdim=True) 61 | x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean 62 | return x 63 | 64 | 65 | def rand_translation(x: torch.Tensor, ratio: float = 0.125) -> torch.Tensor: 66 | shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) 67 | translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device) 68 | translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device) 69 | grid_batch, grid_x, grid_y = torch.meshgrid( 70 | torch.arange(x.size(0), dtype=torch.long, device=x.device), 71 | torch.arange(x.size(2), dtype=torch.long, device=x.device), 72 | torch.arange(x.size(3), dtype=torch.long, device=x.device), 73 | ) 74 | grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1) 75 | grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1) 76 | x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0]) 77 | x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2) 78 | return x 79 | 80 | 81 | def rand_cutout(x: torch.Tensor, ratio: float = 0.2) -> torch.Tensor: 82 | cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) 83 | offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device) 84 | offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device) 85 | grid_batch, grid_x, grid_y = torch.meshgrid( 86 | torch.arange(x.size(0), dtype=torch.long, device=x.device), 87 | torch.arange(cutout_size[0], dtype=torch.long, device=x.device), 88 | torch.arange(cutout_size[1], dtype=torch.long, device=x.device), 89 | ) 90 | grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1) 91 | grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1) 92 | mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device) 93 | mask[grid_batch, grid_x, grid_y] = 0 94 | x = x * mask.unsqueeze(1) 95 | return x 96 | 97 | 98 | def rand_resize(x: torch.Tensor, min_ratio: float = 0.8, max_ratio: float = 1.2) -> torch.Tensor: 99 | resize_ratio = np.random.rand()*(max_ratio-min_ratio) + min_ratio 100 | resized_img = F.interpolate(x, size=int(resize_ratio*x.shape[3]), mode='bilinear') 101 | org_size = x.shape[3] 102 | if int(resize_ratio*x.shape[3]) < x.shape[3]: 103 | left_pad = (x.shape[3]-int(resize_ratio*x.shape[3]))/2. 104 | left_pad = int(left_pad) 105 | right_pad = x.shape[3] - left_pad - resized_img.shape[3] 106 | x = F.pad(resized_img, (left_pad, right_pad, left_pad, right_pad), "constant", 0.) 107 | else: 108 | left = (int(resize_ratio*x.shape[3])-x.shape[3])/2. 109 | left = int(left) 110 | x = resized_img[:, :, left:(left+x.shape[3]), left:(left+x.shape[3])] 111 | assert x.shape[2] == org_size 112 | assert x.shape[3] == org_size 113 | return x 114 | 115 | 116 | AUGMENT_FNS = { 117 | 'color': [rand_brightness, rand_saturation, rand_contrast], 118 | 'translation': [rand_translation], 119 | 'resize': [rand_resize], 120 | 'cutout': [rand_cutout], 121 | } 122 | -------------------------------------------------------------------------------- /add_turbo/discriminator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """ 10 | Projected discriminator architecture from 11 | "StyleGAN-T: Unlocking the Power of GANs for Fast Large-Scale Text-to-Image Synthesis". 12 | """ 13 | 14 | import numpy as np 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | from torch.nn.utils.spectral_norm import SpectralNorm 19 | from torchvision.transforms import RandomCrop, Normalize 20 | import timm 21 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 22 | 23 | from shared import ResidualBlock, FullyConnectedLayer 24 | from vit_utils import make_vit_backbone, forward_vit 25 | from diffaug import DiffAugment 26 | 27 | 28 | class SpectralConv1d(nn.Conv1d): 29 | def __init__(self, *args, **kwargs): 30 | super().__init__(*args, **kwargs) 31 | SpectralNorm.apply(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12) 32 | 33 | 34 | class BatchNormLocal(nn.Module): 35 | def __init__(self, num_features: int, affine: bool = True, virtual_bs: int = 8, eps: float = 1e-5): 36 | super().__init__() 37 | self.virtual_bs = virtual_bs 38 | self.eps = eps 39 | self.affine = affine 40 | 41 | if self.affine: 42 | self.weight = nn.Parameter(torch.ones(num_features)) 43 | self.bias = nn.Parameter(torch.zeros(num_features)) 44 | 45 | def forward(self, x: torch.Tensor) -> torch.Tensor: 46 | shape = x.size() 47 | 48 | # Reshape batch into groups. 49 | G = np.ceil(x.size(0) / self.virtual_bs).astype(int) 50 | x = x.view(G, -1, x.size(-2), x.size(-1)) 51 | 52 | # Calculate stats. 53 | mean = x.mean([1, 3], keepdim=True) 54 | var = x.var([1, 3], keepdim=True, unbiased=False) 55 | x = (x - mean) / (torch.sqrt(var + self.eps)) 56 | 57 | if self.affine: 58 | x = x * self.weight[None, :, None] + self.bias[None, :, None] 59 | 60 | return x.view(shape) 61 | 62 | 63 | def make_block(channels: int, kernel_size: int) -> nn.Module: 64 | return nn.Sequential( 65 | SpectralConv1d( 66 | channels, 67 | channels, 68 | kernel_size=kernel_size, 69 | padding=kernel_size // 2, 70 | padding_mode='circular', 71 | ), 72 | BatchNormLocal(channels), 73 | nn.LeakyReLU(0.2, True), 74 | ) 75 | 76 | 77 | class DiscHead(nn.Module): 78 | def __init__(self, channels: int, c_dim: int, cmap_dim: int = 64): 79 | super().__init__() 80 | self.channels = channels 81 | self.c_dim = c_dim 82 | self.cmap_dim = cmap_dim 83 | 84 | self.main = nn.Sequential( 85 | make_block(channels, kernel_size=1), 86 | ResidualBlock(make_block(channels, kernel_size=9)) 87 | ) 88 | 89 | if self.c_dim > 0: 90 | self.cmapper = FullyConnectedLayer(self.c_dim, cmap_dim) 91 | self.cls = SpectralConv1d(channels, cmap_dim, kernel_size=1, padding=0) 92 | else: 93 | self.cls = SpectralConv1d(channels, 1, kernel_size=1, padding=0) 94 | 95 | def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: 96 | h = self.main(x) 97 | out = self.cls(h) 98 | if self.c_dim > 0: 99 | cmap = self.cmapper(c).unsqueeze(-1) 100 | out = (out * cmap).sum(1, keepdim=True) * (1 / np.sqrt(self.cmap_dim)) 101 | return out 102 | 103 | 104 | class DINO(torch.nn.Module): 105 | def __init__(self, hooks: list[int] = [2, 5, 8, 11], hook_patch: bool = True): 106 | super().__init__() 107 | self.n_hooks = len(hooks) + int(hook_patch) 108 | 109 | self.model = make_vit_backbone( 110 | timm.create_model('vit_small_patch16_224_dino', pretrained=True), 111 | patch_size=[16, 16], hooks=hooks, hook_patch=hook_patch, 112 | ) 113 | self.model = self.model.eval().requires_grad_(False) 114 | 115 | self.img_resolution = self.model.model.patch_embed.img_size[0] 116 | self.embed_dim = self.model.model.embed_dim 117 | self.norm = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) 118 | 119 | def forward(self, x: torch.Tensor) -> torch.Tensor: 120 | ''' input: x in [0, 1]; output: dict of activations ''' 121 | x = F.interpolate(x, self.img_resolution, mode='area') 122 | x = self.norm(x) 123 | features = forward_vit(self.model, x) 124 | return features 125 | 126 | 127 | class ProjectedDiscriminator(nn.Module): 128 | def __init__(self, c_dim: int, diffaug: bool = True, p_crop: float = 0.5): 129 | super().__init__() 130 | self.c_dim = c_dim 131 | self.diffaug = diffaug 132 | self.p_crop = p_crop 133 | 134 | self.dino = DINO() 135 | 136 | heads = [] 137 | for i in range(self.dino.n_hooks): 138 | heads += [str(i), DiscHead(self.dino.embed_dim, c_dim)], 139 | self.heads = nn.ModuleDict(heads) 140 | 141 | def train(self, mode: bool = True): 142 | self.heads.requires_grad_(mode) 143 | self.heads = self.heads.train(mode) 144 | return self 145 | 146 | def eval(self): 147 | return self.train(False) 148 | 149 | def forward(self, x: torch.Tensor, c: torch.Tensor, phase="D") -> torch.Tensor: 150 | # Apply augmentation (x in [-1, 1]). 151 | if self.diffaug: 152 | x = DiffAugment(x, policy='color,translation,cutout') 153 | 154 | # Transform to [0, 1]. 155 | x = x.add(1).div(2) 156 | 157 | # Take crops with probablity p_crop if the image is larger. 158 | if x.size(-1) > self.dino.img_resolution and np.random.random() < self.p_crop: 159 | x = RandomCrop(self.dino.img_resolution)(x) 160 | 161 | # Forward pass through DINO ViT. 162 | features_tmp = self.dino(x) 163 | 164 | if phase == "D": 165 | features = {} 166 | for k, _ in self.heads.items(): 167 | features[k] = features_tmp[k].detach().clone().requires_grad_(True) 168 | else: 169 | features = features_tmp 170 | 171 | # Apply discriminator heads. 172 | logits = [] 173 | for k, head in self.heads.items(): 174 | logits.append(head(features[k], c).view(x.size(0), -1)) 175 | logits = torch.cat(logits, dim=1) 176 | return logits, features 177 | -------------------------------------------------------------------------------- /add_turbo/dnnlib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | from .util import EasyDict, make_cache_dir_path 10 | -------------------------------------------------------------------------------- /add_turbo/shared.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Shared architecture blocks.""" 10 | 11 | from typing import Callable 12 | 13 | import numpy as np 14 | import torch 15 | import torch.nn as nn 16 | 17 | from torch_utils.ops import bias_act 18 | 19 | 20 | class ResidualBlock(nn.Module): 21 | def __init__(self, fn: Callable): 22 | super().__init__() 23 | self.fn = fn 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: 26 | return (self.fn(x) + x) / np.sqrt(2) 27 | 28 | 29 | class FullyConnectedLayer(nn.Module): 30 | def __init__( 31 | self, 32 | in_features: int, # Number of input features. 33 | out_features: int, # Number of output features. 34 | bias: bool = True, # Apply additive bias before the activation function? 35 | activation: str = 'linear', # Activation function: 'relu', 'lrelu', etc. 36 | lr_multiplier: float = 1.0, # Learning rate multiplier. 37 | weight_init: float = 1.0, # Initial standard deviation of the weight tensor. 38 | bias_init: float = 0.0, # Initial value for the additive bias. 39 | ): 40 | 41 | super().__init__() 42 | self.in_features = in_features 43 | self.out_features = out_features 44 | self.activation = activation 45 | self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) * (weight_init / lr_multiplier)) 46 | bias_init = np.broadcast_to(np.asarray(bias_init, dtype=np.float32), [out_features]) 47 | self.bias = torch.nn.Parameter(torch.from_numpy(bias_init / lr_multiplier)) if bias else None 48 | self.weight_gain = lr_multiplier / np.sqrt(in_features) 49 | self.bias_gain = lr_multiplier 50 | 51 | def forward(self, x: torch.Tensor) -> torch.Tensor: 52 | w = self.weight.to(x.dtype) * self.weight_gain 53 | b = self.bias 54 | if b is not None: 55 | b = b.to(x.dtype) 56 | if self.bias_gain != 1: 57 | b = b * self.bias_gain 58 | 59 | if self.activation == 'linear' and b is not None: 60 | x = torch.addmm(b.unsqueeze(0), x, w.t()) 61 | else: 62 | x = x.matmul(w.t()) 63 | x = bias_act.bias_act(x, b, act=self.activation) 64 | return x 65 | 66 | def extra_repr(self) -> str: 67 | return f'in_features={self.in_features:d}, out_features={self.out_features:d}, activation={self.activation:s}' 68 | 69 | 70 | class MLP(nn.Module): 71 | def __init__( 72 | self, 73 | features_list: list[int], # Number of features in each layer of the MLP. 74 | activation: str = 'linear', # Activation function: 'relu', 'lrelu', etc. 75 | lr_multiplier: float = 1.0, # Learning rate multiplier. 76 | linear_out: bool = False # Use the 'linear' activation function for the output layer? 77 | ): 78 | super().__init__() 79 | num_layers = len(features_list) - 1 80 | self.num_layers = num_layers 81 | self.out_dim = features_list[-1] 82 | 83 | for idx in range(num_layers): 84 | in_features = features_list[idx] 85 | out_features = features_list[idx + 1] 86 | if linear_out and idx == num_layers-1: 87 | activation = 'linear' 88 | layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier) 89 | setattr(self, f'fc{idx}', layer) 90 | 91 | def forward(self, x: torch.Tensor) -> torch.Tensor: 92 | ''' if x is sequence of tokens, shift tokens to batch and apply MLP to all''' 93 | shift2batch = (x.ndim == 3) 94 | 95 | if shift2batch: 96 | B, K, C = x.shape 97 | x = x.flatten(0,1) 98 | 99 | for idx in range(self.num_layers): 100 | layer = getattr(self, f'fc{idx}') 101 | x = layer(x) 102 | 103 | if shift2batch: 104 | x = x.reshape(B, K, -1) 105 | 106 | return x 107 | -------------------------------------------------------------------------------- /add_turbo/torch_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /add_turbo/torch_utils/custom_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import glob 10 | import hashlib 11 | import importlib 12 | import os 13 | import re 14 | import shutil 15 | import uuid 16 | 17 | import torch 18 | import torch.utils.cpp_extension 19 | from torch.utils.file_baton import FileBaton 20 | 21 | #---------------------------------------------------------------------------- 22 | # Global options. 23 | 24 | verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full' 25 | 26 | #---------------------------------------------------------------------------- 27 | # Internal helper funcs. 28 | 29 | def _find_compiler_bindir(): 30 | patterns = [ 31 | 'C:/Program Files*/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64', 32 | 'C:/Program Files*/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64', 33 | 'C:/Program Files*/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64', 34 | 'C:/Program Files*/Microsoft Visual Studio */vc/bin', 35 | ] 36 | for pattern in patterns: 37 | matches = sorted(glob.glob(pattern)) 38 | if len(matches): 39 | return matches[-1] 40 | return None 41 | 42 | #---------------------------------------------------------------------------- 43 | 44 | def _get_mangled_gpu_name(): 45 | name = torch.cuda.get_device_name().lower() 46 | out = [] 47 | for c in name: 48 | if re.match('[a-z0-9_-]+', c): 49 | out.append(c) 50 | else: 51 | out.append('-') 52 | return ''.join(out) 53 | 54 | #---------------------------------------------------------------------------- 55 | # Main entry point for compiling and loading C++/CUDA plugins. 56 | 57 | _cached_plugins = dict() 58 | 59 | def get_plugin(module_name, sources, headers=None, source_dir=None, **build_kwargs): 60 | assert verbosity in ['none', 'brief', 'full'] 61 | if headers is None: 62 | headers = [] 63 | if source_dir is not None: 64 | sources = [os.path.join(source_dir, fname) for fname in sources] 65 | headers = [os.path.join(source_dir, fname) for fname in headers] 66 | 67 | # Already cached? 68 | if module_name in _cached_plugins: 69 | return _cached_plugins[module_name] 70 | 71 | # Print status. 72 | if verbosity == 'full': 73 | print(f'Setting up PyTorch plugin "{module_name}"...') 74 | elif verbosity == 'brief': 75 | print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True) 76 | verbose_build = (verbosity == 'full') 77 | 78 | # Compile and load. 79 | try: # pylint: disable=too-many-nested-blocks 80 | # Make sure we can find the necessary compiler binaries. 81 | if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0: 82 | compiler_bindir = _find_compiler_bindir() 83 | if compiler_bindir is None: 84 | raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".') 85 | os.environ['PATH'] += ';' + compiler_bindir 86 | 87 | # Some containers set TORCH_CUDA_ARCH_LIST to a list that can either 88 | # break the build or unnecessarily restrict what's available to nvcc. 89 | # Unset it to let nvcc decide based on what's available on the 90 | # machine. 91 | os.environ['TORCH_CUDA_ARCH_LIST'] = '' 92 | 93 | # Incremental build md5sum trickery. Copies all the input source files 94 | # into a cached build directory under a combined md5 digest of the input 95 | # source files. Copying is done only if the combined digest has changed. 96 | # This keeps input file timestamps and filenames the same as in previous 97 | # extension builds, allowing for fast incremental rebuilds. 98 | # 99 | # This optimization is done only in case all the source files reside in 100 | # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR 101 | # environment variable is set (we take this as a signal that the user 102 | # actually cares about this.) 103 | # 104 | # EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work 105 | # around the *.cu dependency bug in ninja config. 106 | # 107 | all_source_files = sorted(sources + headers) 108 | all_source_dirs = set(os.path.dirname(fname) for fname in all_source_files) 109 | if len(all_source_dirs) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ): 110 | 111 | # Compute combined hash digest for all source files. 112 | hash_md5 = hashlib.md5() 113 | for src in all_source_files: 114 | with open(src, 'rb') as f: 115 | hash_md5.update(f.read()) 116 | 117 | # Select cached build directory name. 118 | source_digest = hash_md5.hexdigest() 119 | build_top_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access 120 | cached_build_dir = os.path.join(build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}') 121 | 122 | if not os.path.isdir(cached_build_dir): 123 | tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}' 124 | os.makedirs(tmpdir) 125 | for src in all_source_files: 126 | shutil.copyfile(src, os.path.join(tmpdir, os.path.basename(src))) 127 | try: 128 | os.replace(tmpdir, cached_build_dir) # atomic 129 | except OSError: 130 | # source directory already exists, delete tmpdir and its contents. 131 | shutil.rmtree(tmpdir) 132 | if not os.path.isdir(cached_build_dir): raise 133 | 134 | # Compile. 135 | cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources] 136 | torch.utils.cpp_extension.load(name=module_name, build_directory=cached_build_dir, 137 | verbose=verbose_build, sources=cached_sources, **build_kwargs) 138 | else: 139 | torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs) 140 | 141 | # Load. 142 | module = importlib.import_module(module_name) 143 | 144 | except: 145 | if verbosity == 'brief': 146 | print('Failed!') 147 | raise 148 | 149 | # Print status and add to cache dict. 150 | if verbosity == 'full': 151 | print(f'Done setting up PyTorch plugin "{module_name}".') 152 | elif verbosity == 'brief': 153 | print('Done.') 154 | _cached_plugins[module_name] = module 155 | return module 156 | 157 | #---------------------------------------------------------------------------- 158 | -------------------------------------------------------------------------------- /add_turbo/torch_utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import re 10 | import contextlib 11 | import numpy as np 12 | import torch 13 | import warnings 14 | import dnnlib 15 | 16 | #---------------------------------------------------------------------------- 17 | # Cached construction of constant tensors. Avoids CPU=>GPU copy when the 18 | # same constant is used multiple times. 19 | 20 | _constant_cache = dict() 21 | 22 | def constant(value, shape=None, dtype=None, device=None, memory_format=None): 23 | value = np.asarray(value) 24 | if shape is not None: 25 | shape = tuple(shape) 26 | if dtype is None: 27 | dtype = torch.get_default_dtype() 28 | if device is None: 29 | device = torch.device('cpu') 30 | if memory_format is None: 31 | memory_format = torch.contiguous_format 32 | 33 | key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) 34 | tensor = _constant_cache.get(key, None) 35 | if tensor is None: 36 | tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) 37 | if shape is not None: 38 | tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) 39 | tensor = tensor.contiguous(memory_format=memory_format) 40 | _constant_cache[key] = tensor 41 | return tensor 42 | 43 | #---------------------------------------------------------------------------- 44 | # Replace NaN/Inf with specified numerical values. 45 | 46 | try: 47 | nan_to_num = torch.nan_to_num # 1.8.0a0 48 | except AttributeError: 49 | def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin 50 | assert isinstance(input, torch.Tensor) 51 | if posinf is None: 52 | posinf = torch.finfo(input.dtype).max 53 | if neginf is None: 54 | neginf = torch.finfo(input.dtype).min 55 | assert nan == 0 56 | return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out) 57 | 58 | #---------------------------------------------------------------------------- 59 | # Symbolic assert. 60 | 61 | try: 62 | symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access 63 | except AttributeError: 64 | symbolic_assert = torch.Assert # 1.7.0 65 | 66 | #---------------------------------------------------------------------------- 67 | # Context manager to temporarily suppress known warnings in torch.jit.trace(). 68 | # Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672 69 | 70 | @contextlib.contextmanager 71 | def suppress_tracer_warnings(): 72 | flt = ('ignore', None, torch.jit.TracerWarning, None, 0) 73 | warnings.filters.insert(0, flt) 74 | yield 75 | warnings.filters.remove(flt) 76 | 77 | #---------------------------------------------------------------------------- 78 | # Assert that the shape of a tensor matches the given list of integers. 79 | # None indicates that the size of a dimension is allowed to vary. 80 | # Performs symbolic assertion when used in torch.jit.trace(). 81 | 82 | def assert_shape(tensor, ref_shape): 83 | if tensor.ndim != len(ref_shape): 84 | raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}') 85 | for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): 86 | if ref_size is None: 87 | pass 88 | elif isinstance(ref_size, torch.Tensor): 89 | with suppress_tracer_warnings(): # as_tensor results are registered as constants 90 | symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}') 91 | elif isinstance(size, torch.Tensor): 92 | with suppress_tracer_warnings(): # as_tensor results are registered as constants 93 | symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}') 94 | elif size != ref_size: 95 | raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}') 96 | 97 | #---------------------------------------------------------------------------- 98 | # Function decorator that calls torch.autograd.profiler.record_function(). 99 | 100 | def profiled_function(fn): 101 | def decorator(*args, **kwargs): 102 | with torch.autograd.profiler.record_function(fn.__name__): 103 | return fn(*args, **kwargs) 104 | decorator.__name__ = fn.__name__ 105 | return decorator 106 | 107 | #---------------------------------------------------------------------------- 108 | # Sampler for torch.utils.data.DataLoader that loops over the dataset 109 | # indefinitely, shuffling items as it goes. 110 | 111 | class InfiniteSampler(torch.utils.data.Sampler): 112 | def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5): 113 | assert len(dataset) > 0 114 | assert num_replicas > 0 115 | assert 0 <= rank < num_replicas 116 | assert 0 <= window_size <= 1 117 | super().__init__(dataset) 118 | self.dataset = dataset 119 | self.rank = rank 120 | self.num_replicas = num_replicas 121 | self.shuffle = shuffle 122 | self.seed = seed 123 | self.window_size = window_size 124 | 125 | def __iter__(self): 126 | order = np.arange(len(self.dataset)) 127 | rnd = None 128 | window = 0 129 | if self.shuffle: 130 | rnd = np.random.RandomState(self.seed) 131 | rnd.shuffle(order) 132 | window = int(np.rint(order.size * self.window_size)) 133 | 134 | idx = 0 135 | while True: 136 | i = idx % order.size 137 | if idx % self.num_replicas == self.rank: 138 | yield order[i] 139 | if window >= 2: 140 | j = (i - rnd.randint(window)) % order.size 141 | order[i], order[j] = order[j], order[i] 142 | idx += 1 143 | 144 | #---------------------------------------------------------------------------- 145 | # Utilities for operating with torch.nn.Module parameters and buffers. 146 | def spectral_to_cpu(model: torch.nn.Module): 147 | def wrapped_in_spectral(m): return hasattr(m, 'weight_v') 148 | children = get_children(model) 149 | for child in children: 150 | if wrapped_in_spectral(child): 151 | child.weight = child.weight.cpu() 152 | return model 153 | 154 | def get_children(model: torch.nn.Module): 155 | children = list(model.children()) 156 | flatt_children = [] 157 | if children == []: 158 | return model 159 | else: 160 | for child in children: 161 | try: 162 | flatt_children.extend(get_children(child)) 163 | except TypeError: 164 | flatt_children.append(get_children(child)) 165 | return flatt_children 166 | 167 | def params_and_buffers(module): 168 | assert isinstance(module, torch.nn.Module) 169 | return list(module.parameters()) + list(module.buffers()) 170 | 171 | def named_params_and_buffers(module): 172 | assert isinstance(module, torch.nn.Module) 173 | return list(module.named_parameters()) + list(module.named_buffers()) 174 | 175 | def copy_params_and_buffers(src_module, dst_module, require_all=False): 176 | assert isinstance(src_module, torch.nn.Module) 177 | assert isinstance(dst_module, torch.nn.Module) 178 | src_tensors = dict(named_params_and_buffers(src_module)) 179 | for name, tensor in named_params_and_buffers(dst_module): 180 | assert (name in src_tensors) or (not require_all) 181 | if name in src_tensors: 182 | tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad) 183 | 184 | #---------------------------------------------------------------------------- 185 | # Context manager for easily enabling/disabling DistributedDataParallel 186 | # synchronization. 187 | 188 | @contextlib.contextmanager 189 | def ddp_sync(module, sync): 190 | assert isinstance(module, torch.nn.Module) 191 | if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): 192 | yield 193 | else: 194 | with module.no_sync(): 195 | yield 196 | 197 | #---------------------------------------------------------------------------- 198 | # Check DistributedDataParallel consistency across processes. 199 | 200 | def check_ddp_consistency(module, ignore_regex=None): 201 | assert isinstance(module, torch.nn.Module) 202 | for name, tensor in named_params_and_buffers(module): 203 | fullname = type(module).__name__ + '.' + name 204 | if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): 205 | continue 206 | tensor = tensor.detach() 207 | if tensor.is_floating_point(): 208 | tensor = nan_to_num(tensor) 209 | other = tensor.clone() 210 | torch.distributed.broadcast(tensor=other, src=0) 211 | assert (tensor == other).all(), fullname 212 | 213 | #---------------------------------------------------------------------------- 214 | # Print summary table of module hierarchy. 215 | 216 | def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): 217 | assert isinstance(module, torch.nn.Module) 218 | assert not isinstance(module, torch.jit.ScriptModule) 219 | assert isinstance(inputs, (tuple, list)) 220 | 221 | # Register hooks. 222 | entries = [] 223 | nesting = [0] 224 | def pre_hook(_mod, _inputs): 225 | nesting[0] += 1 226 | def post_hook(mod, _inputs, outputs): 227 | nesting[0] -= 1 228 | if nesting[0] <= max_nesting: 229 | outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] 230 | outputs = [t for t in outputs if isinstance(t, torch.Tensor)] 231 | entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs)) 232 | hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] 233 | hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] 234 | 235 | # Run module. 236 | outputs = module(*inputs) 237 | for hook in hooks: 238 | hook.remove() 239 | 240 | # Identify unique outputs, parameters, and buffers. 241 | tensors_seen = set() 242 | for e in entries: 243 | e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen] 244 | e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] 245 | e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] 246 | tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs} 247 | 248 | # Filter out redundant entries. 249 | if skip_redundant: 250 | entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)] 251 | 252 | # Construct table. 253 | rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']] 254 | rows += [['---'] * len(rows[0])] 255 | param_total = 0 256 | buffer_total = 0 257 | submodule_names = {mod: name for name, mod in module.named_modules()} 258 | for e in entries: 259 | name = '' if e.mod is module else submodule_names[e.mod] 260 | param_size = sum(t.numel() for t in e.unique_params) 261 | buffer_size = sum(t.numel() for t in e.unique_buffers) 262 | output_shapes = [str(list(t.shape)) for t in e.outputs] 263 | output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs] 264 | rows += [[ 265 | name + (':0' if len(e.outputs) >= 2 else ''), 266 | str(param_size) if param_size else '-', 267 | str(buffer_size) if buffer_size else '-', 268 | (output_shapes + ['-'])[0], 269 | (output_dtypes + ['-'])[0], 270 | ]] 271 | for idx in range(1, len(e.outputs)): 272 | rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]] 273 | param_total += param_size 274 | buffer_total += buffer_size 275 | rows += [['---'] * len(rows[0])] 276 | rows += [['Total', str(param_total), str(buffer_total), '-', '-']] 277 | 278 | # Print table. 279 | widths = [max(len(cell) for cell in column) for column in zip(*rows)] 280 | print() 281 | for row in rows: 282 | print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths))) 283 | print() 284 | return outputs 285 | -------------------------------------------------------------------------------- /add_turbo/torch_utils/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /add_turbo/torch_utils/ops/bias_act.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "bias_act.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static bool has_same_layout(torch::Tensor x, torch::Tensor y) 17 | { 18 | if (x.dim() != y.dim()) 19 | return false; 20 | for (int64_t i = 0; i < x.dim(); i++) 21 | { 22 | if (x.size(i) != y.size(i)) 23 | return false; 24 | if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) 25 | return false; 26 | } 27 | return true; 28 | } 29 | 30 | //------------------------------------------------------------------------ 31 | 32 | static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp) 33 | { 34 | // Validate arguments. 35 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 36 | TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x"); 37 | TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x"); 38 | TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x"); 39 | TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x"); 40 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 41 | TORCH_CHECK(b.dim() == 1, "b must have rank 1"); 42 | TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds"); 43 | TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements"); 44 | TORCH_CHECK(grad >= 0, "grad must be non-negative"); 45 | 46 | // Validate layout. 47 | TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense"); 48 | TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); 49 | TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x"); 50 | TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x"); 51 | TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x"); 52 | 53 | // Create output tensor. 54 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 55 | torch::Tensor y = torch::empty_like(x); 56 | TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); 57 | 58 | // Initialize CUDA kernel parameters. 59 | bias_act_kernel_params p; 60 | p.x = x.data_ptr(); 61 | p.b = (b.numel()) ? b.data_ptr() : NULL; 62 | p.xref = (xref.numel()) ? xref.data_ptr() : NULL; 63 | p.yref = (yref.numel()) ? yref.data_ptr() : NULL; 64 | p.dy = (dy.numel()) ? dy.data_ptr() : NULL; 65 | p.y = y.data_ptr(); 66 | p.grad = grad; 67 | p.act = act; 68 | p.alpha = alpha; 69 | p.gain = gain; 70 | p.clamp = clamp; 71 | p.sizeX = (int)x.numel(); 72 | p.sizeB = (int)b.numel(); 73 | p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; 74 | 75 | // Choose CUDA kernel. 76 | void* kernel; 77 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 78 | { 79 | kernel = choose_bias_act_kernel(p); 80 | }); 81 | TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func"); 82 | 83 | // Launch CUDA kernel. 84 | p.loopX = 4; 85 | int blockSize = 4 * 32; 86 | int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; 87 | void* args[] = {&p}; 88 | AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 89 | return y; 90 | } 91 | 92 | //------------------------------------------------------------------------ 93 | 94 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 95 | { 96 | m.def("bias_act", &bias_act); 97 | } 98 | 99 | //------------------------------------------------------------------------ 100 | -------------------------------------------------------------------------------- /add_turbo/torch_utils/ops/bias_act.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include "bias_act.h" 11 | 12 | //------------------------------------------------------------------------ 13 | // Helpers. 14 | 15 | template struct InternalType; 16 | template <> struct InternalType { typedef double scalar_t; }; 17 | template <> struct InternalType { typedef float scalar_t; }; 18 | template <> struct InternalType { typedef float scalar_t; }; 19 | 20 | //------------------------------------------------------------------------ 21 | // CUDA kernel. 22 | 23 | template 24 | __global__ void bias_act_kernel(bias_act_kernel_params p) 25 | { 26 | typedef typename InternalType::scalar_t scalar_t; 27 | int G = p.grad; 28 | scalar_t alpha = (scalar_t)p.alpha; 29 | scalar_t gain = (scalar_t)p.gain; 30 | scalar_t clamp = (scalar_t)p.clamp; 31 | scalar_t one = (scalar_t)1; 32 | scalar_t two = (scalar_t)2; 33 | scalar_t expRange = (scalar_t)80; 34 | scalar_t halfExpRange = (scalar_t)40; 35 | scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946; 36 | scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717; 37 | 38 | // Loop over elements. 39 | int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; 40 | for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) 41 | { 42 | // Load. 43 | scalar_t x = (scalar_t)((const T*)p.x)[xi]; 44 | scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0; 45 | scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0; 46 | scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0; 47 | scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one; 48 | scalar_t yy = (gain != 0) ? yref / gain : 0; 49 | scalar_t y = 0; 50 | 51 | // Apply bias. 52 | ((G == 0) ? x : xref) += b; 53 | 54 | // linear 55 | if (A == 1) 56 | { 57 | if (G == 0) y = x; 58 | if (G == 1) y = x; 59 | } 60 | 61 | // relu 62 | if (A == 2) 63 | { 64 | if (G == 0) y = (x > 0) ? x : 0; 65 | if (G == 1) y = (yy > 0) ? x : 0; 66 | } 67 | 68 | // lrelu 69 | if (A == 3) 70 | { 71 | if (G == 0) y = (x > 0) ? x : x * alpha; 72 | if (G == 1) y = (yy > 0) ? x : x * alpha; 73 | } 74 | 75 | // tanh 76 | if (A == 4) 77 | { 78 | if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); } 79 | if (G == 1) y = x * (one - yy * yy); 80 | if (G == 2) y = x * (one - yy * yy) * (-two * yy); 81 | } 82 | 83 | // sigmoid 84 | if (A == 5) 85 | { 86 | if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one); 87 | if (G == 1) y = x * yy * (one - yy); 88 | if (G == 2) y = x * yy * (one - yy) * (one - two * yy); 89 | } 90 | 91 | // elu 92 | if (A == 6) 93 | { 94 | if (G == 0) y = (x >= 0) ? x : exp(x) - one; 95 | if (G == 1) y = (yy >= 0) ? x : x * (yy + one); 96 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one); 97 | } 98 | 99 | // selu 100 | if (A == 7) 101 | { 102 | if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one); 103 | if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha); 104 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha); 105 | } 106 | 107 | // softplus 108 | if (A == 8) 109 | { 110 | if (G == 0) y = (x > expRange) ? x : log(exp(x) + one); 111 | if (G == 1) y = x * (one - exp(-yy)); 112 | if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); } 113 | } 114 | 115 | // swish 116 | if (A == 9) 117 | { 118 | if (G == 0) 119 | y = (x < -expRange) ? 0 : x / (exp(-x) + one); 120 | else 121 | { 122 | scalar_t c = exp(xref); 123 | scalar_t d = c + one; 124 | if (G == 1) 125 | y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); 126 | else 127 | y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d); 128 | yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain; 129 | } 130 | } 131 | 132 | // Apply gain. 133 | y *= gain * dy; 134 | 135 | // Clamp. 136 | if (clamp >= 0) 137 | { 138 | if (G == 0) 139 | y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp; 140 | else 141 | y = (yref > -clamp & yref < clamp) ? y : 0; 142 | } 143 | 144 | // Store. 145 | ((T*)p.y)[xi] = (T)y; 146 | } 147 | } 148 | 149 | //------------------------------------------------------------------------ 150 | // CUDA kernel selection. 151 | 152 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p) 153 | { 154 | if (p.act == 1) return (void*)bias_act_kernel; 155 | if (p.act == 2) return (void*)bias_act_kernel; 156 | if (p.act == 3) return (void*)bias_act_kernel; 157 | if (p.act == 4) return (void*)bias_act_kernel; 158 | if (p.act == 5) return (void*)bias_act_kernel; 159 | if (p.act == 6) return (void*)bias_act_kernel; 160 | if (p.act == 7) return (void*)bias_act_kernel; 161 | if (p.act == 8) return (void*)bias_act_kernel; 162 | if (p.act == 9) return (void*)bias_act_kernel; 163 | return NULL; 164 | } 165 | 166 | //------------------------------------------------------------------------ 167 | // Template specializations. 168 | 169 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 170 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 171 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 172 | 173 | //------------------------------------------------------------------------ 174 | -------------------------------------------------------------------------------- /add_turbo/torch_utils/ops/bias_act.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | //------------------------------------------------------------------------ 10 | // CUDA kernel parameters. 11 | 12 | struct bias_act_kernel_params 13 | { 14 | const void* x; // [sizeX] 15 | const void* b; // [sizeB] or NULL 16 | const void* xref; // [sizeX] or NULL 17 | const void* yref; // [sizeX] or NULL 18 | const void* dy; // [sizeX] or NULL 19 | void* y; // [sizeX] 20 | 21 | int grad; 22 | int act; 23 | float alpha; 24 | float gain; 25 | float clamp; 26 | 27 | int sizeX; 28 | int sizeB; 29 | int stepB; 30 | int loopX; 31 | }; 32 | 33 | //------------------------------------------------------------------------ 34 | // CUDA kernel selection. 35 | 36 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p); 37 | 38 | //------------------------------------------------------------------------ 39 | -------------------------------------------------------------------------------- /add_turbo/torch_utils/ops/bias_act.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom PyTorch ops for efficient bias and activation.""" 10 | 11 | import os 12 | import numpy as np 13 | import torch 14 | import dnnlib 15 | 16 | from .. import custom_ops 17 | from .. import misc 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | activation_funcs = { 22 | 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False), 23 | 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False), 24 | 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False), 25 | 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True), 26 | 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True), 27 | 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True), 28 | 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True), 29 | 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True), 30 | 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True), 31 | } 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | _plugin = None 36 | _null_tensor = torch.empty([0]) 37 | 38 | def _init(): 39 | global _plugin 40 | if _plugin is None: 41 | _plugin = custom_ops.get_plugin( 42 | module_name='bias_act_plugin', 43 | sources=['bias_act.cpp', 'bias_act.cu'], 44 | headers=['bias_act.h'], 45 | source_dir=os.path.dirname(__file__), 46 | extra_cuda_cflags=['--use_fast_math', '--allow-unsupported-compiler'], 47 | ) 48 | return True 49 | 50 | #---------------------------------------------------------------------------- 51 | 52 | def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'): 53 | r"""Fused bias and activation function. 54 | 55 | Adds bias `b` to activation tensor `x`, evaluates activation function `act`, 56 | and scales the result by `gain`. Each of the steps is optional. In most cases, 57 | the fused op is considerably more efficient than performing the same calculation 58 | using standard PyTorch ops. It supports first and second order gradients, 59 | but not third order gradients. 60 | 61 | Args: 62 | x: Input activation tensor. Can be of any shape. 63 | b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type 64 | as `x`. The shape must be known, and it must match the dimension of `x` 65 | corresponding to `dim`. 66 | dim: The dimension in `x` corresponding to the elements of `b`. 67 | The value of `dim` is ignored if `b` is not specified. 68 | act: Name of the activation function to evaluate, or `"linear"` to disable. 69 | Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc. 70 | See `activation_funcs` for a full list. `None` is not allowed. 71 | alpha: Shape parameter for the activation function, or `None` to use the default. 72 | gain: Scaling factor for the output tensor, or `None` to use default. 73 | See `activation_funcs` for the default scaling of each activation function. 74 | If unsure, consider specifying 1. 75 | clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable 76 | the clamping (default). 77 | impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). 78 | 79 | Returns: 80 | Tensor of the same shape and datatype as `x`. 81 | """ 82 | assert isinstance(x, torch.Tensor) 83 | assert impl in ['ref', 'cuda'] 84 | if impl == 'cuda' and x.device.type == 'cuda' and _init(): 85 | return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b) 86 | return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp) 87 | 88 | #---------------------------------------------------------------------------- 89 | 90 | @misc.profiled_function 91 | def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None): 92 | """Slow reference implementation of `bias_act()` using standard TensorFlow ops. 93 | """ 94 | assert isinstance(x, torch.Tensor) 95 | assert clamp is None or clamp >= 0 96 | spec = activation_funcs[act] 97 | alpha = float(alpha if alpha is not None else spec.def_alpha) 98 | gain = float(gain if gain is not None else spec.def_gain) 99 | clamp = float(clamp if clamp is not None else -1) 100 | 101 | # Add bias. 102 | if b is not None: 103 | assert isinstance(b, torch.Tensor) and b.ndim == 1 104 | assert 0 <= dim < x.ndim 105 | assert b.shape[0] == x.shape[dim] 106 | x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)]) 107 | 108 | # Evaluate activation function. 109 | alpha = float(alpha) 110 | x = spec.func(x, alpha=alpha) 111 | 112 | # Scale by gain. 113 | gain = float(gain) 114 | if gain != 1: 115 | x = x * gain 116 | 117 | # Clamp. 118 | if clamp >= 0: 119 | x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type 120 | return x 121 | 122 | #---------------------------------------------------------------------------- 123 | 124 | _bias_act_cuda_cache = dict() 125 | 126 | def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None): 127 | """Fast CUDA implementation of `bias_act()` using custom ops. 128 | """ 129 | # Parse arguments. 130 | assert clamp is None or clamp >= 0 131 | spec = activation_funcs[act] 132 | alpha = float(alpha if alpha is not None else spec.def_alpha) 133 | gain = float(gain if gain is not None else spec.def_gain) 134 | clamp = float(clamp if clamp is not None else -1) 135 | 136 | # Lookup from cache. 137 | key = (dim, act, alpha, gain, clamp) 138 | if key in _bias_act_cuda_cache: 139 | return _bias_act_cuda_cache[key] 140 | 141 | # Forward op. 142 | class BiasActCuda(torch.autograd.Function): 143 | @staticmethod 144 | def forward(ctx, x, b): # pylint: disable=arguments-differ 145 | ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride(1) == 1 else torch.contiguous_format 146 | x = x.contiguous(memory_format=ctx.memory_format) 147 | b = b.contiguous() if b is not None else _null_tensor 148 | y = x 149 | if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor: 150 | y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp) 151 | ctx.save_for_backward( 152 | x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, 153 | b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, 154 | y if 'y' in spec.ref else _null_tensor) 155 | return y 156 | 157 | @staticmethod 158 | def backward(ctx, dy): # pylint: disable=arguments-differ 159 | dy = dy.contiguous(memory_format=ctx.memory_format) 160 | x, b, y = ctx.saved_tensors 161 | dx = None 162 | db = None 163 | 164 | if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: 165 | dx = dy 166 | if act != 'linear' or gain != 1 or clamp >= 0: 167 | dx = BiasActCudaGrad.apply(dy, x, b, y) 168 | 169 | if ctx.needs_input_grad[1]: 170 | db = dx.sum([i for i in range(dx.ndim) if i != dim]) 171 | 172 | return dx, db 173 | 174 | # Backward op. 175 | class BiasActCudaGrad(torch.autograd.Function): 176 | @staticmethod 177 | def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ 178 | ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride(1) == 1 else torch.contiguous_format 179 | dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp) 180 | ctx.save_for_backward( 181 | dy if spec.has_2nd_grad else _null_tensor, 182 | x, b, y) 183 | return dx 184 | 185 | @staticmethod 186 | def backward(ctx, d_dx): # pylint: disable=arguments-differ 187 | d_dx = d_dx.contiguous(memory_format=ctx.memory_format) 188 | dy, x, b, y = ctx.saved_tensors 189 | d_dy = None 190 | d_x = None 191 | d_b = None 192 | d_y = None 193 | 194 | if ctx.needs_input_grad[0]: 195 | d_dy = BiasActCudaGrad.apply(d_dx, x, b, y) 196 | 197 | if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]): 198 | d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp) 199 | 200 | if spec.has_2nd_grad and ctx.needs_input_grad[2]: 201 | d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim]) 202 | 203 | return d_dy, d_x, d_b, d_y 204 | 205 | # Add to cache. 206 | _bias_act_cuda_cache[key] = BiasActCuda 207 | return BiasActCuda 208 | 209 | #---------------------------------------------------------------------------- 210 | -------------------------------------------------------------------------------- /add_turbo/torch_utils/ops/conv2d_gradfix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom replacement for `torch.nn.functional.conv2d` that supports 10 | arbitrarily high order gradients with zero performance penalty.""" 11 | 12 | import contextlib 13 | import torch 14 | from pkg_resources import parse_version 15 | 16 | # pylint: disable=redefined-builtin 17 | # pylint: disable=arguments-differ 18 | # pylint: disable=protected-access 19 | 20 | #---------------------------------------------------------------------------- 21 | 22 | enabled = False # Enable the custom op by setting this to true. 23 | weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights. 24 | _use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11 25 | 26 | @contextlib.contextmanager 27 | def no_weight_gradients(disable=True): 28 | global weight_gradients_disabled 29 | old = weight_gradients_disabled 30 | if disable: 31 | weight_gradients_disabled = True 32 | yield 33 | weight_gradients_disabled = old 34 | 35 | #---------------------------------------------------------------------------- 36 | 37 | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): 38 | if _should_use_custom_op(input): 39 | return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias) 40 | return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) 41 | 42 | def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): 43 | if _should_use_custom_op(input): 44 | return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias) 45 | return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation) 46 | 47 | #---------------------------------------------------------------------------- 48 | 49 | def _should_use_custom_op(input): 50 | assert isinstance(input, torch.Tensor) 51 | if (not enabled) or (not torch.backends.cudnn.enabled): 52 | return False 53 | if _use_pytorch_1_11_api: 54 | # The work-around code doesn't work on PyTorch 1.11.0 onwards 55 | return False 56 | if input.device.type != 'cuda': 57 | return False 58 | return True 59 | 60 | def _tuple_of_ints(xs, ndim): 61 | xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim 62 | assert len(xs) == ndim 63 | assert all(isinstance(x, int) for x in xs) 64 | return xs 65 | 66 | #---------------------------------------------------------------------------- 67 | 68 | _conv2d_gradfix_cache = dict() 69 | _null_tensor = torch.empty([0]) 70 | 71 | def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups): 72 | # Parse arguments. 73 | ndim = 2 74 | weight_shape = tuple(weight_shape) 75 | stride = _tuple_of_ints(stride, ndim) 76 | padding = _tuple_of_ints(padding, ndim) 77 | output_padding = _tuple_of_ints(output_padding, ndim) 78 | dilation = _tuple_of_ints(dilation, ndim) 79 | 80 | # Lookup from cache. 81 | key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) 82 | if key in _conv2d_gradfix_cache: 83 | return _conv2d_gradfix_cache[key] 84 | 85 | # Validate arguments. 86 | assert groups >= 1 87 | assert len(weight_shape) == ndim + 2 88 | assert all(stride[i] >= 1 for i in range(ndim)) 89 | assert all(padding[i] >= 0 for i in range(ndim)) 90 | assert all(dilation[i] >= 0 for i in range(ndim)) 91 | if not transpose: 92 | assert all(output_padding[i] == 0 for i in range(ndim)) 93 | else: # transpose 94 | assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim)) 95 | 96 | # Helpers. 97 | common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups) 98 | def calc_output_padding(input_shape, output_shape): 99 | if transpose: 100 | return [0, 0] 101 | return [ 102 | input_shape[i + 2] 103 | - (output_shape[i + 2] - 1) * stride[i] 104 | - (1 - 2 * padding[i]) 105 | - dilation[i] * (weight_shape[i + 2] - 1) 106 | for i in range(ndim) 107 | ] 108 | 109 | # Forward & backward. 110 | class Conv2d(torch.autograd.Function): 111 | @staticmethod 112 | def forward(ctx, input, weight, bias): 113 | assert weight.shape == weight_shape 114 | ctx.save_for_backward( 115 | input if weight.requires_grad else _null_tensor, 116 | weight if input.requires_grad else _null_tensor, 117 | ) 118 | ctx.input_shape = input.shape 119 | 120 | # Simple 1x1 convolution => cuBLAS (only on Volta, not on Ampere). 121 | if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0) and torch.cuda.get_device_capability(input.device) < (8, 0): 122 | a = weight.reshape(groups, weight_shape[0] // groups, weight_shape[1]) 123 | b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1) 124 | c = (a.transpose(1, 2) if transpose else a) @ b.permute(1, 2, 0, 3).flatten(2) 125 | c = c.reshape(-1, input.shape[0], *input.shape[2:]).transpose(0, 1) 126 | c = c if bias is None else c + bias.unsqueeze(0).unsqueeze(2).unsqueeze(3) 127 | return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format)) 128 | 129 | # General case => cuDNN. 130 | if transpose: 131 | return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs) 132 | return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) 133 | 134 | @staticmethod 135 | def backward(ctx, grad_output): 136 | input, weight = ctx.saved_tensors 137 | input_shape = ctx.input_shape 138 | grad_input = None 139 | grad_weight = None 140 | grad_bias = None 141 | 142 | if ctx.needs_input_grad[0]: 143 | p = calc_output_padding(input_shape=input_shape, output_shape=grad_output.shape) 144 | op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs) 145 | grad_input = op.apply(grad_output, weight, None) 146 | assert grad_input.shape == input_shape 147 | 148 | if ctx.needs_input_grad[1] and not weight_gradients_disabled: 149 | grad_weight = Conv2dGradWeight.apply(grad_output, input) 150 | assert grad_weight.shape == weight_shape 151 | 152 | if ctx.needs_input_grad[2]: 153 | grad_bias = grad_output.sum([0, 2, 3]) 154 | 155 | return grad_input, grad_weight, grad_bias 156 | 157 | # Gradient with respect to the weights. 158 | class Conv2dGradWeight(torch.autograd.Function): 159 | @staticmethod 160 | def forward(ctx, grad_output, input): 161 | ctx.save_for_backward( 162 | grad_output if input.requires_grad else _null_tensor, 163 | input if grad_output.requires_grad else _null_tensor, 164 | ) 165 | ctx.grad_output_shape = grad_output.shape 166 | ctx.input_shape = input.shape 167 | 168 | # Simple 1x1 convolution => cuBLAS (on both Volta and Ampere). 169 | if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0): 170 | a = grad_output.reshape(grad_output.shape[0], groups, grad_output.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2) 171 | b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2) 172 | c = (b @ a.transpose(1, 2) if transpose else a @ b.transpose(1, 2)).reshape(weight_shape) 173 | return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format)) 174 | 175 | # General case => cuDNN. 176 | name = 'aten::cudnn_convolution_transpose_backward_weight' if transpose else 'aten::cudnn_convolution_backward_weight' 177 | flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32] 178 | return torch._C._jit_get_operation(name)(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags) 179 | 180 | @staticmethod 181 | def backward(ctx, grad2_grad_weight): 182 | grad_output, input = ctx.saved_tensors 183 | grad_output_shape = ctx.grad_output_shape 184 | input_shape = ctx.input_shape 185 | grad2_grad_output = None 186 | grad2_input = None 187 | 188 | if ctx.needs_input_grad[0]: 189 | grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None) 190 | assert grad2_grad_output.shape == grad_output_shape 191 | 192 | if ctx.needs_input_grad[1]: 193 | p = calc_output_padding(input_shape=input_shape, output_shape=grad_output_shape) 194 | op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs) 195 | grad2_input = op.apply(grad_output, grad2_grad_weight, None) 196 | assert grad2_input.shape == input_shape 197 | 198 | return grad2_grad_output, grad2_input 199 | 200 | _conv2d_gradfix_cache[key] = Conv2d 201 | return Conv2d 202 | 203 | #---------------------------------------------------------------------------- 204 | -------------------------------------------------------------------------------- /add_turbo/torch_utils/ops/conv2d_resample.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """2D convolution with optional up/downsampling.""" 10 | 11 | import torch 12 | 13 | from .. import misc 14 | from . import conv2d_gradfix 15 | from . import upfirdn2d 16 | from .upfirdn2d import _parse_padding 17 | from .upfirdn2d import _get_filter_size 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | def _get_weight_shape(w): 22 | with misc.suppress_tracer_warnings(): # this value will be treated as a constant 23 | shape = [int(sz) for sz in w.shape] 24 | misc.assert_shape(w, shape) 25 | return shape 26 | 27 | #---------------------------------------------------------------------------- 28 | 29 | def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True): 30 | """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations. 31 | """ 32 | _out_channels, _in_channels_per_group, kh, kw = _get_weight_shape(w) 33 | 34 | # Flip weight if requested. 35 | # Note: conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). 36 | if not flip_weight and (kw > 1 or kh > 1): 37 | w = w.flip([2, 3]) 38 | 39 | # Execute using conv2d_gradfix. 40 | op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d 41 | return op(x, w, stride=stride, padding=padding, groups=groups) 42 | 43 | #---------------------------------------------------------------------------- 44 | 45 | @misc.profiled_function 46 | def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False): 47 | r"""2D convolution with optional up/downsampling. 48 | 49 | Padding is performed only once at the beginning, not between the operations. 50 | 51 | Args: 52 | x: Input tensor of shape 53 | `[batch_size, in_channels, in_height, in_width]`. 54 | w: Weight tensor of shape 55 | `[out_channels, in_channels//groups, kernel_height, kernel_width]`. 56 | f: Low-pass filter for up/downsampling. Must be prepared beforehand by 57 | calling upfirdn2d.setup_filter(). None = identity (default). 58 | up: Integer upsampling factor (default: 1). 59 | down: Integer downsampling factor (default: 1). 60 | padding: Padding with respect to the upsampled image. Can be a single number 61 | or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` 62 | (default: 0). 63 | groups: Split input channels into N groups (default: 1). 64 | flip_weight: False = convolution, True = correlation (default: True). 65 | flip_filter: False = convolution, True = correlation (default: False). 66 | 67 | Returns: 68 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. 69 | """ 70 | # Validate arguments. 71 | assert isinstance(x, torch.Tensor) and (x.ndim == 4) 72 | assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) 73 | assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32) 74 | assert isinstance(up, int) and (up >= 1) 75 | assert isinstance(down, int) and (down >= 1) 76 | assert isinstance(groups, int) and (groups >= 1) 77 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) 78 | fw, fh = _get_filter_size(f) 79 | px0, px1, py0, py1 = _parse_padding(padding) 80 | 81 | # Adjust padding to account for up/downsampling. 82 | if up > 1: 83 | px0 += (fw + up - 1) // 2 84 | px1 += (fw - up) // 2 85 | py0 += (fh + up - 1) // 2 86 | py1 += (fh - up) // 2 87 | if down > 1: 88 | px0 += (fw - down + 1) // 2 89 | px1 += (fw - down) // 2 90 | py0 += (fh - down + 1) // 2 91 | py1 += (fh - down) // 2 92 | 93 | # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve. 94 | if kw == 1 and kh == 1 and (down > 1 and up == 1): 95 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter) 96 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 97 | return x 98 | 99 | # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample. 100 | if kw == 1 and kh == 1 and (up > 1 and down == 1): 101 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 102 | x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) 103 | return x 104 | 105 | # Fast path: downsampling only => use strided convolution. 106 | if down > 1 and up == 1: 107 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter) 108 | x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight) 109 | return x 110 | 111 | # Fast path: upsampling with optional downsampling => use transpose strided convolution. 112 | if up > 1: 113 | if groups == 1: 114 | w = w.transpose(0, 1) 115 | else: 116 | w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw) 117 | w = w.transpose(1, 2) 118 | w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw) 119 | px0 -= kw - 1 120 | px1 -= kw - up 121 | py0 -= kh - 1 122 | py1 -= kh - up 123 | pxt = max(min(-px0, -px1), 0) 124 | pyt = max(min(-py0, -py1), 0) 125 | x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight)) 126 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter) 127 | if down > 1: 128 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) 129 | return x 130 | 131 | # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d. 132 | if up == 1 and down == 1: 133 | if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0: 134 | return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight) 135 | 136 | # Fallback: Generic reference implementation. 137 | x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) 138 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 139 | if down > 1: 140 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) 141 | return x 142 | 143 | #---------------------------------------------------------------------------- 144 | -------------------------------------------------------------------------------- /add_turbo/torch_utils/ops/filtered_lrelu.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "filtered_lrelu.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static std::tuple filtered_lrelu( 17 | torch::Tensor x, torch::Tensor fu, torch::Tensor fd, torch::Tensor b, torch::Tensor si, 18 | int up, int down, int px0, int px1, int py0, int py1, int sx, int sy, float gain, float slope, float clamp, bool flip_filters, bool writeSigns) 19 | { 20 | // Set CUDA device. 21 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 22 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 23 | 24 | // Validate arguments. 25 | TORCH_CHECK(fu.device() == x.device() && fd.device() == x.device() && b.device() == x.device(), "all input tensors must reside on the same device"); 26 | TORCH_CHECK(fu.dtype() == torch::kFloat && fd.dtype() == torch::kFloat, "fu and fd must be float32"); 27 | TORCH_CHECK(b.dtype() == x.dtype(), "x and b must have the same dtype"); 28 | TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat, "x and b must be float16 or float32"); 29 | TORCH_CHECK(x.dim() == 4, "x must be rank 4"); 30 | TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large"); 31 | TORCH_CHECK(x.numel() > 0, "x is empty"); 32 | TORCH_CHECK((fu.dim() == 1 || fu.dim() == 2) && (fd.dim() == 1 || fd.dim() == 2), "fu and fd must be rank 1 or 2"); 33 | TORCH_CHECK(fu.size(0) <= INT_MAX && fu.size(-1) <= INT_MAX, "fu is too large"); 34 | TORCH_CHECK(fd.size(0) <= INT_MAX && fd.size(-1) <= INT_MAX, "fd is too large"); 35 | TORCH_CHECK(fu.numel() > 0, "fu is empty"); 36 | TORCH_CHECK(fd.numel() > 0, "fd is empty"); 37 | TORCH_CHECK(b.dim() == 1 && b.size(0) == x.size(1), "b must be a vector with the same number of channels as x"); 38 | TORCH_CHECK(up >= 1 && down >= 1, "up and down must be at least 1"); 39 | 40 | // Figure out how much shared memory is available on the device. 41 | int maxSharedBytes = 0; 42 | AT_CUDA_CHECK(cudaDeviceGetAttribute(&maxSharedBytes, cudaDevAttrMaxSharedMemoryPerBlockOptin, x.device().index())); 43 | int sharedKB = maxSharedBytes >> 10; 44 | 45 | // Populate enough launch parameters to check if a CUDA kernel exists. 46 | filtered_lrelu_kernel_params p; 47 | p.up = up; 48 | p.down = down; 49 | p.fuShape = make_int2((int)fu.size(-1), fu.dim() == 2 ? (int)fu.size(0) : 0); // shape [n, 0] indicates separable filter. 50 | p.fdShape = make_int2((int)fd.size(-1), fd.dim() == 2 ? (int)fd.size(0) : 0); 51 | filtered_lrelu_kernel_spec test_spec = choose_filtered_lrelu_kernel(p, sharedKB); 52 | if (!test_spec.exec) 53 | { 54 | // No kernel found - return empty tensors and indicate missing kernel with return code of -1. 55 | return std::make_tuple(torch::Tensor(), torch::Tensor(), -1); 56 | } 57 | 58 | // Input/output element size. 59 | int64_t sz = (x.dtype() == torch::kHalf) ? 2 : 4; 60 | 61 | // Input sizes. 62 | int64_t xw = (int)x.size(3); 63 | int64_t xh = (int)x.size(2); 64 | int64_t fut_w = (int)fu.size(-1) - 1; 65 | int64_t fut_h = (int)fu.size(0) - 1; 66 | int64_t fdt_w = (int)fd.size(-1) - 1; 67 | int64_t fdt_h = (int)fd.size(0) - 1; 68 | 69 | // Logical size of upsampled buffer. 70 | int64_t cw = xw * up + (px0 + px1) - fut_w; 71 | int64_t ch = xh * up + (py0 + py1) - fut_h; 72 | TORCH_CHECK(cw > fdt_w && ch > fdt_h, "upsampled buffer must be at least the size of downsampling filter"); 73 | TORCH_CHECK(cw <= INT_MAX && ch <= INT_MAX, "upsampled buffer is too large"); 74 | 75 | // Compute output size and allocate. 76 | int64_t yw = (cw - fdt_w + (down - 1)) / down; 77 | int64_t yh = (ch - fdt_h + (down - 1)) / down; 78 | TORCH_CHECK(yw > 0 && yh > 0, "output must be at least 1x1"); 79 | TORCH_CHECK(yw <= INT_MAX && yh <= INT_MAX, "output is too large"); 80 | torch::Tensor y = torch::empty({x.size(0), x.size(1), yh, yw}, x.options(), x.suggest_memory_format()); 81 | 82 | // Allocate sign tensor. 83 | torch::Tensor so; 84 | torch::Tensor s = si; 85 | bool readSigns = !!s.numel(); 86 | int64_t sw_active = 0; // Active width of sign tensor. 87 | if (writeSigns) 88 | { 89 | sw_active = yw * down - (down - 1) + fdt_w; // Active width in elements. 90 | int64_t sh = yh * down - (down - 1) + fdt_h; // Height = active height. 91 | int64_t sw = (sw_active + 15) & ~15; // Width = active width in elements, rounded up to multiple of 16. 92 | TORCH_CHECK(sh <= INT_MAX && (sw >> 2) <= INT_MAX, "signs is too large"); 93 | s = so = torch::empty({x.size(0), x.size(1), sh, sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous); 94 | } 95 | else if (readSigns) 96 | sw_active = s.size(3) << 2; 97 | 98 | // Validate sign tensor if in use. 99 | if (readSigns || writeSigns) 100 | { 101 | TORCH_CHECK(s.is_contiguous(), "signs must be contiguous"); 102 | TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8"); 103 | TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x"); 104 | TORCH_CHECK(s.dim() == 4, "signs must be rank 4"); 105 | TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x"); 106 | TORCH_CHECK(s.size(2) <= INT_MAX && s.size(3) <= INT_MAX, "signs is too large"); 107 | } 108 | 109 | // Populate rest of CUDA kernel parameters. 110 | p.x = x.data_ptr(); 111 | p.y = y.data_ptr(); 112 | p.b = b.data_ptr(); 113 | p.s = (readSigns || writeSigns) ? s.data_ptr() : 0; 114 | p.fu = fu.data_ptr(); 115 | p.fd = fd.data_ptr(); 116 | p.pad0 = make_int2(px0, py0); 117 | p.gain = gain; 118 | p.slope = slope; 119 | p.clamp = clamp; 120 | p.flip = (flip_filters) ? 1 : 0; 121 | p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); 122 | p.yShape = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); 123 | p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3), (int)s.size(2)) : make_int2(0, 0); // Width is in bytes. Contiguous. 124 | p.sOfs = make_int2(sx, sy); 125 | p.swLimit = (sw_active + 3) >> 2; // Rounded up to bytes. 126 | 127 | // x, y, b strides are in bytes. 128 | p.xStride = make_longlong4(sz * x.stride(3), sz * x.stride(2), sz * x.stride(1), sz * x.stride(0)); 129 | p.yStride = make_longlong4(sz * y.stride(3), sz * y.stride(2), sz * y.stride(1), sz * y.stride(0)); 130 | p.bStride = sz * b.stride(0); 131 | 132 | // fu, fd strides are in elements. 133 | p.fuStride = make_longlong3(fu.stride(-1), fu.dim() == 2 ? fu.stride(0) : 0, 0); 134 | p.fdStride = make_longlong3(fd.stride(-1), fd.dim() == 2 ? fd.stride(0) : 0, 0); 135 | 136 | // Determine if indices don't fit in int32. Support negative strides although Torch currently never produces those. 137 | bool index64b = false; 138 | if (std::abs(p.bStride * x.size(1)) > INT_MAX) index64b = true; 139 | if (std::min(x.size(0) * p.xStride.w, 0ll) + std::min(x.size(1) * p.xStride.z, 0ll) + std::min(x.size(2) * p.xStride.y, 0ll) + std::min(x.size(3) * p.xStride.x, 0ll) < -INT_MAX) index64b = true; 140 | if (std::max(x.size(0) * p.xStride.w, 0ll) + std::max(x.size(1) * p.xStride.z, 0ll) + std::max(x.size(2) * p.xStride.y, 0ll) + std::max(x.size(3) * p.xStride.x, 0ll) > INT_MAX) index64b = true; 141 | if (std::min(y.size(0) * p.yStride.w, 0ll) + std::min(y.size(1) * p.yStride.z, 0ll) + std::min(y.size(2) * p.yStride.y, 0ll) + std::min(y.size(3) * p.yStride.x, 0ll) < -INT_MAX) index64b = true; 142 | if (std::max(y.size(0) * p.yStride.w, 0ll) + std::max(y.size(1) * p.yStride.z, 0ll) + std::max(y.size(2) * p.yStride.y, 0ll) + std::max(y.size(3) * p.yStride.x, 0ll) > INT_MAX) index64b = true; 143 | if (s.numel() > INT_MAX) index64b = true; 144 | 145 | // Choose CUDA kernel. 146 | filtered_lrelu_kernel_spec spec = { 0 }; 147 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_cuda", [&] 148 | { 149 | if constexpr (sizeof(scalar_t) <= 4) // Exclude doubles. constexpr prevents template instantiation. 150 | { 151 | // Choose kernel based on index type, datatype and sign read/write modes. 152 | if (!index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB); 153 | else if (!index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB); 154 | else if (!index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB); 155 | else if ( index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB); 156 | else if ( index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB); 157 | else if ( index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB); 158 | } 159 | }); 160 | TORCH_CHECK(spec.exec, "internal error - CUDA kernel not found") // This should not happen because we tested earlier that kernel exists. 161 | 162 | // Launch CUDA kernel. 163 | void* args[] = {&p}; 164 | int bx = spec.numWarps * 32; 165 | int gx = (p.yShape.x - 1) / spec.tileOut.x + 1; 166 | int gy = (p.yShape.y - 1) / spec.tileOut.y + 1; 167 | int gz = p.yShape.z * p.yShape.w; 168 | 169 | // Repeat multiple horizontal tiles in a CTA? 170 | if (spec.xrep) 171 | { 172 | p.tilesXrep = spec.xrep; 173 | p.tilesXdim = gx; 174 | 175 | gx = (gx + p.tilesXrep - 1) / p.tilesXrep; 176 | std::swap(gx, gy); 177 | } 178 | else 179 | { 180 | p.tilesXrep = 0; 181 | p.tilesXdim = 0; 182 | } 183 | 184 | // Launch filter setup kernel. 185 | AT_CUDA_CHECK(cudaLaunchKernel(spec.setup, 1, 1024, args, 0, at::cuda::getCurrentCUDAStream())); 186 | 187 | // Copy kernels to constant memory. 188 | if ( writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters(at::cuda::getCurrentCUDAStream()))); 189 | else if (!writeSigns && readSigns) AT_CUDA_CHECK((copy_filters(at::cuda::getCurrentCUDAStream()))); 190 | else if (!writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters(at::cuda::getCurrentCUDAStream()))); 191 | 192 | // Set cache and shared memory configurations for main kernel. 193 | AT_CUDA_CHECK(cudaFuncSetCacheConfig(spec.exec, cudaFuncCachePreferShared)); 194 | if (spec.dynamicSharedKB) // Need dynamically allocated shared memory? 195 | AT_CUDA_CHECK(cudaFuncSetAttribute(spec.exec, cudaFuncAttributeMaxDynamicSharedMemorySize, spec.dynamicSharedKB << 10)); 196 | AT_CUDA_CHECK(cudaFuncSetSharedMemConfig(spec.exec, cudaSharedMemBankSizeFourByte)); 197 | 198 | // Launch main kernel. 199 | const int maxSubGz = 65535; // CUDA maximum for block z dimension. 200 | for (int zofs=0; zofs < gz; zofs += maxSubGz) // Do multiple launches if gz is too big. 201 | { 202 | p.blockZofs = zofs; 203 | int subGz = std::min(maxSubGz, gz - zofs); 204 | AT_CUDA_CHECK(cudaLaunchKernel(spec.exec, dim3(gx, gy, subGz), bx, args, spec.dynamicSharedKB << 10, at::cuda::getCurrentCUDAStream())); 205 | } 206 | 207 | // Done. 208 | return std::make_tuple(y, so, 0); 209 | } 210 | 211 | //------------------------------------------------------------------------ 212 | 213 | static torch::Tensor filtered_lrelu_act(torch::Tensor x, torch::Tensor si, int sx, int sy, float gain, float slope, float clamp, bool writeSigns) 214 | { 215 | // Set CUDA device. 216 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 217 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 218 | 219 | // Validate arguments. 220 | TORCH_CHECK(x.dim() == 4, "x must be rank 4"); 221 | TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large"); 222 | TORCH_CHECK(x.numel() > 0, "x is empty"); 223 | TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat || x.dtype() == torch::kDouble, "x must be float16, float32 or float64"); 224 | 225 | // Output signs if we don't have sign input. 226 | torch::Tensor so; 227 | torch::Tensor s = si; 228 | bool readSigns = !!s.numel(); 229 | if (writeSigns) 230 | { 231 | int64_t sw = x.size(3); 232 | sw = (sw + 15) & ~15; // Round to a multiple of 16 for coalescing. 233 | s = so = torch::empty({x.size(0), x.size(1), x.size(2), sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous); 234 | } 235 | 236 | // Validate sign tensor if in use. 237 | if (readSigns || writeSigns) 238 | { 239 | TORCH_CHECK(s.is_contiguous(), "signs must be contiguous"); 240 | TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8"); 241 | TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x"); 242 | TORCH_CHECK(s.dim() == 4, "signs must be rank 4"); 243 | TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x"); 244 | TORCH_CHECK(s.size(2) <= INT_MAX && (s.size(3) << 2) <= INT_MAX, "signs tensor is too large"); 245 | } 246 | 247 | // Initialize CUDA kernel parameters. 248 | filtered_lrelu_act_kernel_params p; 249 | p.x = x.data_ptr(); 250 | p.s = (readSigns || writeSigns) ? s.data_ptr() : 0; 251 | p.gain = gain; 252 | p.slope = slope; 253 | p.clamp = clamp; 254 | p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); 255 | p.xStride = make_longlong4(x.stride(3), x.stride(2), x.stride(1), x.stride(0)); 256 | p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3) << 2, (int)s.size(2)) : make_int2(0, 0); // Width is in elements. Contiguous. 257 | p.sOfs = make_int2(sx, sy); 258 | 259 | // Choose CUDA kernel. 260 | void* func = 0; 261 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_act_cuda", [&] 262 | { 263 | if (writeSigns) 264 | func = choose_filtered_lrelu_act_kernel(); 265 | else if (readSigns) 266 | func = choose_filtered_lrelu_act_kernel(); 267 | else 268 | func = choose_filtered_lrelu_act_kernel(); 269 | }); 270 | TORCH_CHECK(func, "internal error - CUDA kernel not found"); 271 | 272 | // Launch CUDA kernel. 273 | void* args[] = {&p}; 274 | int bx = 128; // 4 warps per block. 275 | 276 | // Logical size of launch = writeSigns ? p.s : p.x 277 | uint32_t gx = writeSigns ? p.sShape.x : p.xShape.x; 278 | uint32_t gy = writeSigns ? p.sShape.y : p.xShape.y; 279 | uint32_t gz = p.xShape.z * p.xShape.w; // Same as in p.sShape if signs are in use. 280 | gx = (gx - 1) / bx + 1; 281 | 282 | // Make sure grid y and z dimensions are within CUDA launch limits. Kernel loops internally to do the rest. 283 | const uint32_t gmax = 65535; 284 | gy = std::min(gy, gmax); 285 | gz = std::min(gz, gmax); 286 | 287 | // Launch. 288 | AT_CUDA_CHECK(cudaLaunchKernel(func, dim3(gx, gy, gz), bx, args, 0, at::cuda::getCurrentCUDAStream())); 289 | return so; 290 | } 291 | 292 | //------------------------------------------------------------------------ 293 | 294 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 295 | { 296 | m.def("filtered_lrelu", &filtered_lrelu); // The whole thing. 297 | m.def("filtered_lrelu_act_", &filtered_lrelu_act); // Activation and sign tensor handling only. Modifies data tensor in-place. 298 | } 299 | 300 | //------------------------------------------------------------------------ 301 | -------------------------------------------------------------------------------- /add_turbo/torch_utils/ops/filtered_lrelu.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | 11 | //------------------------------------------------------------------------ 12 | // CUDA kernel parameters. 13 | 14 | struct filtered_lrelu_kernel_params 15 | { 16 | // These parameters decide which kernel to use. 17 | int up; // upsampling ratio (1, 2, 4) 18 | int down; // downsampling ratio (1, 2, 4) 19 | int2 fuShape; // [size, 1] | [size, size] 20 | int2 fdShape; // [size, 1] | [size, size] 21 | 22 | int _dummy; // Alignment. 23 | 24 | // Rest of the parameters. 25 | const void* x; // Input tensor. 26 | void* y; // Output tensor. 27 | const void* b; // Bias tensor. 28 | unsigned char* s; // Sign tensor in/out. NULL if unused. 29 | const float* fu; // Upsampling filter. 30 | const float* fd; // Downsampling filter. 31 | 32 | int2 pad0; // Left/top padding. 33 | float gain; // Additional gain factor. 34 | float slope; // Leaky ReLU slope on negative side. 35 | float clamp; // Clamp after nonlinearity. 36 | int flip; // Filter kernel flip for gradient computation. 37 | 38 | int tilesXdim; // Original number of horizontal output tiles. 39 | int tilesXrep; // Number of horizontal tiles per CTA. 40 | int blockZofs; // Block z offset to support large minibatch, channel dimensions. 41 | 42 | int4 xShape; // [width, height, channel, batch] 43 | int4 yShape; // [width, height, channel, batch] 44 | int2 sShape; // [width, height] - width is in bytes. Contiguous. Zeros if unused. 45 | int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor. 46 | int swLimit; // Active width of sign tensor in bytes. 47 | 48 | longlong4 xStride; // Strides of all tensors except signs, same component order as shapes. 49 | longlong4 yStride; // 50 | int64_t bStride; // 51 | longlong3 fuStride; // 52 | longlong3 fdStride; // 53 | }; 54 | 55 | struct filtered_lrelu_act_kernel_params 56 | { 57 | void* x; // Input/output, modified in-place. 58 | unsigned char* s; // Sign tensor in/out. NULL if unused. 59 | 60 | float gain; // Additional gain factor. 61 | float slope; // Leaky ReLU slope on negative side. 62 | float clamp; // Clamp after nonlinearity. 63 | 64 | int4 xShape; // [width, height, channel, batch] 65 | longlong4 xStride; // Input/output tensor strides, same order as in shape. 66 | int2 sShape; // [width, height] - width is in elements. Contiguous. Zeros if unused. 67 | int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor. 68 | }; 69 | 70 | //------------------------------------------------------------------------ 71 | // CUDA kernel specialization. 72 | 73 | struct filtered_lrelu_kernel_spec 74 | { 75 | void* setup; // Function for filter kernel setup. 76 | void* exec; // Function for main operation. 77 | int2 tileOut; // Width/height of launch tile. 78 | int numWarps; // Number of warps per thread block, determines launch block size. 79 | int xrep; // For processing multiple horizontal tiles per thread block. 80 | int dynamicSharedKB; // How much dynamic shared memory the exec kernel wants. 81 | }; 82 | 83 | //------------------------------------------------------------------------ 84 | // CUDA kernel selection. 85 | 86 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 87 | template void* choose_filtered_lrelu_act_kernel(void); 88 | template cudaError_t copy_filters(cudaStream_t stream); 89 | 90 | //------------------------------------------------------------------------ 91 | -------------------------------------------------------------------------------- /add_turbo/torch_utils/ops/filtered_lrelu.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import os 10 | import numpy as np 11 | import torch 12 | import warnings 13 | 14 | from .. import custom_ops 15 | from .. import misc 16 | from . import upfirdn2d 17 | from . import bias_act 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | _plugin = None 22 | 23 | def _init(): 24 | global _plugin 25 | if _plugin is None: 26 | _plugin = custom_ops.get_plugin( 27 | module_name='filtered_lrelu_plugin', 28 | sources=['filtered_lrelu.cpp', 'filtered_lrelu_wr.cu', 'filtered_lrelu_rd.cu', 'filtered_lrelu_ns.cu'], 29 | headers=['filtered_lrelu.h', 'filtered_lrelu.cu'], 30 | source_dir=os.path.dirname(__file__), 31 | extra_cuda_cflags=['--use_fast_math', '--allow-unsupported-compiler'], 32 | ) 33 | return True 34 | 35 | def _get_filter_size(f): 36 | if f is None: 37 | return 1, 1 38 | assert isinstance(f, torch.Tensor) 39 | assert 1 <= f.ndim <= 2 40 | return f.shape[-1], f.shape[0] # width, height 41 | 42 | def _parse_padding(padding): 43 | if isinstance(padding, int): 44 | padding = [padding, padding] 45 | assert isinstance(padding, (list, tuple)) 46 | assert all(isinstance(x, (int, np.integer)) for x in padding) 47 | padding = [int(x) for x in padding] 48 | if len(padding) == 2: 49 | px, py = padding 50 | padding = [px, px, py, py] 51 | px0, px1, py0, py1 = padding 52 | return px0, px1, py0, py1 53 | 54 | #---------------------------------------------------------------------------- 55 | 56 | def filtered_lrelu(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False, impl='cuda'): 57 | r"""Filtered leaky ReLU for a batch of 2D images. 58 | 59 | Performs the following sequence of operations for each channel: 60 | 61 | 1. Add channel-specific bias if provided (`b`). 62 | 63 | 2. Upsample the image by inserting N-1 zeros after each pixel (`up`). 64 | 65 | 3. Pad the image with the specified number of zeros on each side (`padding`). 66 | Negative padding corresponds to cropping the image. 67 | 68 | 4. Convolve the image with the specified upsampling FIR filter (`fu`), shrinking it 69 | so that the footprint of all output pixels lies within the input image. 70 | 71 | 5. Multiply each value by the provided gain factor (`gain`). 72 | 73 | 6. Apply leaky ReLU activation function to each value. 74 | 75 | 7. Clamp each value between -clamp and +clamp, if `clamp` parameter is provided. 76 | 77 | 8. Convolve the image with the specified downsampling FIR filter (`fd`), shrinking 78 | it so that the footprint of all output pixels lies within the input image. 79 | 80 | 9. Downsample the image by keeping every Nth pixel (`down`). 81 | 82 | The fused op is considerably more efficient than performing the same calculation 83 | using standard PyTorch ops. It supports gradients of arbitrary order. 84 | 85 | Args: 86 | x: Float32/float16/float64 input tensor of the shape 87 | `[batch_size, num_channels, in_height, in_width]`. 88 | fu: Float32 upsampling FIR filter of the shape 89 | `[filter_height, filter_width]` (non-separable), 90 | `[filter_taps]` (separable), or 91 | `None` (identity). 92 | fd: Float32 downsampling FIR filter of the shape 93 | `[filter_height, filter_width]` (non-separable), 94 | `[filter_taps]` (separable), or 95 | `None` (identity). 96 | b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type 97 | as `x`. The length of vector must must match the channel dimension of `x`. 98 | up: Integer upsampling factor (default: 1). 99 | down: Integer downsampling factor. (default: 1). 100 | padding: Padding with respect to the upsampled image. Can be a single number 101 | or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` 102 | (default: 0). 103 | gain: Overall scaling factor for signal magnitude (default: sqrt(2)). 104 | slope: Slope on the negative side of leaky ReLU (default: 0.2). 105 | clamp: Maximum magnitude for leaky ReLU output (default: None). 106 | flip_filter: False = convolution, True = correlation (default: False). 107 | impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). 108 | 109 | Returns: 110 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. 111 | """ 112 | assert isinstance(x, torch.Tensor) 113 | assert impl in ['ref', 'cuda'] 114 | if impl == 'cuda' and x.device.type == 'cuda' and _init(): 115 | return _filtered_lrelu_cuda(up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter).apply(x, fu, fd, b, None, 0, 0) 116 | return _filtered_lrelu_ref(x, fu=fu, fd=fd, b=b, up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter) 117 | 118 | #---------------------------------------------------------------------------- 119 | 120 | @misc.profiled_function 121 | def _filtered_lrelu_ref(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False): 122 | """Slow and memory-inefficient reference implementation of `filtered_lrelu()` using 123 | existing `upfirdn2n()` and `bias_act()` ops. 124 | """ 125 | assert isinstance(x, torch.Tensor) and x.ndim == 4 126 | fu_w, fu_h = _get_filter_size(fu) 127 | fd_w, fd_h = _get_filter_size(fd) 128 | if b is not None: 129 | assert isinstance(b, torch.Tensor) and b.dtype == x.dtype 130 | misc.assert_shape(b, [x.shape[1]]) 131 | assert isinstance(up, int) and up >= 1 132 | assert isinstance(down, int) and down >= 1 133 | px0, px1, py0, py1 = _parse_padding(padding) 134 | assert gain == float(gain) and gain > 0 135 | assert slope == float(slope) and slope >= 0 136 | assert clamp is None or (clamp == float(clamp) and clamp >= 0) 137 | 138 | # Calculate output size. 139 | batch_size, channels, in_h, in_w = x.shape 140 | in_dtype = x.dtype 141 | out_w = (in_w * up + (px0 + px1) - (fu_w - 1) - (fd_w - 1) + (down - 1)) // down 142 | out_h = (in_h * up + (py0 + py1) - (fu_h - 1) - (fd_h - 1) + (down - 1)) // down 143 | 144 | # Compute using existing ops. 145 | x = bias_act.bias_act(x=x, b=b) # Apply bias. 146 | x = upfirdn2d.upfirdn2d(x=x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample. 147 | x = bias_act.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) # Bias, leaky ReLU, clamp. 148 | x = upfirdn2d.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter) # Downsample. 149 | 150 | # Check output shape & dtype. 151 | misc.assert_shape(x, [batch_size, channels, out_h, out_w]) 152 | assert x.dtype == in_dtype 153 | return x 154 | 155 | #---------------------------------------------------------------------------- 156 | 157 | _filtered_lrelu_cuda_cache = dict() 158 | 159 | def _filtered_lrelu_cuda(up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False): 160 | """Fast CUDA implementation of `filtered_lrelu()` using custom ops. 161 | """ 162 | assert isinstance(up, int) and up >= 1 163 | assert isinstance(down, int) and down >= 1 164 | px0, px1, py0, py1 = _parse_padding(padding) 165 | assert gain == float(gain) and gain > 0 166 | gain = float(gain) 167 | assert slope == float(slope) and slope >= 0 168 | slope = float(slope) 169 | assert clamp is None or (clamp == float(clamp) and clamp >= 0) 170 | clamp = float(clamp if clamp is not None else 'inf') 171 | 172 | # Lookup from cache. 173 | key = (up, down, px0, px1, py0, py1, gain, slope, clamp, flip_filter) 174 | if key in _filtered_lrelu_cuda_cache: 175 | return _filtered_lrelu_cuda_cache[key] 176 | 177 | # Forward op. 178 | class FilteredLReluCuda(torch.autograd.Function): 179 | @staticmethod 180 | def forward(ctx, x, fu, fd, b, si, sx, sy): # pylint: disable=arguments-differ 181 | assert isinstance(x, torch.Tensor) and x.ndim == 4 182 | 183 | # Replace empty up/downsample kernels with full 1x1 kernels (faster than separable). 184 | if fu is None: 185 | fu = torch.ones([1, 1], dtype=torch.float32, device=x.device) 186 | if fd is None: 187 | fd = torch.ones([1, 1], dtype=torch.float32, device=x.device) 188 | assert 1 <= fu.ndim <= 2 189 | assert 1 <= fd.ndim <= 2 190 | 191 | # Replace separable 1x1 kernels with full 1x1 kernels when scale factor is 1. 192 | if up == 1 and fu.ndim == 1 and fu.shape[0] == 1: 193 | fu = fu.square()[None] 194 | if down == 1 and fd.ndim == 1 and fd.shape[0] == 1: 195 | fd = fd.square()[None] 196 | 197 | # Missing sign input tensor. 198 | if si is None: 199 | si = torch.empty([0]) 200 | 201 | # Missing bias tensor. 202 | if b is None: 203 | b = torch.zeros([x.shape[1]], dtype=x.dtype, device=x.device) 204 | 205 | # Construct internal sign tensor only if gradients are needed. 206 | write_signs = (si.numel() == 0) and (x.requires_grad or b.requires_grad) 207 | 208 | # Warn if input storage strides are not in decreasing order due to e.g. channels-last layout. 209 | strides = [x.stride(i) for i in range(x.ndim) if x.size(i) > 1] 210 | if any(a < b for a, b in zip(strides[:-1], strides[1:])): 211 | warnings.warn("low-performance memory layout detected in filtered_lrelu input", RuntimeWarning) 212 | 213 | # Call C++/Cuda plugin if datatype is supported. 214 | if x.dtype in [torch.float16, torch.float32]: 215 | if torch.cuda.current_stream(x.device) != torch.cuda.default_stream(x.device): 216 | warnings.warn("filtered_lrelu called with non-default cuda stream but concurrent execution is not supported", RuntimeWarning) 217 | y, so, return_code = _plugin.filtered_lrelu(x, fu, fd, b, si, up, down, px0, px1, py0, py1, sx, sy, gain, slope, clamp, flip_filter, write_signs) 218 | else: 219 | return_code = -1 220 | 221 | # No Cuda kernel found? Fall back to generic implementation. Still more memory efficient than the reference implementation because 222 | # only the bit-packed sign tensor is retained for gradient computation. 223 | if return_code < 0: 224 | warnings.warn("filtered_lrelu called with parameters that have no optimized CUDA kernel, using generic fallback", RuntimeWarning) 225 | 226 | y = x.add(b.unsqueeze(-1).unsqueeze(-1)) # Add bias. 227 | y = upfirdn2d.upfirdn2d(x=y, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample. 228 | so = _plugin.filtered_lrelu_act_(y, si, sx, sy, gain, slope, clamp, write_signs) # Activation function and sign handling. Modifies y in-place. 229 | y = upfirdn2d.upfirdn2d(x=y, f=fd, down=down, flip_filter=flip_filter) # Downsample. 230 | 231 | # Prepare for gradient computation. 232 | ctx.save_for_backward(fu, fd, (si if si.numel() else so)) 233 | ctx.x_shape = x.shape 234 | ctx.y_shape = y.shape 235 | ctx.s_ofs = sx, sy 236 | return y 237 | 238 | @staticmethod 239 | def backward(ctx, dy): # pylint: disable=arguments-differ 240 | fu, fd, si = ctx.saved_tensors 241 | _, _, xh, xw = ctx.x_shape 242 | _, _, yh, yw = ctx.y_shape 243 | sx, sy = ctx.s_ofs 244 | dx = None # 0 245 | dfu = None; assert not ctx.needs_input_grad[1] 246 | dfd = None; assert not ctx.needs_input_grad[2] 247 | db = None # 3 248 | dsi = None; assert not ctx.needs_input_grad[4] 249 | dsx = None; assert not ctx.needs_input_grad[5] 250 | dsy = None; assert not ctx.needs_input_grad[6] 251 | 252 | if ctx.needs_input_grad[0] or ctx.needs_input_grad[3]: 253 | pp = [ 254 | (fu.shape[-1] - 1) + (fd.shape[-1] - 1) - px0, 255 | xw * up - yw * down + px0 - (up - 1), 256 | (fu.shape[0] - 1) + (fd.shape[0] - 1) - py0, 257 | xh * up - yh * down + py0 - (up - 1), 258 | ] 259 | gg = gain * (up ** 2) / (down ** 2) 260 | ff = (not flip_filter) 261 | sx = sx - (fu.shape[-1] - 1) + px0 262 | sy = sy - (fu.shape[0] - 1) + py0 263 | dx = _filtered_lrelu_cuda(up=down, down=up, padding=pp, gain=gg, slope=slope, clamp=None, flip_filter=ff).apply(dy, fd, fu, None, si, sx, sy) 264 | 265 | if ctx.needs_input_grad[3]: 266 | db = dx.sum([0, 2, 3]) 267 | 268 | return dx, dfu, dfd, db, dsi, dsx, dsy 269 | 270 | # Add to cache. 271 | _filtered_lrelu_cuda_cache[key] = FilteredLReluCuda 272 | return FilteredLReluCuda 273 | 274 | #---------------------------------------------------------------------------- 275 | -------------------------------------------------------------------------------- /add_turbo/torch_utils/ops/filtered_lrelu_ns.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include "filtered_lrelu.cu" 10 | 11 | // Template/kernel specializations for no signs mode (no gradients required). 12 | 13 | // Full op, 32-bit indexing. 14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 16 | 17 | // Full op, 64-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Activation/signs only for generic variant. 64-bit indexing. 22 | template void* choose_filtered_lrelu_act_kernel(void); 23 | template void* choose_filtered_lrelu_act_kernel(void); 24 | template void* choose_filtered_lrelu_act_kernel(void); 25 | 26 | // Copy filters to constant memory. 27 | template cudaError_t copy_filters(cudaStream_t stream); 28 | -------------------------------------------------------------------------------- /add_turbo/torch_utils/ops/filtered_lrelu_rd.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include "filtered_lrelu.cu" 10 | 11 | // Template/kernel specializations for sign read mode. 12 | 13 | // Full op, 32-bit indexing. 14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 16 | 17 | // Full op, 64-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Activation/signs only for generic variant. 64-bit indexing. 22 | template void* choose_filtered_lrelu_act_kernel(void); 23 | template void* choose_filtered_lrelu_act_kernel(void); 24 | template void* choose_filtered_lrelu_act_kernel(void); 25 | 26 | // Copy filters to constant memory. 27 | template cudaError_t copy_filters(cudaStream_t stream); 28 | -------------------------------------------------------------------------------- /add_turbo/torch_utils/ops/filtered_lrelu_wr.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include "filtered_lrelu.cu" 10 | 11 | // Template/kernel specializations for sign write mode. 12 | 13 | // Full op, 32-bit indexing. 14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 16 | 17 | // Full op, 64-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Activation/signs only for generic variant. 64-bit indexing. 22 | template void* choose_filtered_lrelu_act_kernel(void); 23 | template void* choose_filtered_lrelu_act_kernel(void); 24 | template void* choose_filtered_lrelu_act_kernel(void); 25 | 26 | // Copy filters to constant memory. 27 | template cudaError_t copy_filters(cudaStream_t stream); 28 | -------------------------------------------------------------------------------- /add_turbo/torch_utils/ops/fma.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.""" 10 | 11 | import torch 12 | 13 | #---------------------------------------------------------------------------- 14 | 15 | def fma(a, b, c): # => a * b + c 16 | return _FusedMultiplyAdd.apply(a, b, c) 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c 21 | @staticmethod 22 | def forward(ctx, a, b, c): # pylint: disable=arguments-differ 23 | out = torch.addcmul(c, a, b) 24 | ctx.save_for_backward(a, b) 25 | ctx.c_shape = c.shape 26 | return out 27 | 28 | @staticmethod 29 | def backward(ctx, dout): # pylint: disable=arguments-differ 30 | a, b = ctx.saved_tensors 31 | c_shape = ctx.c_shape 32 | da = None 33 | db = None 34 | dc = None 35 | 36 | if ctx.needs_input_grad[0]: 37 | da = _unbroadcast(dout * b, a.shape) 38 | 39 | if ctx.needs_input_grad[1]: 40 | db = _unbroadcast(dout * a, b.shape) 41 | 42 | if ctx.needs_input_grad[2]: 43 | dc = _unbroadcast(dout, c_shape) 44 | 45 | return da, db, dc 46 | 47 | #---------------------------------------------------------------------------- 48 | 49 | def _unbroadcast(x, shape): 50 | extra_dims = x.ndim - len(shape) 51 | assert extra_dims >= 0 52 | dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] 53 | if len(dim): 54 | x = x.sum(dim=dim, keepdim=True) 55 | if extra_dims: 56 | x = x.reshape(-1, *x.shape[extra_dims+1:]) 57 | assert x.shape == shape 58 | return x 59 | 60 | #---------------------------------------------------------------------------- 61 | -------------------------------------------------------------------------------- /add_turbo/torch_utils/ops/grid_sample_gradfix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom replacement for `torch.nn.functional.grid_sample` that 10 | supports arbitrarily high order gradients between the input and output. 11 | Only works on 2D images and assumes 12 | `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`.""" 13 | 14 | import torch 15 | from pkg_resources import parse_version 16 | 17 | # pylint: disable=redefined-builtin 18 | # pylint: disable=arguments-differ 19 | # pylint: disable=protected-access 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | enabled = False # Enable the custom op by setting this to true. 24 | _use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11 25 | 26 | #---------------------------------------------------------------------------- 27 | 28 | def grid_sample(input, grid): 29 | if _should_use_custom_op(): 30 | return _GridSample2dForward.apply(input, grid) 31 | return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | def _should_use_custom_op(): 36 | return enabled 37 | 38 | #---------------------------------------------------------------------------- 39 | 40 | class _GridSample2dForward(torch.autograd.Function): 41 | @staticmethod 42 | def forward(ctx, input, grid): 43 | assert input.ndim == 4 44 | assert grid.ndim == 4 45 | output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 46 | ctx.save_for_backward(input, grid) 47 | return output 48 | 49 | @staticmethod 50 | def backward(ctx, grad_output): 51 | input, grid = ctx.saved_tensors 52 | grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) 53 | return grad_input, grad_grid 54 | 55 | #---------------------------------------------------------------------------- 56 | 57 | class _GridSample2dBackward(torch.autograd.Function): 58 | @staticmethod 59 | def forward(ctx, grad_output, input, grid): 60 | op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') 61 | if _use_pytorch_1_11_api: 62 | output_mask = (ctx.needs_input_grad[1], ctx.needs_input_grad[2]) 63 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False, output_mask) 64 | else: 65 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) 66 | ctx.save_for_backward(grid) 67 | return grad_input, grad_grid 68 | 69 | @staticmethod 70 | def backward(ctx, grad2_grad_input, grad2_grad_grid): 71 | _ = grad2_grad_grid # unused 72 | grid, = ctx.saved_tensors 73 | grad2_grad_output = None 74 | grad2_input = None 75 | grad2_grid = None 76 | 77 | if ctx.needs_input_grad[0]: 78 | grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) 79 | 80 | assert not ctx.needs_input_grad[2] 81 | return grad2_grad_output, grad2_input, grad2_grid 82 | 83 | #---------------------------------------------------------------------------- 84 | -------------------------------------------------------------------------------- /add_turbo/torch_utils/ops/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "upfirdn2d.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain) 17 | { 18 | // Validate arguments. 19 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 20 | TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x"); 21 | TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32"); 22 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 23 | TORCH_CHECK(f.numel() <= INT_MAX, "f is too large"); 24 | TORCH_CHECK(x.numel() > 0, "x has zero size"); 25 | TORCH_CHECK(f.numel() > 0, "f has zero size"); 26 | TORCH_CHECK(x.dim() == 4, "x must be rank 4"); 27 | TORCH_CHECK(f.dim() == 2, "f must be rank 2"); 28 | TORCH_CHECK((x.size(0)-1)*x.stride(0) + (x.size(1)-1)*x.stride(1) + (x.size(2)-1)*x.stride(2) + (x.size(3)-1)*x.stride(3) <= INT_MAX, "x memory footprint is too large"); 29 | TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1"); 30 | TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1"); 31 | TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1"); 32 | 33 | // Create output tensor. 34 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 35 | int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx; 36 | int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy; 37 | TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1"); 38 | torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format()); 39 | TORCH_CHECK(y.numel() <= INT_MAX, "output is too large"); 40 | TORCH_CHECK((y.size(0)-1)*y.stride(0) + (y.size(1)-1)*y.stride(1) + (y.size(2)-1)*y.stride(2) + (y.size(3)-1)*y.stride(3) <= INT_MAX, "output memory footprint is too large"); 41 | 42 | // Initialize CUDA kernel parameters. 43 | upfirdn2d_kernel_params p; 44 | p.x = x.data_ptr(); 45 | p.f = f.data_ptr(); 46 | p.y = y.data_ptr(); 47 | p.up = make_int2(upx, upy); 48 | p.down = make_int2(downx, downy); 49 | p.pad0 = make_int2(padx0, pady0); 50 | p.flip = (flip) ? 1 : 0; 51 | p.gain = gain; 52 | p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); 53 | p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0)); 54 | p.filterSize = make_int2((int)f.size(1), (int)f.size(0)); 55 | p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0)); 56 | p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); 57 | p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0)); 58 | p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z; 59 | p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1; 60 | 61 | // Choose CUDA kernel. 62 | upfirdn2d_kernel_spec spec; 63 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 64 | { 65 | spec = choose_upfirdn2d_kernel(p); 66 | }); 67 | 68 | // Set looping options. 69 | p.loopMajor = (p.sizeMajor - 1) / 16384 + 1; 70 | p.loopMinor = spec.loopMinor; 71 | p.loopX = spec.loopX; 72 | p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1; 73 | p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1; 74 | 75 | // Compute grid size. 76 | dim3 blockSize, gridSize; 77 | if (spec.tileOutW < 0) // large 78 | { 79 | blockSize = dim3(4, 32, 1); 80 | gridSize = dim3( 81 | ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor, 82 | (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, 83 | p.launchMajor); 84 | } 85 | else // small 86 | { 87 | blockSize = dim3(256, 1, 1); 88 | gridSize = dim3( 89 | ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor, 90 | (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, 91 | p.launchMajor); 92 | } 93 | 94 | // Launch CUDA kernel. 95 | void* args[] = {&p}; 96 | AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 97 | return y; 98 | } 99 | 100 | //------------------------------------------------------------------------ 101 | 102 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 103 | { 104 | m.def("upfirdn2d", &upfirdn2d); 105 | } 106 | 107 | //------------------------------------------------------------------------ 108 | -------------------------------------------------------------------------------- /add_turbo/torch_utils/ops/upfirdn2d.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | 11 | //------------------------------------------------------------------------ 12 | // CUDA kernel parameters. 13 | 14 | struct upfirdn2d_kernel_params 15 | { 16 | const void* x; 17 | const float* f; 18 | void* y; 19 | 20 | int2 up; 21 | int2 down; 22 | int2 pad0; 23 | int flip; 24 | float gain; 25 | 26 | int4 inSize; // [width, height, channel, batch] 27 | int4 inStride; 28 | int2 filterSize; // [width, height] 29 | int2 filterStride; 30 | int4 outSize; // [width, height, channel, batch] 31 | int4 outStride; 32 | int sizeMinor; 33 | int sizeMajor; 34 | 35 | int loopMinor; 36 | int loopMajor; 37 | int loopX; 38 | int launchMinor; 39 | int launchMajor; 40 | }; 41 | 42 | //------------------------------------------------------------------------ 43 | // CUDA kernel specialization. 44 | 45 | struct upfirdn2d_kernel_spec 46 | { 47 | void* kernel; 48 | int tileOutW; 49 | int tileOutH; 50 | int loopMinor; 51 | int loopX; 52 | }; 53 | 54 | //------------------------------------------------------------------------ 55 | // CUDA kernel selection. 56 | 57 | template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); 58 | 59 | //------------------------------------------------------------------------ 60 | -------------------------------------------------------------------------------- /add_turbo/torch_utils/ops/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom PyTorch ops for efficient resampling of 2D images.""" 10 | 11 | import os 12 | import numpy as np 13 | import torch 14 | 15 | from .. import custom_ops 16 | from .. import misc 17 | from . import conv2d_gradfix 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | _plugin = None 22 | 23 | def _init(): 24 | global _plugin 25 | if _plugin is None: 26 | _plugin = custom_ops.get_plugin( 27 | module_name='upfirdn2d_plugin', 28 | sources=['upfirdn2d.cpp', 'upfirdn2d.cu'], 29 | headers=['upfirdn2d.h'], 30 | source_dir=os.path.dirname(__file__), 31 | extra_cuda_cflags=['--use_fast_math', '--allow-unsupported-compiler'], 32 | ) 33 | return True 34 | 35 | def _parse_scaling(scaling): 36 | if isinstance(scaling, int): 37 | scaling = [scaling, scaling] 38 | assert isinstance(scaling, (list, tuple)) 39 | assert all(isinstance(x, int) for x in scaling) 40 | sx, sy = scaling 41 | assert sx >= 1 and sy >= 1 42 | return sx, sy 43 | 44 | def _parse_padding(padding): 45 | if isinstance(padding, int): 46 | padding = [padding, padding] 47 | assert isinstance(padding, (list, tuple)) 48 | assert all(isinstance(x, int) for x in padding) 49 | if len(padding) == 2: 50 | padx, pady = padding 51 | padding = [padx, padx, pady, pady] 52 | padx0, padx1, pady0, pady1 = padding 53 | return padx0, padx1, pady0, pady1 54 | 55 | def _get_filter_size(f): 56 | if f is None: 57 | return 1, 1 58 | assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] 59 | fw = f.shape[-1] 60 | fh = f.shape[0] 61 | with misc.suppress_tracer_warnings(): 62 | fw = int(fw) 63 | fh = int(fh) 64 | misc.assert_shape(f, [fh, fw][:f.ndim]) 65 | assert fw >= 1 and fh >= 1 66 | return fw, fh 67 | 68 | #---------------------------------------------------------------------------- 69 | 70 | def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None): 71 | r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`. 72 | 73 | Args: 74 | f: Torch tensor, numpy array, or python list of the shape 75 | `[filter_height, filter_width]` (non-separable), 76 | `[filter_taps]` (separable), 77 | `[]` (impulse), or 78 | `None` (identity). 79 | device: Result device (default: cpu). 80 | normalize: Normalize the filter so that it retains the magnitude 81 | for constant input signal (DC)? (default: True). 82 | flip_filter: Flip the filter? (default: False). 83 | gain: Overall scaling factor for signal magnitude (default: 1). 84 | separable: Return a separable filter? (default: select automatically). 85 | 86 | Returns: 87 | Float32 tensor of the shape 88 | `[filter_height, filter_width]` (non-separable) or 89 | `[filter_taps]` (separable). 90 | """ 91 | # Validate. 92 | if f is None: 93 | f = 1 94 | f = torch.as_tensor(f, dtype=torch.float32) 95 | assert f.ndim in [0, 1, 2] 96 | assert f.numel() > 0 97 | if f.ndim == 0: 98 | f = f[np.newaxis] 99 | 100 | # Separable? 101 | if separable is None: 102 | separable = (f.ndim == 1 and f.numel() >= 8) 103 | if f.ndim == 1 and not separable: 104 | f = f.ger(f) 105 | assert f.ndim == (1 if separable else 2) 106 | 107 | # Apply normalize, flip, gain, and device. 108 | if normalize: 109 | f /= f.sum() 110 | if flip_filter: 111 | f = f.flip(list(range(f.ndim))) 112 | f = f * (gain ** (f.ndim / 2)) 113 | f = f.to(device=device) 114 | return f 115 | 116 | #---------------------------------------------------------------------------- 117 | 118 | def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'): 119 | r"""Pad, upsample, filter, and downsample a batch of 2D images. 120 | 121 | Performs the following sequence of operations for each channel: 122 | 123 | 1. Upsample the image by inserting N-1 zeros after each pixel (`up`). 124 | 125 | 2. Pad the image with the specified number of zeros on each side (`padding`). 126 | Negative padding corresponds to cropping the image. 127 | 128 | 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it 129 | so that the footprint of all output pixels lies within the input image. 130 | 131 | 4. Downsample the image by keeping every Nth pixel (`down`). 132 | 133 | This sequence of operations bears close resemblance to scipy.signal.upfirdn(). 134 | The fused op is considerably more efficient than performing the same calculation 135 | using standard PyTorch ops. It supports gradients of arbitrary order. 136 | 137 | Args: 138 | x: Float32/float64/float16 input tensor of the shape 139 | `[batch_size, num_channels, in_height, in_width]`. 140 | f: Float32 FIR filter of the shape 141 | `[filter_height, filter_width]` (non-separable), 142 | `[filter_taps]` (separable), or 143 | `None` (identity). 144 | up: Integer upsampling factor. Can be a single int or a list/tuple 145 | `[x, y]` (default: 1). 146 | down: Integer downsampling factor. Can be a single int or a list/tuple 147 | `[x, y]` (default: 1). 148 | padding: Padding with respect to the upsampled image. Can be a single number 149 | or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` 150 | (default: 0). 151 | flip_filter: False = convolution, True = correlation (default: False). 152 | gain: Overall scaling factor for signal magnitude (default: 1). 153 | impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). 154 | 155 | Returns: 156 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. 157 | """ 158 | assert isinstance(x, torch.Tensor) 159 | assert impl in ['ref', 'cuda'] 160 | if impl == 'cuda' and x.device.type == 'cuda' and _init(): 161 | return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f) 162 | return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain) 163 | 164 | #---------------------------------------------------------------------------- 165 | 166 | @misc.profiled_function 167 | def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): 168 | """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops. 169 | """ 170 | # Validate arguments. 171 | assert isinstance(x, torch.Tensor) and x.ndim == 4 172 | if f is None: 173 | f = torch.ones([1, 1], dtype=torch.float32, device=x.device) 174 | assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] 175 | assert f.dtype == torch.float32 and not f.requires_grad 176 | batch_size, num_channels, in_height, in_width = x.shape 177 | upx, upy = _parse_scaling(up) 178 | downx, downy = _parse_scaling(down) 179 | padx0, padx1, pady0, pady1 = _parse_padding(padding) 180 | 181 | # Check that upsampled buffer is not smaller than the filter. 182 | upW = in_width * upx + padx0 + padx1 183 | upH = in_height * upy + pady0 + pady1 184 | assert upW >= f.shape[-1] and upH >= f.shape[0] 185 | 186 | # Upsample by inserting zeros. 187 | x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1]) 188 | x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1]) 189 | x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx]) 190 | 191 | # Pad or crop. 192 | x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)]) 193 | x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)] 194 | 195 | # Setup filter. 196 | f = f * (gain ** (f.ndim / 2)) 197 | f = f.to(x.dtype) 198 | if not flip_filter: 199 | f = f.flip(list(range(f.ndim))) 200 | 201 | # Convolve with the filter. 202 | f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim) 203 | if f.ndim == 4: 204 | x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels) 205 | else: 206 | x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels) 207 | x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels) 208 | 209 | # Downsample by throwing away pixels. 210 | x = x[:, :, ::downy, ::downx] 211 | return x 212 | 213 | #---------------------------------------------------------------------------- 214 | 215 | _upfirdn2d_cuda_cache = dict() 216 | 217 | def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1): 218 | """Fast CUDA implementation of `upfirdn2d()` using custom ops. 219 | """ 220 | # Parse arguments. 221 | upx, upy = _parse_scaling(up) 222 | downx, downy = _parse_scaling(down) 223 | padx0, padx1, pady0, pady1 = _parse_padding(padding) 224 | 225 | # Lookup from cache. 226 | key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain) 227 | if key in _upfirdn2d_cuda_cache: 228 | return _upfirdn2d_cuda_cache[key] 229 | 230 | # Forward op. 231 | class Upfirdn2dCuda(torch.autograd.Function): 232 | @staticmethod 233 | def forward(ctx, x, f): # pylint: disable=arguments-differ 234 | assert isinstance(x, torch.Tensor) and x.ndim == 4 235 | if f is None: 236 | f = torch.ones([1, 1], dtype=torch.float32, device=x.device) 237 | if f.ndim == 1 and f.shape[0] == 1: 238 | f = f.square().unsqueeze(0) # Convert separable-1 into full-1x1. 239 | assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] 240 | y = x 241 | if f.ndim == 2: 242 | y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain) 243 | else: 244 | y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, 1.0) 245 | y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, gain) 246 | ctx.save_for_backward(f) 247 | ctx.x_shape = x.shape 248 | return y 249 | 250 | @staticmethod 251 | def backward(ctx, dy): # pylint: disable=arguments-differ 252 | f, = ctx.saved_tensors 253 | _, _, ih, iw = ctx.x_shape 254 | _, _, oh, ow = dy.shape 255 | fw, fh = _get_filter_size(f) 256 | p = [ 257 | fw - padx0 - 1, 258 | iw * upx - ow * downx + padx0 - upx + 1, 259 | fh - pady0 - 1, 260 | ih * upy - oh * downy + pady0 - upy + 1, 261 | ] 262 | dx = None 263 | df = None 264 | 265 | if ctx.needs_input_grad[0]: 266 | dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f) 267 | 268 | assert not ctx.needs_input_grad[1] 269 | return dx, df 270 | 271 | # Add to cache. 272 | _upfirdn2d_cuda_cache[key] = Upfirdn2dCuda 273 | return Upfirdn2dCuda 274 | 275 | #---------------------------------------------------------------------------- 276 | 277 | def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'): 278 | r"""Filter a batch of 2D images using the given 2D FIR filter. 279 | 280 | By default, the result is padded so that its shape matches the input. 281 | User-specified padding is applied on top of that, with negative values 282 | indicating cropping. Pixels outside the image are assumed to be zero. 283 | 284 | Args: 285 | x: Float32/float64/float16 input tensor of the shape 286 | `[batch_size, num_channels, in_height, in_width]`. 287 | f: Float32 FIR filter of the shape 288 | `[filter_height, filter_width]` (non-separable), 289 | `[filter_taps]` (separable), or 290 | `None` (identity). 291 | padding: Padding with respect to the output. Can be a single number or a 292 | list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` 293 | (default: 0). 294 | flip_filter: False = convolution, True = correlation (default: False). 295 | gain: Overall scaling factor for signal magnitude (default: 1). 296 | impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). 297 | 298 | Returns: 299 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. 300 | """ 301 | padx0, padx1, pady0, pady1 = _parse_padding(padding) 302 | fw, fh = _get_filter_size(f) 303 | p = [ 304 | padx0 + fw // 2, 305 | padx1 + (fw - 1) // 2, 306 | pady0 + fh // 2, 307 | pady1 + (fh - 1) // 2, 308 | ] 309 | return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) 310 | 311 | #---------------------------------------------------------------------------- 312 | 313 | def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'): 314 | r"""Upsample a batch of 2D images using the given 2D FIR filter. 315 | 316 | By default, the result is padded so that its shape is a multiple of the input. 317 | User-specified padding is applied on top of that, with negative values 318 | indicating cropping. Pixels outside the image are assumed to be zero. 319 | 320 | Args: 321 | x: Float32/float64/float16 input tensor of the shape 322 | `[batch_size, num_channels, in_height, in_width]`. 323 | f: Float32 FIR filter of the shape 324 | `[filter_height, filter_width]` (non-separable), 325 | `[filter_taps]` (separable), or 326 | `None` (identity). 327 | up: Integer upsampling factor. Can be a single int or a list/tuple 328 | `[x, y]` (default: 1). 329 | padding: Padding with respect to the output. Can be a single number or a 330 | list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` 331 | (default: 0). 332 | flip_filter: False = convolution, True = correlation (default: False). 333 | gain: Overall scaling factor for signal magnitude (default: 1). 334 | impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). 335 | 336 | Returns: 337 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. 338 | """ 339 | upx, upy = _parse_scaling(up) 340 | padx0, padx1, pady0, pady1 = _parse_padding(padding) 341 | fw, fh = _get_filter_size(f) 342 | p = [ 343 | padx0 + (fw + upx - 1) // 2, 344 | padx1 + (fw - upx) // 2, 345 | pady0 + (fh + upy - 1) // 2, 346 | pady1 + (fh - upy) // 2, 347 | ] 348 | return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl) 349 | 350 | #---------------------------------------------------------------------------- 351 | 352 | def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'): 353 | r"""Downsample a batch of 2D images using the given 2D FIR filter. 354 | 355 | By default, the result is padded so that its shape is a fraction of the input. 356 | User-specified padding is applied on top of that, with negative values 357 | indicating cropping. Pixels outside the image are assumed to be zero. 358 | 359 | Args: 360 | x: Float32/float64/float16 input tensor of the shape 361 | `[batch_size, num_channels, in_height, in_width]`. 362 | f: Float32 FIR filter of the shape 363 | `[filter_height, filter_width]` (non-separable), 364 | `[filter_taps]` (separable), or 365 | `None` (identity). 366 | down: Integer downsampling factor. Can be a single int or a list/tuple 367 | `[x, y]` (default: 1). 368 | padding: Padding with respect to the input. Can be a single number or a 369 | list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` 370 | (default: 0). 371 | flip_filter: False = convolution, True = correlation (default: False). 372 | gain: Overall scaling factor for signal magnitude (default: 1). 373 | impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). 374 | 375 | Returns: 376 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. 377 | """ 378 | downx, downy = _parse_scaling(down) 379 | padx0, padx1, pady0, pady1 = _parse_padding(padding) 380 | fw, fh = _get_filter_size(f) 381 | p = [ 382 | padx0 + (fw - downx + 1) // 2, 383 | padx1 + (fw - downx) // 2, 384 | pady0 + (fh - downy + 1) // 2, 385 | pady1 + (fh - downy) // 2, 386 | ] 387 | return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) 388 | 389 | #---------------------------------------------------------------------------- 390 | -------------------------------------------------------------------------------- /add_turbo/vit_utils.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2021 Intel ISL (Intel Intelligent Systems Lab) 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # 23 | # Based on code from https://github.com/isl-org/DPT 24 | 25 | """Flexible configuration and feature extraction of timm VisionTransformers.""" 26 | 27 | import types 28 | import math 29 | from typing import Callable 30 | 31 | import torch 32 | import torch.nn as nn 33 | import torch.nn.functional as F 34 | 35 | 36 | class AddReadout(nn.Module): 37 | def __init__(self, start_index: bool = 1): 38 | super(AddReadout, self).__init__() 39 | self.start_index = start_index 40 | 41 | def forward(self, x: torch.Tensor) -> torch.Tensor: 42 | if self.start_index == 2: 43 | readout = (x[:, 0] + x[:, 1]) / 2 44 | else: 45 | readout = x[:, 0] 46 | return x[:, self.start_index:] + readout.unsqueeze(1) 47 | 48 | 49 | class Transpose(nn.Module): 50 | def __init__(self, dim0: int, dim1: int): 51 | super(Transpose, self).__init__() 52 | self.dim0 = dim0 53 | self.dim1 = dim1 54 | 55 | def forward(self, x: torch.Tensor) -> torch.Tensor: 56 | x = x.transpose(self.dim0, self.dim1) 57 | return x.contiguous() 58 | 59 | 60 | def forward_vit(pretrained: nn.Module, x: torch.Tensor) -> dict: 61 | _, _, H, W = x.size() 62 | _ = pretrained.model.forward_flex(x) 63 | return {k: pretrained.rearrange(v) for k, v in activations.items()} 64 | 65 | 66 | def _resize_pos_embed(self, posemb: torch.Tensor, gs_h: int, gs_w: int) -> torch.Tensor: 67 | posemb_tok, posemb_grid = ( 68 | posemb[:, : self.start_index], 69 | posemb[0, self.start_index :], 70 | ) 71 | 72 | gs_old = int(math.sqrt(len(posemb_grid))) 73 | 74 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) 75 | posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear", align_corners=False) 76 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) 77 | 78 | posemb = torch.cat([posemb_tok, posemb_grid], dim=1) 79 | 80 | return posemb 81 | 82 | 83 | def forward_flex(self, x: torch.Tensor) -> torch.Tensor: 84 | # patch proj and dynamically resize 85 | B, C, H, W = x.size() 86 | x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) 87 | pos_embed = self._resize_pos_embed( 88 | self.pos_embed, H // self.patch_size[1], W // self.patch_size[0] 89 | ) 90 | 91 | # add cls token 92 | cls_tokens = self.cls_token.expand( 93 | x.size(0), -1, -1 94 | ) 95 | x = torch.cat((cls_tokens, x), dim=1) 96 | 97 | # forward pass 98 | x = x + pos_embed 99 | x = self.pos_drop(x) 100 | 101 | for blk in self.blocks: 102 | x = blk(x) 103 | 104 | x = self.norm(x) 105 | return x 106 | 107 | 108 | activations = {} 109 | 110 | 111 | def get_activation(name: str) -> Callable: 112 | def hook(model, input, output): 113 | activations[name] = output 114 | return hook 115 | 116 | 117 | def make_vit_backbone( 118 | model: nn.Module, 119 | patch_size: list[int] = [16, 16], 120 | hooks: list[int] = [2, 5, 8, 11], 121 | hook_patch: bool = True, 122 | start_index: list[int] = 1, 123 | ): 124 | assert len(hooks) == 4 125 | 126 | pretrained = nn.Module() 127 | pretrained.model = model 128 | 129 | # add hooks 130 | pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation('0')) 131 | pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation('1')) 132 | pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation('2')) 133 | pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation('3')) 134 | if hook_patch: 135 | pretrained.model.pos_drop.register_forward_hook(get_activation('4')) 136 | 137 | # configure readout 138 | pretrained.rearrange = nn.Sequential(AddReadout(start_index), Transpose(1, 2)) 139 | pretrained.model.start_index = start_index 140 | pretrained.model.patch_size = patch_size 141 | 142 | # We inject this function into the VisionTransformer instances so that 143 | # we can use it with interpolated position embeddings without modifying the library source. 144 | pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) 145 | pretrained.model._resize_pos_embed = types.MethodType( 146 | _resize_pos_embed, pretrained.model 147 | ) 148 | 149 | return pretrained 150 | -------------------------------------------------------------------------------- /mobile_diffusion_unet/README.md: -------------------------------------------------------------------------------- 1 | ## MobileDiffusion UNET Reproduction 2 | 3 | ### Overview 4 | 5 | Our team has successfully replicated the UNet compression and acceleration aspects of MobileDiffusion, 6 | which includes the pruning and merging of the down and up layers, as well as the weight sharing in the attention mechanism. 7 | 8 | ### Usage and Application 9 | ```shell 10 | export MODEL_PATH="XXX" 11 | export DATASET_PATH="XXX" 12 | 13 | accelerate launch --mixed_precision="fp16" distill_training_dalcefo_md.py \ 14 | --pretrained_model_name_or_path=$MODEL_PATH \ 15 | --train_data_dir=$DATASET_PATH \ 16 | --resolution=512 --center_crop --random_flip \ 17 | --train_batch_size=16 \ 18 | --gradient_accumulation_steps=4 \ 19 | --gradient_checkpointing \ 20 | --max_train_steps=25000 \ 21 | --distill_level="md"\ 22 | --prepare_unet="True" \ 23 | --output_weight=10 \ 24 | --feature_weight=0.5 \ 25 | --learning_rate=1e-04 \ 26 | --max_grad_norm=1 \ 27 | --lr_scheduler="constant" --lr_warmup_steps=0 \ 28 | --output_dir="XXX" \ 29 | --student_unet_weight="XXX" 30 | 31 | ``` 32 | 33 | ### Contribution and Collaboration 34 | 35 | We are open to contributions from the community to refine and improve upon the current implementation. 36 | 37 | ### Acknowledgments 38 | 39 | We would like to acknowledge the original authors of MobileDiffusion for their innovative approach to model 40 | compression and acceleration. Our work is a testament to the impact of their research in the field of 41 | efficient machine learning models. 42 | -------------------------------------------------------------------------------- /mobile_diffusion_unet/__init__.py: -------------------------------------------------------------------------------- 1 | import os,sys 2 | 3 | curPath = os.path.abspath(os.path.dirname(__file__)) 4 | sys.path.append(curPath) -------------------------------------------------------------------------------- /mobile_diffusion_unet/mdAttention.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Optional, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from diffusers.models.attention_processor import SpatialNorm 7 | from diffusers.utils import USE_PEFT_BACKEND 8 | from diffusers.models.lora import LoRACompatibleLinear 9 | 10 | import math 11 | 12 | class AttentionMobile(nn.Module): 13 | def __init__( 14 | self, 15 | query_dim: int, 16 | cross_attention_dim: Optional[int] = None, 17 | heads: int = 8, 18 | dim_head: int = 64, 19 | dropout: float = 0.0, 20 | bias: bool = False, 21 | upcast_attention: bool = False, 22 | upcast_softmax: bool = False, 23 | cross_attention_norm: Optional[str] = None, 24 | cross_attention_norm_num_groups: int = 32, 25 | added_kv_proj_dim: Optional[int] = None, 26 | norm_num_groups: Optional[int] = None, 27 | spatial_norm_dim: Optional[int] = None, 28 | out_bias: bool = True, 29 | scale_qk: bool = True, 30 | only_cross_attention: bool = False, 31 | eps: float = 1e-5, 32 | rescale_output_factor: float = 1.0, 33 | residual_connection: bool = False, 34 | _from_deprecated_attn_block: bool = False, 35 | # processor: Optional["AttnProcessor"] = None, 36 | out_dim: int = None, 37 | ): 38 | super().__init__() 39 | self.inner_dim = out_dim if out_dim is not None else dim_head * heads 40 | self.query_dim = query_dim 41 | self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim 42 | self.upcast_attention = upcast_attention 43 | self.upcast_softmax = upcast_softmax 44 | self.rescale_output_factor = rescale_output_factor 45 | self.residual_connection = residual_connection 46 | self.dropout = dropout 47 | self.fused_projections = False 48 | self.out_dim = out_dim if out_dim is not None else query_dim 49 | 50 | # we make use of this private variable to know whether this class is loaded 51 | # with an deprecated state dict so that we can convert it on the fly 52 | self._from_deprecated_attn_block = _from_deprecated_attn_block 53 | 54 | self.scale_qk = scale_qk 55 | self.scale = dim_head**-0.5 if self.scale_qk else 1.0 56 | 57 | self.heads = out_dim // dim_head if out_dim is not None else heads 58 | # for slice_size > 0 the attention score computation 59 | # is split across the batch axis to save memory 60 | # You can set slice_size with `set_attention_slice` 61 | self.sliceable_head_dim = heads 62 | 63 | self.added_kv_proj_dim = added_kv_proj_dim 64 | self.only_cross_attention = only_cross_attention 65 | 66 | # if self.added_kv_proj_dim is None and self.only_cross_attention: 67 | # raise ValueError( 68 | # "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." 69 | # ) 70 | 71 | if norm_num_groups is not None: 72 | self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True) 73 | else: 74 | self.group_norm = None 75 | 76 | if spatial_norm_dim is not None: 77 | self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim) 78 | else: 79 | self.spatial_norm = None 80 | 81 | if cross_attention_norm is None: 82 | self.norm_cross = None 83 | elif cross_attention_norm == "layer_norm": 84 | self.norm_cross = nn.LayerNorm(self.cross_attention_dim) 85 | elif cross_attention_norm == "group_norm": 86 | if self.added_kv_proj_dim is not None: 87 | # The given `encoder_hidden_states` are initially of shape 88 | # (batch_size, seq_len, added_kv_proj_dim) before being projected 89 | # to (batch_size, seq_len, cross_attention_dim). The norm is applied 90 | # before the projection, so we need to use `added_kv_proj_dim` as 91 | # the number of channels for the group norm. 92 | norm_cross_num_channels = added_kv_proj_dim 93 | else: 94 | norm_cross_num_channels = self.cross_attention_dim 95 | 96 | self.norm_cross = nn.GroupNorm( 97 | num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True 98 | ) 99 | else: 100 | raise ValueError( 101 | f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" 102 | ) 103 | 104 | if USE_PEFT_BACKEND: 105 | linear_cls = nn.Linear 106 | else: 107 | linear_cls = LoRACompatibleLinear 108 | 109 | self.linear_cls = linear_cls 110 | self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias) 111 | 112 | if not self.only_cross_attention: # SA 113 | # only relevant for the `AddedKVProcessor` classes 114 | self.to_kv = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias) 115 | else: # CA 116 | self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias) 117 | self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias) 118 | 119 | # if self.added_kv_proj_dim is not None: 120 | # self.add_kv_proj = linear_cls(added_kv_proj_dim, self.inner_dim) 121 | 122 | self.to_out = nn.ModuleList([]) 123 | self.to_out.append(linear_cls(self.inner_dim, self.out_dim, bias=out_bias)) 124 | self.to_out.append(nn.Dropout(dropout)) 125 | 126 | def forward( 127 | self, 128 | hidden_states: torch.FloatTensor, 129 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 130 | attention_mask: Optional[torch.FloatTensor] = None, 131 | temb: Optional[torch.FloatTensor] = None, 132 | scale: float = 1.0, 133 | ) -> torch.Tensor: 134 | residual = hidden_states 135 | if self.spatial_norm is not None: 136 | hidden_states = self.spatial_norm(hidden_states, temb) 137 | 138 | input_ndim = hidden_states.ndim 139 | 140 | if input_ndim == 4: 141 | batch_size, channel, height, width = hidden_states.shape 142 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 143 | 144 | batch_size, sequence_length, _ = ( 145 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 146 | ) 147 | 148 | if attention_mask is not None: 149 | pass 150 | # attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size) 151 | # # scaled_dot_product_attention expects attention_mask shape to be 152 | # # (batch, heads, source_length, target_length) 153 | # attention_mask = attention_mask.view(batch_size, self.heads, -1, attention_mask.shape[-1]) 154 | 155 | if self.group_norm is not None: 156 | hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 157 | 158 | args = () if USE_PEFT_BACKEND else (scale,) 159 | query = self.to_q(hidden_states, *args) 160 | 161 | if encoder_hidden_states is None: 162 | encoder_hidden_states = hidden_states 163 | elif self.norm_cross: 164 | # encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states) 165 | pass 166 | 167 | if not self.only_cross_attention: # SA 168 | key = self.to_kv(encoder_hidden_states, *args) 169 | value = self.to_kv(encoder_hidden_states, *args) 170 | else: 171 | key = self.to_k(encoder_hidden_states, *args) 172 | value = self.to_v(encoder_hidden_states, *args) 173 | 174 | inner_dim = key.shape[-1] 175 | head_dim = inner_dim // self.heads 176 | 177 | query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) 178 | 179 | key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) 180 | value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) 181 | 182 | ## start scaled_dot_product_attention 183 | # L, S = query.size(-2), key.size(-2) 184 | scale_factor = 1 / math.sqrt(query.size(-1)) 185 | # attn_bias = torch.zeros(L, S, dtype=query.dtype).to(hidden_states.device) 186 | 187 | attn_weight = query @ key.transpose(-2, -1) * scale_factor 188 | # attn_weight += attn_bias 189 | attn_weight = torch.relu(attn_weight) # softmax -> relu 190 | # attn_weight = torch.dropout(attn_weight, 0.0, train=True) 191 | hidden_states = attn_weight @ value 192 | ## end scaled_dot_product_attention 193 | 194 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim) 195 | hidden_states = hidden_states.to(query.dtype) 196 | 197 | # linear proj 198 | hidden_states = self.to_out[0](hidden_states, *args) 199 | # dropout 200 | hidden_states = self.to_out[1](hidden_states) 201 | 202 | if input_ndim == 4: 203 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 204 | 205 | if self.residual_connection: 206 | hidden_states = hidden_states + residual 207 | 208 | hidden_states = hidden_states / self.rescale_output_factor 209 | 210 | return hidden_states -------------------------------------------------------------------------------- /mobile_diffusion_unet/mdBasicTransformerBlock.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Tuple, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | # from diffusers.models.attention_processor import Attention 6 | from diffusers.models.attention import FeedForward 7 | from mdAttention import AttentionMobile as Attention 8 | 9 | class BasicTransformerMobileBlock(nn.Module): 10 | def __init__( 11 | self, 12 | dim: int, 13 | num_attention_heads: int, 14 | attention_head_dim: int, 15 | dropout=0.0, 16 | cross_attention_dim: Optional[int] = None, 17 | activation_fn: str = "geglu", 18 | # num_embeds_ada_norm: Optional[int] = None, 19 | attention_bias: bool = False, 20 | only_cross_attention: bool = False, 21 | # double_self_attention: bool = False, 22 | upcast_attention: bool = False, 23 | norm_elementwise_affine: bool = True, 24 | # norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single' 25 | norm_eps: float = 1e-5, 26 | final_dropout: bool = False, 27 | # attention_type: str = "default", 28 | # positional_embeddings: Optional[str] = None, 29 | # num_positional_embeddings: Optional[int] = None, 30 | # ada_norm_continous_conditioning_embedding_dim: Optional[int] = None, 31 | # ada_norm_bias: Optional[int] = None, 32 | ff_inner_dim: Optional[int] = None, 33 | ff_bias: bool = True, 34 | attention_out_bias: bool = True, 35 | ): 36 | super().__init__() 37 | self.only_cross_attention = only_cross_attention 38 | 39 | # self.pos_embed = None 40 | 41 | # Define 3 blocks. Each block has its own normalization layer. 42 | # 1. Self-Attn 43 | if not only_cross_attention: 44 | self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) 45 | 46 | self.attn1 = Attention( 47 | query_dim=dim, 48 | heads=num_attention_heads, 49 | dim_head=attention_head_dim, 50 | dropout=dropout, 51 | bias=attention_bias, 52 | cross_attention_dim=cross_attention_dim if only_cross_attention else None, 53 | upcast_attention=upcast_attention, 54 | out_bias=attention_out_bias, 55 | only_cross_attention=False 56 | ) 57 | 58 | # 2. Cross-Attn 59 | self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) 60 | self.attn2 = Attention( 61 | query_dim=dim, 62 | cross_attention_dim=cross_attention_dim, 63 | heads=num_attention_heads, 64 | dim_head=attention_head_dim, 65 | dropout=dropout, 66 | bias=attention_bias, 67 | upcast_attention=upcast_attention, 68 | out_bias=attention_out_bias, 69 | only_cross_attention=True 70 | ) # is self-attn if encoder_hidden_states is none 71 | 72 | # 3. Feed-forward 73 | self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) 74 | self.ff = FeedForward( 75 | dim, 76 | dropout=dropout, 77 | activation_fn=activation_fn, 78 | final_dropout=final_dropout, 79 | inner_dim=ff_inner_dim, 80 | bias=ff_bias, 81 | ) 82 | 83 | def forward( 84 | self, 85 | hidden_states: torch.FloatTensor, 86 | attention_mask: Optional[torch.FloatTensor] = None, 87 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 88 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 89 | timestep: Optional[torch.LongTensor] = None, 90 | cross_attention_kwargs: Dict[str, Any] = None, 91 | class_labels: Optional[torch.LongTensor] = None, 92 | added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, 93 | ) -> torch.FloatTensor: 94 | # batch_size = hidden_states.shape[0] 95 | lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 96 | 97 | cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} 98 | # gligen_kwargs = cross_attention_kwargs.pop("gligen", None) 99 | 100 | # 0. Self-Attention 101 | if not self.only_cross_attention: 102 | norm_hidden_states = self.norm1(hidden_states) 103 | 104 | # if self.pos_embed is not None: 105 | # norm_hidden_states = self.pos_embed(norm_hidden_states) 106 | 107 | attn_output = self.attn1( 108 | norm_hidden_states, 109 | encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, 110 | attention_mask=attention_mask, 111 | **cross_attention_kwargs, 112 | ) 113 | 114 | hidden_states = attn_output + hidden_states 115 | if hidden_states.ndim == 4: 116 | hidden_states = hidden_states.squeeze(1) 117 | 118 | # 3. Cross-Attention 119 | norm_hidden_states = self.norm2(hidden_states) 120 | attn_output = self.attn2( 121 | norm_hidden_states, 122 | encoder_hidden_states=encoder_hidden_states, 123 | attention_mask=encoder_attention_mask, 124 | **cross_attention_kwargs, 125 | ) 126 | hidden_states = attn_output + hidden_states 127 | 128 | # 4. Feed-forward 129 | norm_hidden_states = self.norm3(hidden_states) 130 | ff_output = self.ff(norm_hidden_states, scale=lora_scale) 131 | 132 | hidden_states = ff_output + hidden_states 133 | if hidden_states.ndim == 4: 134 | hidden_states = hidden_states.squeeze(1) 135 | 136 | return hidden_states 137 | 138 | -------------------------------------------------------------------------------- /mobile_diffusion_unet/mdCADownBlock.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Tuple, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | from diffusers.models.resnet import ResnetBlock2D, Downsample2D 6 | from diffusers.utils import is_torch_version, logging 7 | 8 | from mdTransformer2DModel import Transformer2DMobileModel 9 | from mdSepResnetBlock2D import SeparableResnetBlock2D 10 | 11 | class CADownBlock2DMobile(nn.Module): 12 | def __init__( 13 | self, 14 | in_channels: int, 15 | out_channels: int, 16 | temb_channels: int, 17 | dropout: float = 0.0, 18 | num_layers: int = 1, 19 | transformer_layers_per_block: Union[int, Tuple[int]] = 1, 20 | resnet_eps: float = 1e-6, 21 | resnet_time_scale_shift: str = "default", 22 | resnet_act_fn: str = "swish", 23 | resnet_groups: int = 32, 24 | resnet_pre_norm: bool = True, 25 | num_attention_heads: int = 1, 26 | cross_attention_dim: int = 1280, 27 | output_scale_factor: float = 1.0, 28 | downsample_padding: int = 1, 29 | add_downsample: bool = True, 30 | # dual_cross_attention: bool = False, 31 | use_linear_projection: bool = False, 32 | only_cross_attention: bool = False, 33 | upcast_attention: bool = False, 34 | # attention_type: str = "default", 35 | resnet_type: str = "ResnetBlock2D", 36 | ): 37 | super().__init__() 38 | resnets = [] 39 | attentions = [] 40 | 41 | self.num_attention_heads = num_attention_heads 42 | 43 | for i in range(num_layers): 44 | in_channels = in_channels if i == 0 else out_channels 45 | 46 | if resnet_type == "SeparableResnetBlock2D": 47 | res = SeparableResnetBlock2D( 48 | in_channels=in_channels, 49 | out_channels=out_channels, 50 | temb_channels=temb_channels, 51 | eps=resnet_eps, 52 | groups=resnet_groups, 53 | dropout=dropout, 54 | time_embedding_norm=resnet_time_scale_shift, 55 | non_linearity=resnet_act_fn, 56 | output_scale_factor=output_scale_factor, 57 | pre_norm=resnet_pre_norm, 58 | ) 59 | else: 60 | res = ResnetBlock2D( 61 | in_channels=in_channels, 62 | out_channels=out_channels, 63 | temb_channels=temb_channels, 64 | eps=resnet_eps, 65 | groups=resnet_groups, 66 | dropout=dropout, 67 | time_embedding_norm=resnet_time_scale_shift, 68 | non_linearity=resnet_act_fn, 69 | output_scale_factor=output_scale_factor, 70 | pre_norm=resnet_pre_norm, 71 | ) 72 | resnets.append(res) 73 | 74 | attentions.append( 75 | Transformer2DMobileModel( 76 | num_attention_heads, 77 | out_channels // num_attention_heads, 78 | in_channels=out_channels, 79 | num_layers=transformer_layers_per_block, 80 | cross_attention_dim=cross_attention_dim, 81 | norm_num_groups=resnet_groups, 82 | use_linear_projection=use_linear_projection, 83 | only_cross_attention=only_cross_attention, 84 | upcast_attention=upcast_attention, 85 | # attention_type=attention_type, 86 | ) 87 | ) 88 | 89 | self.attentions = nn.ModuleList(attentions) 90 | self.resnets = nn.ModuleList(resnets) 91 | 92 | if add_downsample: 93 | self.downsamplers = nn.ModuleList( 94 | [ 95 | Downsample2D( 96 | out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" 97 | ) 98 | ] 99 | ) 100 | else: 101 | self.downsamplers = None 102 | 103 | self.gradient_checkpointing = False 104 | 105 | def forward( 106 | self, 107 | hidden_states: torch.FloatTensor, 108 | temb: Optional[torch.FloatTensor] = None, 109 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 110 | attention_mask: Optional[torch.FloatTensor] = None, 111 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 112 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 113 | additional_residuals: Optional[torch.FloatTensor] = None, 114 | ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: 115 | # output_states = () 116 | 117 | lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 118 | 119 | blocks = list(zip(self.resnets, self.attentions)) 120 | # print(f"CADownBlock2DMobile train = {self.training}, grad = {self.gradient_checkpointing}") 121 | 122 | for i, (resnet, attn) in enumerate(blocks): 123 | if self.training and self.gradient_checkpointing: 124 | 125 | def create_custom_forward(module, return_dict=None): 126 | def custom_forward(*inputs): 127 | if return_dict is not None: 128 | return module(*inputs, return_dict=return_dict) 129 | else: 130 | return module(*inputs) 131 | 132 | return custom_forward 133 | 134 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 135 | hidden_states = torch.utils.checkpoint.checkpoint( 136 | create_custom_forward(resnet), 137 | hidden_states, 138 | temb, 139 | **ckpt_kwargs, 140 | ) 141 | hidden_states = attn( 142 | hidden_states, 143 | encoder_hidden_states=encoder_hidden_states, 144 | cross_attention_kwargs=cross_attention_kwargs, 145 | attention_mask=attention_mask, 146 | encoder_attention_mask=encoder_attention_mask, 147 | return_dict=False, 148 | )[0] 149 | else: 150 | hidden_states = resnet(hidden_states, temb, scale=lora_scale) 151 | hidden_states = attn( 152 | hidden_states, 153 | encoder_hidden_states=encoder_hidden_states, 154 | cross_attention_kwargs=cross_attention_kwargs, 155 | attention_mask=attention_mask, 156 | encoder_attention_mask=encoder_attention_mask, 157 | return_dict=False, 158 | )[0] 159 | 160 | # apply additional residuals to the output of the last pair of resnet and attention blocks 161 | if i == len(blocks) - 1 and additional_residuals is not None: 162 | hidden_states = hidden_states + additional_residuals 163 | 164 | # output_states = output_states + (hidden_states,) 165 | 166 | if self.downsamplers is not None: 167 | for downsampler in self.downsamplers: 168 | hidden_states = downsampler(hidden_states, scale=lora_scale) 169 | 170 | # output_states = output_states + (hidden_states,) 171 | 172 | return hidden_states #, output_states -------------------------------------------------------------------------------- /mobile_diffusion_unet/mdCAUpBlock.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Tuple, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | from diffusers.models.resnet import ResnetBlock2D, Upsample2D 6 | from diffusers.utils import is_torch_version, logging 7 | 8 | from mdTransformer2DModel import Transformer2DMobileModel 9 | 10 | class CAUpBlock2DMobile(nn.Module): 11 | def __init__( 12 | self, 13 | in_channels: int, 14 | out_channels: int, 15 | # prev_output_channel: int, 16 | temb_channels: int, 17 | # resolution_idx: Optional[int] = None, 18 | dropout: float = 0.0, 19 | num_layers: int = 1, 20 | transformer_layers_per_block: Union[int, Tuple[int]] = 1, 21 | resnet_eps: float = 1e-6, 22 | resnet_time_scale_shift: str = "default", 23 | resnet_act_fn: str = "swish", 24 | resnet_groups: int = 32, 25 | resnet_pre_norm: bool = True, 26 | num_attention_heads: int = 1, 27 | cross_attention_dim: int = 1280, 28 | output_scale_factor: float = 1.0, 29 | add_upsample: bool = True, 30 | # dual_cross_attention: bool = False, 31 | use_linear_projection: bool = False, 32 | only_cross_attention: bool = False, 33 | upcast_attention: bool = False, 34 | # attention_type: str = "default", 35 | ): 36 | super().__init__() 37 | resnets = [] 38 | attentions = [] 39 | 40 | self.num_attention_heads = num_attention_heads 41 | 42 | for i in range(num_layers): 43 | resnets.append( 44 | ResnetBlock2D( 45 | in_channels=in_channels if i == 0 else out_channels, 46 | out_channels=out_channels, 47 | temb_channels=temb_channels, 48 | eps=resnet_eps, 49 | groups=resnet_groups, 50 | dropout=dropout, 51 | time_embedding_norm=resnet_time_scale_shift, 52 | non_linearity=resnet_act_fn, 53 | output_scale_factor=output_scale_factor, 54 | pre_norm=resnet_pre_norm, 55 | ) 56 | ) 57 | attentions.append( 58 | Transformer2DMobileModel( 59 | num_attention_heads, 60 | out_channels // num_attention_heads, 61 | in_channels=out_channels, 62 | num_layers=transformer_layers_per_block, 63 | cross_attention_dim=cross_attention_dim, 64 | norm_num_groups=resnet_groups, 65 | use_linear_projection=use_linear_projection, 66 | only_cross_attention=only_cross_attention, 67 | upcast_attention=upcast_attention, 68 | # attention_type=attention_type, 69 | ) 70 | ) 71 | self.attentions = nn.ModuleList(attentions) 72 | self.resnets = nn.ModuleList(resnets) 73 | 74 | if add_upsample: 75 | self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) 76 | else: 77 | self.upsamplers = None 78 | 79 | self.gradient_checkpointing = False 80 | 81 | def forward( 82 | self, 83 | hidden_states: torch.FloatTensor, 84 | res_hidden_states: torch.FloatTensor, 85 | temb: Optional[torch.FloatTensor] = None, 86 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 87 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 88 | upsample_size: Optional[int] = None, 89 | attention_mask: Optional[torch.FloatTensor] = None, 90 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 91 | ) -> torch.FloatTensor: 92 | lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 93 | 94 | for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)): 95 | if i == 0: 96 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) 97 | 98 | if self.training and self.gradient_checkpointing: 99 | 100 | def create_custom_forward(module, return_dict=None): 101 | def custom_forward(*inputs): 102 | if return_dict is not None: 103 | return module(*inputs, return_dict=return_dict) 104 | else: 105 | return module(*inputs) 106 | 107 | return custom_forward 108 | 109 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 110 | hidden_states = torch.utils.checkpoint.checkpoint( 111 | create_custom_forward(resnet), 112 | hidden_states, 113 | temb, 114 | **ckpt_kwargs, 115 | ) 116 | hidden_states = attn( 117 | hidden_states, 118 | encoder_hidden_states=encoder_hidden_states, 119 | cross_attention_kwargs=cross_attention_kwargs, 120 | attention_mask=attention_mask, 121 | encoder_attention_mask=encoder_attention_mask, 122 | return_dict=False, 123 | )[0] 124 | else: 125 | hidden_states = resnet(hidden_states, temb, scale=lora_scale) 126 | hidden_states = attn( 127 | hidden_states, 128 | encoder_hidden_states=encoder_hidden_states, 129 | cross_attention_kwargs=cross_attention_kwargs, 130 | attention_mask=attention_mask, 131 | encoder_attention_mask=encoder_attention_mask, 132 | return_dict=False, 133 | )[0] 134 | 135 | if self.upsamplers is not None: 136 | for upsampler in self.upsamplers: 137 | hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale) 138 | 139 | return hidden_states -------------------------------------------------------------------------------- /mobile_diffusion_unet/mdDownBlock2D.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Tuple, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | from diffusers.models.resnet import ResnetBlock2D, Downsample2D 6 | from diffusers.utils import is_torch_version 7 | 8 | class DownBlock2DMobile(nn.Module): 9 | def __init__( 10 | self, 11 | in_channels: int, 12 | out_channels: int, 13 | temb_channels: int, 14 | dropout: float = 0.0, 15 | num_layers: int = 1, 16 | resnet_eps: float = 1e-6, 17 | resnet_time_scale_shift: str = "default", 18 | resnet_act_fn: str = "swish", 19 | resnet_groups: int = 32, 20 | resnet_pre_norm: bool = True, 21 | output_scale_factor: float = 1.0, 22 | add_downsample: bool = True, 23 | downsample_padding: int = 1, 24 | ): 25 | super().__init__() 26 | resnets = [] 27 | 28 | for i in range(num_layers): 29 | in_channels = in_channels if i == 0 else out_channels 30 | resnets.append( 31 | ResnetBlock2D( 32 | in_channels=in_channels, 33 | out_channels=out_channels, 34 | temb_channels=temb_channels, 35 | eps=resnet_eps, 36 | groups=resnet_groups, 37 | dropout=dropout, 38 | time_embedding_norm=resnet_time_scale_shift, 39 | non_linearity=resnet_act_fn, 40 | output_scale_factor=output_scale_factor, 41 | pre_norm=resnet_pre_norm, 42 | ) 43 | ) 44 | 45 | self.resnets = nn.ModuleList(resnets) 46 | 47 | if add_downsample: 48 | self.downsamplers = nn.ModuleList( 49 | [ 50 | Downsample2D( 51 | out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" 52 | ) 53 | ] 54 | ) 55 | else: 56 | self.downsamplers = None 57 | 58 | self.gradient_checkpointing = False 59 | 60 | def forward( 61 | self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0 62 | ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: 63 | for resnet in self.resnets: 64 | if self.training and self.gradient_checkpointing: 65 | 66 | def create_custom_forward(module): 67 | def custom_forward(*inputs): 68 | return module(*inputs) 69 | 70 | return custom_forward 71 | 72 | if is_torch_version(">=", "1.11.0"): 73 | hidden_states = torch.utils.checkpoint.checkpoint( 74 | create_custom_forward(resnet), hidden_states, temb, use_reentrant=False 75 | ) 76 | else: 77 | hidden_states = torch.utils.checkpoint.checkpoint( 78 | create_custom_forward(resnet), hidden_states, temb 79 | ) 80 | else: 81 | hidden_states = resnet(hidden_states, temb, scale=scale) 82 | 83 | if self.downsamplers is not None: 84 | for downsampler in self.downsamplers: 85 | hidden_states = downsampler(hidden_states, scale=scale) 86 | 87 | return hidden_states -------------------------------------------------------------------------------- /mobile_diffusion_unet/mdSepResnetBlock2D.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from functools import partial 7 | 8 | from diffusers.utils import USE_PEFT_BACKEND 9 | from diffusers.models.activations import get_activation 10 | from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear 11 | from diffusers.models.normalization import AdaGroupNorm 12 | from diffusers.models.attention_processor import SpatialNorm 13 | from diffusers.models.downsampling import ( # noqa 14 | Downsample1D, 15 | Downsample2D, 16 | FirDownsample2D, 17 | KDownsample2D, 18 | downsample_2d, 19 | ) 20 | from diffusers.models.upsampling import ( # noqa 21 | FirUpsample2D, 22 | KUpsample2D, 23 | Upsample1D, 24 | Upsample2D, 25 | upfirdn2d_native, 26 | upsample_2d, 27 | ) 28 | 29 | class SeparableConv2d(nn.Module): 30 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0): 31 | super(SeparableConv2d, self).__init__() 32 | self.depthwise = nn.Conv2d(in_channels, in_channels, 33 | kernel_size, stride, padding, 34 | groups=in_channels) 35 | self.pointwise = nn.Conv2d(in_channels, out_channels, 36 | kernel_size=1, 37 | stride=1, 38 | padding=0, 39 | groups=1) 40 | 41 | def forward(self, x): 42 | x = self.depthwise(x) 43 | x = self.pointwise(x) 44 | return x 45 | 46 | class SeparableResnetBlock2D(nn.Module): 47 | def __init__( 48 | self, 49 | *, 50 | in_channels: int, 51 | out_channels: Optional[int] = None, 52 | conv_shortcut: bool = False, 53 | dropout: float = 0.0, 54 | temb_channels: int = 512, 55 | groups: int = 32, 56 | groups_out: Optional[int] = None, 57 | pre_norm: bool = True, 58 | eps: float = 1e-6, 59 | non_linearity: str = "swish", 60 | skip_time_act: bool = False, 61 | time_embedding_norm: str = "default", # default, scale_shift, ada_group, spatial 62 | kernel: Optional[torch.FloatTensor] = None, 63 | output_scale_factor: float = 1.0, 64 | use_in_shortcut: Optional[bool] = None, 65 | up: bool = False, 66 | down: bool = False, 67 | conv_shortcut_bias: bool = True, 68 | conv_2d_out_channels: Optional[int] = None, 69 | ): 70 | super().__init__() 71 | self.pre_norm = pre_norm 72 | self.pre_norm = True 73 | self.in_channels = in_channels 74 | out_channels = in_channels if out_channels is None else out_channels 75 | self.out_channels = out_channels 76 | self.use_conv_shortcut = conv_shortcut 77 | self.up = up 78 | self.down = down 79 | self.output_scale_factor = output_scale_factor 80 | self.time_embedding_norm = time_embedding_norm 81 | self.skip_time_act = skip_time_act 82 | 83 | linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear 84 | conv_cls = SeparableConv2d 85 | 86 | if groups_out is None: 87 | groups_out = groups 88 | 89 | if self.time_embedding_norm == "ada_group": 90 | self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps) 91 | elif self.time_embedding_norm == "spatial": 92 | self.norm1 = SpatialNorm(in_channels, temb_channels) 93 | else: 94 | self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) 95 | 96 | self.conv1 = conv_cls(in_channels, out_channels, kernel_size=3, stride=1, padding=1) 97 | 98 | if temb_channels is not None: 99 | if self.time_embedding_norm == "default": 100 | self.time_emb_proj = linear_cls(temb_channels, out_channels) 101 | elif self.time_embedding_norm == "scale_shift": 102 | self.time_emb_proj = linear_cls(temb_channels, 2 * out_channels) 103 | elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": 104 | self.time_emb_proj = None 105 | else: 106 | raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") 107 | else: 108 | self.time_emb_proj = None 109 | 110 | if self.time_embedding_norm == "ada_group": 111 | self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps) 112 | elif self.time_embedding_norm == "spatial": 113 | self.norm2 = SpatialNorm(out_channels, temb_channels) 114 | else: 115 | self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) 116 | 117 | self.dropout = torch.nn.Dropout(dropout) 118 | conv_2d_out_channels = conv_2d_out_channels or out_channels 119 | self.conv2 = conv_cls(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1) 120 | 121 | self.nonlinearity = get_activation(non_linearity) 122 | 123 | self.upsample = self.downsample = None 124 | if self.up: 125 | if kernel == "fir": 126 | fir_kernel = (1, 3, 3, 1) 127 | self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel) 128 | elif kernel == "sde_vp": 129 | self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest") 130 | else: 131 | self.upsample = Upsample2D(in_channels, use_conv=False) 132 | elif self.down: 133 | if kernel == "fir": 134 | fir_kernel = (1, 3, 3, 1) 135 | self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel) 136 | elif kernel == "sde_vp": 137 | self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2) 138 | else: 139 | self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op") 140 | 141 | self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut 142 | 143 | self.conv_shortcut = None 144 | if self.use_in_shortcut: 145 | self.conv_shortcut = conv_cls( 146 | in_channels, 147 | conv_2d_out_channels, 148 | kernel_size=1, 149 | stride=1, 150 | padding=0, 151 | bias=conv_shortcut_bias, 152 | ) 153 | 154 | def forward( 155 | self, 156 | input_tensor: torch.FloatTensor, 157 | temb: torch.FloatTensor, 158 | scale: float = 1.0, 159 | ) -> torch.FloatTensor: 160 | hidden_states = input_tensor 161 | 162 | if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": 163 | hidden_states = self.norm1(hidden_states, temb) 164 | else: 165 | hidden_states = self.norm1(hidden_states) 166 | 167 | hidden_states = self.nonlinearity(hidden_states) 168 | 169 | if self.upsample is not None: 170 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 171 | if hidden_states.shape[0] >= 64: 172 | input_tensor = input_tensor.contiguous() 173 | hidden_states = hidden_states.contiguous() 174 | input_tensor = ( 175 | self.upsample(input_tensor, scale=scale) 176 | if isinstance(self.upsample, Upsample2D) 177 | else self.upsample(input_tensor) 178 | ) 179 | hidden_states = ( 180 | self.upsample(hidden_states, scale=scale) 181 | if isinstance(self.upsample, Upsample2D) 182 | else self.upsample(hidden_states) 183 | ) 184 | elif self.downsample is not None: 185 | input_tensor = ( 186 | self.downsample(input_tensor, scale=scale) 187 | if isinstance(self.downsample, Downsample2D) 188 | else self.downsample(input_tensor) 189 | ) 190 | hidden_states = ( 191 | self.downsample(hidden_states, scale=scale) 192 | if isinstance(self.downsample, Downsample2D) 193 | else self.downsample(hidden_states) 194 | ) 195 | 196 | hidden_states = self.conv1(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv1(hidden_states) 197 | 198 | if self.time_emb_proj is not None: 199 | if not self.skip_time_act: 200 | temb = self.nonlinearity(temb) 201 | temb = ( 202 | self.time_emb_proj(temb, scale)[:, :, None, None] 203 | if not USE_PEFT_BACKEND 204 | else self.time_emb_proj(temb)[:, :, None, None] 205 | ) 206 | 207 | if temb is not None and self.time_embedding_norm == "default": 208 | hidden_states = hidden_states + temb 209 | 210 | if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": 211 | hidden_states = self.norm2(hidden_states, temb) 212 | else: 213 | hidden_states = self.norm2(hidden_states) 214 | 215 | if temb is not None and self.time_embedding_norm == "scale_shift": 216 | scale, shift = torch.chunk(temb, 2, dim=1) 217 | hidden_states = hidden_states * (1 + scale) + shift 218 | 219 | hidden_states = self.nonlinearity(hidden_states) 220 | 221 | hidden_states = self.dropout(hidden_states) 222 | hidden_states = self.conv2(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv2(hidden_states) 223 | 224 | if self.conv_shortcut is not None: 225 | input_tensor = ( 226 | self.conv_shortcut(input_tensor, scale) if not USE_PEFT_BACKEND else self.conv_shortcut(input_tensor) 227 | ) 228 | 229 | output_tensor = (input_tensor + hidden_states) / self.output_scale_factor 230 | 231 | return output_tensor 232 | -------------------------------------------------------------------------------- /mobile_diffusion_unet/mdTransformer2DModel.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Any, Dict, List, Optional, Tuple, Union 3 | 4 | import torch 5 | import torch.nn as nn 6 | from diffusers.models.modeling_utils import ModelMixin 7 | from diffusers.configuration_utils import ConfigMixin, register_to_config 8 | from diffusers.utils import USE_PEFT_BACKEND, BaseOutput 9 | from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear 10 | from diffusers.utils import is_torch_version, logging 11 | 12 | from mdBasicTransformerBlock import BasicTransformerMobileBlock 13 | 14 | @dataclass 15 | class Transformer2DMobileModelOutput(BaseOutput): 16 | sample: torch.FloatTensor 17 | 18 | class Transformer2DMobileModel(ModelMixin, ConfigMixin): 19 | _supports_gradient_checkpointing = True 20 | 21 | @register_to_config 22 | def __init__( 23 | self, 24 | num_attention_heads: int = 16, 25 | attention_head_dim: int = 88, 26 | in_channels: Optional[int] = None, 27 | out_channels: Optional[int] = None, 28 | num_layers: int = 1, 29 | dropout: float = 0.0, 30 | norm_num_groups: int = 32, 31 | cross_attention_dim: Optional[int] = None, 32 | attention_bias: bool = False, 33 | # sample_size: Optional[int] = None, 34 | # num_vector_embeds: Optional[int] = None, 35 | # patch_size: Optional[int] = None, 36 | activation_fn: str = "geglu", 37 | # num_embeds_ada_norm: Optional[int] = None, 38 | use_linear_projection: bool = False, 39 | only_cross_attention: bool = False, 40 | # double_self_attention: bool = False, 41 | upcast_attention: bool = False, 42 | # norm_type: str = "layer_norm", 43 | norm_elementwise_affine: bool = True, 44 | norm_eps: float = 1e-5, 45 | # attention_type: str = "default", 46 | # caption_channels: int = None, 47 | ): 48 | super().__init__() 49 | self.use_linear_projection = use_linear_projection 50 | self.num_attention_heads = num_attention_heads 51 | self.attention_head_dim = attention_head_dim 52 | inner_dim = num_attention_heads * attention_head_dim 53 | 54 | conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv 55 | linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear 56 | 57 | # self.is_input_continuous = (in_channels is not None) and (patch_size is None) 58 | # self.is_input_vectorized = num_vector_embeds is not None 59 | # self.is_input_patches = in_channels is not None and patch_size is not None 60 | 61 | self.in_channels = in_channels 62 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) 63 | if use_linear_projection: 64 | self.proj_in = linear_cls(in_channels, inner_dim) 65 | else: 66 | self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) 67 | 68 | # 3. Define transformers blocks 69 | self.transformer_blocks = nn.ModuleList( 70 | [ 71 | BasicTransformerMobileBlock( 72 | inner_dim, 73 | num_attention_heads, 74 | attention_head_dim, 75 | dropout=dropout, 76 | cross_attention_dim=cross_attention_dim, 77 | activation_fn=activation_fn, 78 | # num_embeds_ada_norm=num_embeds_ada_norm, 79 | attention_bias=attention_bias, 80 | only_cross_attention=only_cross_attention, 81 | # double_self_attention=double_self_attention, 82 | upcast_attention=upcast_attention, 83 | # norm_type=norm_type, 84 | norm_elementwise_affine=norm_elementwise_affine, 85 | norm_eps=norm_eps, 86 | # attention_type=attention_type, 87 | ) 88 | for d in range(num_layers) 89 | ] 90 | ) 91 | 92 | # 4. Define output layers 93 | self.out_channels = in_channels if out_channels is None else out_channels 94 | if use_linear_projection: 95 | self.proj_out = linear_cls(inner_dim, in_channels) 96 | else: 97 | self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) 98 | 99 | self.gradient_checkpointing = False 100 | 101 | 102 | 103 | def _set_gradient_checkpointing(self, module, value=False): 104 | if hasattr(module, "gradient_checkpointing"): 105 | module.gradient_checkpointing = value 106 | 107 | def forward( 108 | self, 109 | hidden_states: torch.Tensor, 110 | encoder_hidden_states: Optional[torch.Tensor] = None, 111 | timestep: Optional[torch.LongTensor] = None, 112 | added_cond_kwargs: Dict[str, torch.Tensor] = None, 113 | class_labels: Optional[torch.LongTensor] = None, 114 | cross_attention_kwargs: Dict[str, Any] = None, 115 | attention_mask: Optional[torch.Tensor] = None, 116 | encoder_attention_mask: Optional[torch.Tensor] = None, 117 | return_dict: bool = True, 118 | ): 119 | lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 120 | 121 | batch, _, height, width = hidden_states.shape 122 | residual = hidden_states 123 | 124 | hidden_states = self.norm(hidden_states) 125 | 126 | # 1. Input 127 | if not self.use_linear_projection: 128 | hidden_states = ( 129 | self.proj_in(hidden_states, scale=lora_scale) 130 | if not USE_PEFT_BACKEND 131 | else self.proj_in(hidden_states) 132 | ) 133 | inner_dim = hidden_states.shape[1] 134 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) 135 | else: 136 | inner_dim = hidden_states.shape[1] 137 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) 138 | hidden_states = ( 139 | self.proj_in(hidden_states, scale=lora_scale) 140 | if not USE_PEFT_BACKEND 141 | else self.proj_in(hidden_states) 142 | ) 143 | 144 | # 2. Blocks 145 | for block in self.transformer_blocks: 146 | if self.training and self.gradient_checkpointing: 147 | 148 | def create_custom_forward(module, return_dict=None): 149 | def custom_forward(*inputs): 150 | if return_dict is not None: 151 | return module(*inputs, return_dict=return_dict) 152 | else: 153 | return module(*inputs) 154 | 155 | return custom_forward 156 | 157 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 158 | hidden_states = torch.utils.checkpoint.checkpoint( 159 | create_custom_forward(block), 160 | hidden_states, 161 | attention_mask, 162 | encoder_hidden_states, 163 | encoder_attention_mask, 164 | timestep, 165 | cross_attention_kwargs, 166 | class_labels, 167 | **ckpt_kwargs, 168 | ) 169 | else: 170 | hidden_states = block( 171 | hidden_states, 172 | attention_mask=attention_mask, 173 | encoder_hidden_states=encoder_hidden_states, 174 | encoder_attention_mask=encoder_attention_mask, 175 | timestep=timestep, 176 | cross_attention_kwargs=cross_attention_kwargs, 177 | class_labels=class_labels, 178 | ) 179 | 180 | # 3. Output 181 | if not self.use_linear_projection: 182 | hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() 183 | hidden_states = ( 184 | self.proj_out(hidden_states, scale=lora_scale) 185 | if not USE_PEFT_BACKEND 186 | else self.proj_out(hidden_states) 187 | ) 188 | else: 189 | hidden_states = ( 190 | self.proj_out(hidden_states, scale=lora_scale) 191 | if not USE_PEFT_BACKEND 192 | else self.proj_out(hidden_states) 193 | ) 194 | hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() 195 | 196 | output = hidden_states + residual 197 | 198 | if not return_dict: 199 | return (output,) 200 | 201 | return Transformer2DMobileModelOutput(sample=output) 202 | -------------------------------------------------------------------------------- /mobile_diffusion_unet/mdUpBlock2D.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Tuple, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | from diffusers.models.resnet import ResnetBlock2D, Upsample2D 6 | from diffusers.utils import is_torch_version, logging 7 | 8 | class UpBlock2DMobile(nn.Module): 9 | def __init__( 10 | self, 11 | in_channels: int, 12 | # prev_output_channel: int, 13 | out_channels: int, 14 | temb_channels: int, 15 | # resolution_idx: Optional[int] = None, 16 | dropout: float = 0.0, 17 | num_layers: int = 1, 18 | resnet_eps: float = 1e-6, 19 | resnet_time_scale_shift: str = "default", 20 | resnet_act_fn: str = "swish", 21 | resnet_groups: int = 32, 22 | resnet_pre_norm: bool = True, 23 | output_scale_factor: float = 1.0, 24 | add_upsample: bool = True, 25 | ): 26 | super().__init__() 27 | resnets = [] 28 | 29 | for i in range(num_layers): 30 | resnets.append( 31 | ResnetBlock2D( 32 | in_channels=in_channels if i == 0 else out_channels, 33 | out_channels=out_channels, 34 | temb_channels=temb_channels, 35 | eps=resnet_eps, 36 | groups=resnet_groups, 37 | dropout=dropout, 38 | time_embedding_norm=resnet_time_scale_shift, 39 | non_linearity=resnet_act_fn, 40 | output_scale_factor=output_scale_factor, 41 | pre_norm=resnet_pre_norm, 42 | ) 43 | ) 44 | 45 | self.resnets = nn.ModuleList(resnets) 46 | 47 | if add_upsample: 48 | self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) 49 | else: 50 | self.upsamplers = None 51 | 52 | self.gradient_checkpointing = False 53 | 54 | def forward( 55 | self, 56 | hidden_states: torch.FloatTensor, 57 | res_hidden_states: torch.FloatTensor, 58 | temb: Optional[torch.FloatTensor] = None, 59 | upsample_size: Optional[int] = None, 60 | scale: float = 1.0, 61 | ) -> torch.FloatTensor: 62 | 63 | for i, resnet in enumerate(self.resnets): 64 | if i == 0: 65 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) 66 | 67 | if self.training and self.gradient_checkpointing: 68 | 69 | def create_custom_forward(module): 70 | def custom_forward(*inputs): 71 | return module(*inputs) 72 | 73 | return custom_forward 74 | 75 | if is_torch_version(">=", "1.11.0"): 76 | hidden_states = torch.utils.checkpoint.checkpoint( 77 | create_custom_forward(resnet), hidden_states, temb, use_reentrant=False 78 | ) 79 | else: 80 | hidden_states = torch.utils.checkpoint.checkpoint( 81 | create_custom_forward(resnet), hidden_states, temb 82 | ) 83 | else: 84 | hidden_states = resnet(hidden_states, temb, scale=scale) 85 | 86 | if self.upsamplers is not None: 87 | for upsampler in self.upsamplers: 88 | hidden_states = upsampler(hidden_states, upsample_size, scale=scale) 89 | 90 | return hidden_states -------------------------------------------------------------------------------- /progressive_distillation_for_sd/README.md: -------------------------------------------------------------------------------- 1 | ## Progressive Distillation in Latent Space for Stable Diffusion 2 | ### Overview 3 | The original Progressive Distillation is a technique developed by Google for use in Pixel Space. We have adapted it for application in Latent Space, specifically with Stable Diffusion models. This repository contains the reimplemented code, which is based on the diffusers library. 4 | 5 | ### Implementation Details 6 | The provided example demonstrates how to distill the original 1000-step diffusion process down to 32 steps. Users can modify the code to further distill from 32 steps to 16, from 16 to 8, and so on, as needed. 7 | 8 | ### Usage 9 | For the best results, we recommend using models with velocity-prediction. If you opt for models with epsilon-prediction, it is necessary to adjust the Signal-to-Noise Ratio (SNR) parameter accordingly. 10 | ```shell 11 | export MODEL_PATH= 12 | export DATASET_PATH= 13 | export CACHE_PATH= 14 | export OUTPUT_DIR= 15 | 16 | accelerate launch --mixed_precision="fp16" train_text_to_image_pd.py \ 17 | --pretrained_model_name_or_path=$MODEL_PATH \ 18 | --dataset_name=$DATASET_PATH \ 19 | --cache_dir=$CACHE_PATH \ 20 | --resolution=512 --center_crop --random_flip \ 21 | --train_batch_size=32 \ 22 | --gradient_accumulation_steps=4 \ 23 | --gradient_checkpointing \ 24 | --max_train_steps=4000 \ 25 | --num_steps_stu=32 \ 26 | --num_steps_tea=1000 \ 27 | --learning_rate=1e-04 \ 28 | --max_grad_norm=1 \ 29 | --lr_scheduler="constant_with_warmup" --lr_warmup_steps=100 \ 30 | --output_dir=$OUTPUT_DIR 31 | ``` 32 | ### Note 33 | Please ensure that you have the **diffusers** library installed before running the code. The actual code for distillation will vary based on the specific implementation and parameters chosen. 34 | 35 | ### Acknowledgements 36 | This work builds upon the innovative methods developed by Google and applies them to a new domain. We appreciate the contributions of the original authors and the open-source community. -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | bitsandbytes 3 | cmake 4 | datasets 5 | diffusers 6 | einops 7 | huggingface-hub 8 | imageio 9 | matplotlib 10 | ninja 11 | opencv-python 12 | tensorboard 13 | torch 14 | torch-fidelity 15 | torchmetrics 16 | torchvision 17 | tqdm 18 | transformers 19 | xformers 20 | --------------------------------------------------------------------------------