├── README.md ├── assets └── text2img.png ├── paths.py ├── requirements.txt ├── sd_fused ├── __init__.py ├── app │ ├── __init__.py │ ├── helpers.py │ ├── sd.py │ └── setup.py ├── clip │ ├── __init__.py │ ├── clip_embedding.py │ ├── container.py │ ├── parser │ │ ├── __init__.py │ │ ├── add_delimiter4words.py │ │ ├── add_split_maker4emphasis.py │ │ ├── clean_spaces.py │ │ ├── diffuse_prompt.py │ │ ├── expand_delimiters.py │ │ └── prompt_choices.py │ └── text_segment.py ├── layers │ ├── __init__.py │ ├── activation │ │ ├── __init__.py │ │ ├── geglu.py │ │ └── silu.py │ ├── auto_encoder │ │ ├── __init__.py │ │ ├── decoder.py │ │ └── encoder.py │ ├── base │ │ ├── __init__.py │ │ ├── base.py │ │ ├── module.py │ │ ├── module_list.py │ │ ├── sequential.py │ │ └── types.py │ ├── basic │ │ ├── __init__.py │ │ ├── conv2d.py │ │ ├── group_norm.py │ │ ├── identity.py │ │ ├── layer_norm.py │ │ └── linear.py │ ├── blocks │ │ ├── __init__.py │ │ ├── attention │ │ │ ├── __init__.py │ │ │ ├── base_attention.py │ │ │ ├── compute │ │ │ │ ├── __init__.py │ │ │ │ ├── attention.py │ │ │ │ ├── auto_chunk_size.py │ │ │ │ ├── chunked_attention.py │ │ │ │ ├── flash_attention.py │ │ │ │ ├── join_spatial_dim.py │ │ │ │ ├── scale_qk.py │ │ │ │ ├── standard_attention.py │ │ │ │ ├── tome.py │ │ │ │ └── weighted_values.py │ │ │ ├── cross_attention.py │ │ │ └── self_attention.py │ │ ├── basic │ │ │ ├── __init__.py │ │ │ ├── gn_conv.py │ │ │ ├── gn_silu_conv.py │ │ │ └── ln_geglu_linear.py │ │ ├── spatial │ │ │ ├── __init__.py │ │ │ ├── ae │ │ │ │ ├── __init__.py │ │ │ │ ├── down_encoder.py │ │ │ │ └── up_decoder.py │ │ │ ├── base │ │ │ │ ├── __init__.py │ │ │ │ ├── down.py │ │ │ │ └── up.py │ │ │ ├── cross_attention │ │ │ │ ├── __init__.py │ │ │ │ ├── down.py │ │ │ │ └── up.py │ │ │ ├── output_states.py │ │ │ ├── resampling │ │ │ │ ├── __init__.py │ │ │ │ ├── downsample2d.py │ │ │ │ └── upsample2d.py │ │ │ ├── resnet.py │ │ │ └── unet_mid │ │ │ │ ├── __init__.py │ │ │ │ ├── cross_attention.py │ │ │ │ └── self_attention.py │ │ └── transformer │ │ │ ├── __init__.py │ │ │ ├── basic_transformer.py │ │ │ └── spatial_transformer.py │ ├── distribution │ │ ├── __init__.py │ │ └── diag_gaussian.py │ ├── embedding │ │ ├── __init__.py │ │ ├── time_step_emb.py │ │ └── time_steps.py │ ├── external │ │ ├── __init__.py │ │ └── rearrange.py │ └── modifiers │ │ ├── __init__.py │ │ └── half_weights.py ├── models │ ├── __init__.py │ ├── ae_kl.py │ ├── config.py │ ├── convert │ │ ├── __init__.py │ │ ├── states.py │ │ ├── unet │ │ │ └── diffusers2fused.py │ │ └── vae │ │ │ ├── diffusers2fused.py │ │ │ └── sd2diffusers.py │ ├── modifiers │ │ ├── __init__.py │ │ ├── flash_attention.py │ │ ├── half_weights.py │ │ ├── split_attention.py │ │ └── tome.py │ └── unet_conditional.py ├── scheduler │ ├── __init__.py │ ├── ddim.py │ └── scheduler.py └── utils │ ├── __init__.py │ ├── cuda │ ├── __init__.py │ ├── clear_cuda.py │ └── free_memory.py │ ├── diverse │ ├── __init__.py │ ├── product_args.py │ ├── separate.py │ ├── single.py │ └── to_list.py │ ├── image │ ├── __init__.py │ ├── image2tensor.py │ ├── image_base64.py │ ├── image_size.py │ ├── open_image.py │ ├── tensor2images.py │ └── types.py │ ├── parameters │ ├── __init__.py │ ├── batch_parameters.py │ ├── group_parameters.py │ ├── parameters.py │ └── parameters_list.py │ ├── tensors │ ├── __init__.py │ ├── generate_noise.py │ ├── normalize.py │ ├── slerp.py │ └── to_tensor.py │ └── typing.py └── setup.py /README.md: -------------------------------------------------------------------------------- 1 | # Stable-Diffusion + Fused CUDA kernels = FUN! 2 | 3 | ## Introduction 4 | 5 | This is a re-written implementation of Stable-Diffusion (SD) based on the original [diffusers](https://github.com/huggingface/diffusers) and [stable-diffusion](https://github.com/CompVis/stable-diffusion) repositories (all kudos for the original programmers). 6 | 7 | The goal of this reimplementation is to make it clearer, more readable, and more upgradable code that is easy to read and modify. 8 | Unfortunately, the original code is very difficult to read due to the lack of proper typing, variable naming, and other factors. 9 | 10 | ## For the inpatients: 11 | 12 | ### Emphasis 13 | 14 | Using the notation `(a few words):weight` you can give emphasis (high number), take out emphasis (small number), or even avoid the subject (negative number). 15 | The words (tokens) inside the parentheses are given a weight that is passed down to the attention calculation, enhancing, attenuating, or negative the attention to the given token. 16 | 17 | Below is a small test where the word `cyberpunk` is given a different emphasis. 18 | 19 | ```python 20 | weight = torch.linspace(-1.2, 4.2, 32).tolist() 21 | choices = '|'.join(map(str, weight)) 22 | 23 | out = pipeline.generate( 24 | prompt=f"portrait, woman, cyberpunk:[{choices}], digital art, detailed, epic, beautiful", 25 | steps=24, 26 | scale=11, 27 | height=512, 28 | width=512, 29 | seed=1658926406, 30 | eta=0.6, 31 | show=True, 32 | batch_size=8, 33 | ) 34 | ``` 35 | 36 | 37 | 38 | 39 | ### Batched sweep 40 | 41 | Any input parameter can be passed as a list for sweeping, where any multiple combinations of sweeps are allowed. 42 | For example: 43 | 44 | ```python 45 | out = pipeline.generate( 46 | prompt="portrait, woman, cyberpunk, digital art, detailed, epic, beautiful", 47 | steps=26, 48 | height=512, 49 | width=512, 50 | seed=1331366415, 51 | eta=torch.linspace(-1, 1, 64).tolist(), 52 | show=True, 53 | batch_size=8, 54 | ) 55 | ``` 56 | 57 | 58 | 59 | 60 | ### Seed-Interpolations 61 | 62 | You can perform interpolation between many to one known seeds. 63 | 64 | ```python 65 | # pipeline.tome(None) 66 | out = pipeline.generate( 67 | prompt="portrait, woman, cyberpunk, digital art, detailed, epic, beautiful", 68 | steps=26, 69 | height=512, 70 | width=512, 71 | seed=3783195593, 72 | sub_seed=2148348002, 73 | interpolation=torch.linspace(0, 1, 64).tolist(), 74 | eta=0.6, 75 | show=True, 76 | batch_size=8, 77 | ) 78 | ``` 79 | 80 | 81 | 82 | 100 | 101 | 102 | 103 | 104 | 124 | 125 | 126 | 127 | 146 | 147 | 148 | ## Kernel fusion 149 | 150 | This is an ongoing project to fuse as many layers as possible to make it more memory friendly and faster. 151 | 152 | ## Installation 153 | 154 | ```bash 155 | pip install -U git+https://github.com/tfernd/sd-fused 156 | ``` 157 | 158 | ## Text2Image generation 159 | 160 | Base code for text-to-image generation. 161 | 162 | ```python 163 | from IPython.display import display 164 | from sd_fused.app import StableDiffusion 165 | 166 | # Assuming you downloaded SD and put it in the folder below 167 | pipeline = StableDiffusion('.pretrained/stable-diffusion') 168 | 169 | # If you have a GPU with 3-4 Gb, use the line below 170 | # pipeline.set_low_ram().half_weights().cuda() 171 | pipeline.half().cuda() 172 | pipeline.split_attention(cross_attention_chunks=1) 173 | # if you have xformers installed, use the line below 174 | # pipeline.flash_attention() 175 | 176 | out = pipeline.generate( 177 | prompt='portrait of zombie, digital art, detailed, artistic', 178 | negative_prompt='old man', 179 | steps=28, 180 | scale=11, 181 | height=512, 182 | width=512, 183 | seed=42, 184 | show=True 185 | ) 186 | ``` 187 | 188 | ![portrait of zombie, digital art, detailed, artistic](assets/text2img.png) 189 | -------------------------------------------------------------------------------- /assets/text2img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tfernd/sd-fused/50570f983cc00dd4bd0dc415d0b515da638311b0/assets/text2img.png -------------------------------------------------------------------------------- /paths.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import modules.safe 5 | 6 | script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 7 | 8 | # Parse the --data-dir flag first so we can use it as a base for our other argument default values 9 | parser = argparse.ArgumentParser(add_help=False) 10 | parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",) 11 | cmd_opts_pre = parser.parse_known_args()[0] 12 | data_path = cmd_opts_pre.data_dir 13 | models_path = os.path.join(data_path, "models") 14 | 15 | # data_path = cmd_opts_pre.data 16 | sys.path.insert(0, script_path) 17 | 18 | # search for directory of stable diffusion in following places 19 | sd_path = None 20 | possible_sd_paths = [ 21 | os.path.join(script_path, '/content/gdrive/MyDrive/sd/stablediffusion'), 22 | os.path.join(script_path, '/content/sd/stablediffusion'), 23 | '.', 24 | os.path.dirname(script_path) 25 | ] 26 | for possible_sd_path in possible_sd_paths: 27 | if os.path.exists(os.path.join(possible_sd_path, 'ldm/models/diffusion/ddpm.py')): 28 | sd_path = os.path.abspath(possible_sd_path) 29 | break 30 | 31 | assert sd_path is not None, "Couldn't find Stable Diffusion in any of: " + str(possible_sd_paths) 32 | 33 | path_dirs = [ 34 | (sd_path, 'ldm', 'Stable Diffusion', []), 35 | (os.path.join(sd_path, 'src/taming-transformers'), 'taming', 'Taming Transformers', []), 36 | (os.path.join(sd_path, 'src/codeformer'), 'inference_codeformer.py', 'CodeFormer', []), 37 | (os.path.join(sd_path, 'src/blip'), 'models/blip.py', 'BLIP', []), 38 | (os.path.join(sd_path, 'src/k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]), 39 | ] 40 | 41 | paths = {} 42 | 43 | for d, must_exist, what, options in path_dirs: 44 | must_exist_path = os.path.abspath(os.path.join(script_path, d, must_exist)) 45 | if not os.path.exists(must_exist_path): 46 | print(f"Warning: {what} not found at path {must_exist_path}", file=sys.stderr) 47 | else: 48 | d = os.path.abspath(d) 49 | if "atstart" in options: 50 | sys.path.insert(0, d) 51 | else: 52 | sys.path.append(d) 53 | paths[what] = d 54 | 55 | class Prioritize: 56 | def __init__(self, name): 57 | self.name = name 58 | self.path = None 59 | 60 | def __enter__(self): 61 | self.path = sys.path.copy() 62 | sys.path = [paths[self.name]] + sys.path 63 | 64 | def __exit__(self, exc_type, exc_val, exc_tb): 65 | sys.path = self.path 66 | self.path = None -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ipywidgets>=7,<8 2 | typing_extensions>=4.3 3 | torch 4 | transformers>=4.24 5 | einops>=0.6 6 | tqdm>=4.64 7 | Pillow>=9.3 8 | scipy 9 | ftfy 10 | validators>=0.20 -------------------------------------------------------------------------------- /sd_fused/__init__.py: -------------------------------------------------------------------------------- 1 | from .app import StableDiffusion 2 | -------------------------------------------------------------------------------- /sd_fused/app/__init__.py: -------------------------------------------------------------------------------- 1 | from .sd import StableDiffusion 2 | -------------------------------------------------------------------------------- /sd_fused/app/helpers.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Optional 3 | 4 | from pathlib import Path 5 | from datetime import datetime 6 | from copy import deepcopy 7 | 8 | from PIL import Image 9 | from PIL.PngImagePlugin import PngInfo 10 | 11 | import random 12 | import torch 13 | from torch import Tensor 14 | 15 | from ..models import AutoencoderKL, UNet2DConditional 16 | from ..clip import ClipEmbedding 17 | from ..utils.tensors import slerp, generate_noise 18 | 19 | MAGIC = 0.18215 20 | 21 | 22 | class Helpers: 23 | # version: str 24 | model_name: str 25 | 26 | save_dir: Path 27 | 28 | clip: ClipEmbedding 29 | vae: AutoencoderKL 30 | unet: UNet2DConditional 31 | 32 | device: torch.device 33 | dtype: torch.dtype 34 | 35 | @property 36 | def latent_channels(self) -> int: 37 | """Latent-space channel size.""" 38 | 39 | return self.unet.out_channels 40 | 41 | @property 42 | def is_true_inpainting(self) -> bool: 43 | """RunwayMl true inpainting model.""" 44 | 45 | return self.unet.in_channels == 4 and self.latent_channels == 9 46 | 47 | def save_image( 48 | self, 49 | image: Image.Image, 50 | png_info: Optional[PngInfo] = None, 51 | ID: Optional[int] = None, 52 | ) -> Path: 53 | """Save the image using the provided metadata information.""" 54 | 55 | now = datetime.now() 56 | timestamp = now.strftime(r"%Y-%m-%d %H-%M-%S.%f") 57 | 58 | if ID is None: 59 | ID = random.randint(0, 2**64) 60 | 61 | self.save_dir.mkdir(parents=True, exist_ok=True) 62 | 63 | path = self.save_dir / f"{timestamp} - {ID:x}.SD.png" 64 | image.save(path, bitmap_format="png", pnginfo=png_info) 65 | 66 | return path 67 | 68 | @torch.no_grad() 69 | def encode(self, data: Tensor) -> Tensor: 70 | """Encodes (stochastically) a RGB image into a latent vector.""" 71 | 72 | return self.vae.encode(data).sample().mul(MAGIC) 73 | 74 | @torch.no_grad() 75 | def decode(self, latents: Tensor) -> Tensor: 76 | """Decode latent vector into an RGB image.""" 77 | 78 | return self.vae.decode(latents.div(MAGIC)) 79 | 80 | @torch.no_grad() 81 | def get_context( 82 | self, 83 | negative_prompts: list[str], 84 | prompts: Optional[list[str]], 85 | ) -> tuple[Tensor, Optional[Tensor]]: 86 | """Creates a context Tensor (negative + positive prompt) and a emphasis weights.""" 87 | 88 | texts = deepcopy(negative_prompts) 89 | if prompts is not None: 90 | texts.extend(deepcopy(prompts)) 91 | 92 | context, weight = self.clip(texts, self.device, self.dtype) 93 | 94 | return context, weight 95 | 96 | def generate_noise( 97 | self, 98 | seeds: list[int], 99 | sub_seeds: Optional[list[int]], 100 | interpolations: Optional[Tensor], 101 | height: int, 102 | width: int, 103 | batch_size: int, 104 | ) -> Tensor: 105 | """Generate random noise with individual seeds per batch and 106 | possible sub-seed interpolation.""" 107 | 108 | shape = (batch_size, self.latent_channels, height // 8, width // 8) 109 | noise = generate_noise(shape, seeds, self.device, self.dtype) 110 | if sub_seeds is None: 111 | return noise 112 | 113 | assert interpolations is not None 114 | sub_noise = generate_noise(shape, sub_seeds, self.device, self.dtype) 115 | noise = slerp(noise, sub_noise, interpolations) 116 | 117 | return noise 118 | -------------------------------------------------------------------------------- /sd_fused/app/sd.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Optional 3 | 4 | from pathlib import Path 5 | from tqdm.auto import trange, tqdm 6 | from PIL import Image 7 | from IPython.display import display 8 | 9 | from copy import deepcopy 10 | import torch 11 | from torch import Tensor 12 | 13 | from ..models import AutoencoderKL, UNet2DConditional 14 | from ..clip import ClipEmbedding 15 | from ..utils.cuda import clear_cuda 16 | from ..scheduler import Scheduler, DDIMScheduler 17 | from ..utils.image import tensor2images 18 | from ..utils.image import ImageType, ResizeModes 19 | from ..utils.typing import MaybeIterable 20 | from ..utils.diverse import to_list, product_args 21 | from ..clip.parser import prompts_choices 22 | from ..utils.tensors import random_seeds 23 | from ..utils.parameters import Parameters, ParametersList, group_parameters, batch_parameters 24 | from .setup import Setup 25 | from .helpers import Helpers 26 | 27 | 28 | class StableDiffusion(Setup, Helpers): 29 | version: str = "0.6.0" 30 | 31 | def __init__( 32 | self, 33 | path: str | Path, 34 | *, 35 | save_dir: str | Path = "./gallery", 36 | model_name: Optional[str] = None, 37 | ) -> None: 38 | """Load Stable-Diffusion diffusers checkpoint.""" 39 | 40 | self.path = path = Path(path) 41 | self.save_dir = Path(save_dir) 42 | self.model_name = model_name or path.name 43 | 44 | assert path.is_dir() 45 | 46 | self.clip = ClipEmbedding(path / "tokenizer", path / "text_encoder") 47 | 48 | self.vae = AutoencoderKL.from_diffusers(path / "vae") 49 | self.unet = UNet2DConditional.from_diffusers(path / "unet") 50 | 51 | # initialize 52 | self.low_ram(False) 53 | self.split_attention(None) 54 | self.flash_attention(False) 55 | self.tome(None) 56 | 57 | def generate( 58 | self, 59 | *, 60 | eta: MaybeIterable[float] = 0, 61 | steps: MaybeIterable[int] = 32, 62 | height: MaybeIterable[int] = 512, 63 | width: MaybeIterable[int] = 512, 64 | negative_prompt: MaybeIterable[str] = "", 65 | # optionals 66 | scale: Optional[MaybeIterable[float]] = 7.5, 67 | prompt: Optional[MaybeIterable[str]] = None, 68 | img: Optional[MaybeIterable[ImageType]] = None, 69 | mask: Optional[MaybeIterable[ImageType]] = None, 70 | strength: Optional[MaybeIterable[float]] = None, 71 | mode: Optional[ResizeModes] = None, 72 | seed: Optional[MaybeIterable[int]] = None, 73 | # sub_seed: Optional[int] = None, # TODO Iterable? 74 | # seed_interpolation: Optional[MaybeIterable[float]] = None, 75 | # latents: Optional[Tensor] = None, # TODO 76 | batch_size: int = 1, 77 | repeat: int = 1, 78 | show: bool = True, 79 | share_seed: bool = True, 80 | ) -> list[tuple[Image.Image, Path, Parameters]]: 81 | """Create a list of parameters and group them 82 | into batches to be processed. 83 | """ 84 | 85 | if seed is not None: 86 | repeat = 1 87 | 88 | if prompt is not None: 89 | prompt = prompts_choices(prompt) 90 | negative_prompt = prompts_choices(negative_prompt) 91 | 92 | list_kwargs = product_args( 93 | eta=eta, 94 | steps=steps, 95 | scale=scale, 96 | height=height, 97 | width=width, 98 | negative_prompt=negative_prompt, 99 | prompt=prompt, 100 | img=img, 101 | mask=mask, 102 | strength=strength, 103 | # seed_interpolation=seed_interpolation, 104 | ) 105 | size = len(list_kwargs) 106 | list_kwargs = deepcopy(list_kwargs * repeat) 107 | 108 | # if seeds are given or share-seed set 109 | # each repeated-iteration has the same seed 110 | if seed is not None or share_seed: 111 | if seed is None: 112 | seed = random_seeds(repeat) 113 | seeds = [s for s in to_list(seed) for _ in range(size)] 114 | 115 | # otherwise each iteration has it's own unique seed 116 | else: 117 | seeds = random_seeds(size * repeat) 118 | 119 | # create parameters list and group/batch them 120 | parameters = [ 121 | Parameters(**kwargs, mode=mode, seed=seed, device=self.device, dtype=self.dtype) # sub_seed=sub_seed 122 | for (seed, kwargs) in zip(seeds, list_kwargs) 123 | ] 124 | groups = group_parameters(parameters) 125 | batched_parameters = batch_parameters(groups, batch_size) 126 | 127 | out: list[tuple[Image.Image, Path, Parameters]] = [] 128 | for params in tqdm(batched_parameters, desc="Generating batches"): 129 | ipp = self.generate_from_parameters(ParametersList(params)) 130 | out.extend(ipp) 131 | 132 | if show: 133 | for image, path, parameters in ipp: 134 | print(parameters) 135 | display(image) 136 | 137 | return out 138 | 139 | @torch.no_grad() 140 | def generate_from_parameters( 141 | self, 142 | pL: ParametersList, 143 | ) -> list[tuple[Image.Image, Path, Parameters]]: 144 | 145 | context, weight = self.get_context(pL.negative_prompts, pL.prompts) 146 | 147 | # TODO make general 148 | scheduler = DDIMScheduler( 149 | pL.steps, pL.shape(self.latent_channels), pL.seeds, pL.strength, self.device, self.dtype 150 | ) 151 | 152 | enc = lambda x: self.encode(x) if x is not None else None 153 | image_latents = enc(pL.images_data) 154 | mask_latents = enc(pL.masks_data) 155 | masked_image_latents = enc(pL.masked_images_data) 156 | 157 | latents = scheduler.prepare_latents(image_latents, mask_latents, masked_image_latents) 158 | # TODO add to scheduler 159 | latents = self.denoise_latents(scheduler, latents, context, weight, pL.unconditional, pL.scales, pL.etas) 160 | 161 | data = self.decode(latents) 162 | images = tensor2images(data) 163 | 164 | paths: list[Path] = [] 165 | for parameter, image in zip(pL, images): 166 | path = self.save_image(image, parameter.png_info) 167 | paths.append(path) 168 | 169 | return list(zip(images, paths, list(pL))) 170 | 171 | def denoise_latents( 172 | self, 173 | scheduler: Scheduler, 174 | latents: Tensor, 175 | context: Tensor, 176 | weight: Optional[Tensor], 177 | unconditional: bool, 178 | scales: Optional[Tensor], 179 | etas: Optional[Tensor], 180 | ) -> Tensor: 181 | """Main loop where latents are denoised.""" 182 | 183 | clear_cuda() 184 | for index in trange(scheduler.skip_timestep, scheduler.steps, desc="Denoising latents"): 185 | timestep = int(scheduler.timesteps[index].item()) 186 | 187 | pred_noise = scheduler.pred_noise( 188 | self.unet, latents, timestep, context, weight, scales, unconditional, self.use_low_ram 189 | ) 190 | latents = scheduler.step(pred_noise, latents, index, etas=etas) 191 | 192 | del pred_noise 193 | clear_cuda() 194 | 195 | return latents 196 | 197 | def __repr__(self) -> str: 198 | name = self.__class__.__qualname__ 199 | 200 | return f'{name}(model="{self.model_name}", version="{self.version}")' 201 | -------------------------------------------------------------------------------- /sd_fused/app/setup.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Optional 3 | from typing_extensions import Self 4 | 5 | import torch 6 | 7 | from ..utils.typing import Literal 8 | from ..utils.cuda import clear_cuda 9 | from ..models import AutoencoderKL, UNet2DConditional 10 | from ..layers.blocks.attention.compute import ChunkType 11 | from ..layers.base.types import Device 12 | from ..layers.base.base import Base 13 | 14 | 15 | class Setup(Base): 16 | use_low_ram: bool 17 | 18 | vae: AutoencoderKL 19 | unet: UNet2DConditional 20 | 21 | def low_ram(self, use: bool = True) -> Self: 22 | """Split context into two passes to save memory.""" 23 | 24 | self.use_low_ram = use 25 | 26 | return self 27 | 28 | def to(self, *, device: Optional[Device] = None, dtype: Optional[torch.dtype] = None) -> Self: 29 | if device is not None: 30 | self.device = device 31 | if dtype is not None: 32 | self.dtype = dtype 33 | 34 | self.unet.to(device=self.device, dtype=dtype) 35 | self.vae.to(device=self.device, dtype=dtype) 36 | 37 | return self 38 | 39 | def cuda(self) -> Self: 40 | """Send unet and auto-encoder to cuda.""" 41 | 42 | clear_cuda() 43 | 44 | return self.to(device="cuda") 45 | 46 | def cpu(self) -> Self: 47 | """Send unet and auto-encoder to cpu.""" 48 | 49 | return self.to(device="cpu") 50 | 51 | def half(self) -> Self: 52 | """Use half-precision for unet and auto-encoder.""" 53 | 54 | return self.to(dtype=torch.float16) 55 | 56 | def float(self) -> Self: 57 | """Use full-precision for unet and auto-encoder.""" 58 | 59 | return self.to(dtype=torch.float32) 60 | 61 | def half_weights(self, use: bool = True) -> Self: 62 | """Store the weights in half-precision but 63 | compute forward pass in full precision. 64 | Useful for GPUs that gives NaN when used in half-precision. 65 | """ 66 | 67 | self.unet.half_weights(use) 68 | self.vae.half_weights(use) 69 | 70 | return self 71 | 72 | def split_attention( 73 | self, 74 | chunks: Optional[int | Literal["auto"]] = "auto", 75 | chunk_types: Optional[ChunkType] = None, 76 | ) -> Self: 77 | """Split cross-attention computation into chunks.""" 78 | 79 | # TODO this should not be here... 80 | # default to batch if not set 81 | if chunks is not None and chunk_types is None: 82 | chunk_types = "batch" 83 | 84 | self.unet.split_attention(chunks, chunk_types) 85 | self.vae.split_attention(chunks, chunk_types) 86 | 87 | return self 88 | 89 | def flash_attention(self, use: bool = True) -> Self: 90 | """Use xformers flash-attention.""" 91 | 92 | self.unet.flash_attention(use) 93 | self.vae.flash_attention(use) 94 | 95 | return self 96 | 97 | def tome(self, r: Optional[int | float] = None) -> Self: 98 | """Merge similar tokens.""" 99 | 100 | self.unet.tome(r) 101 | self.vae.tome(r) 102 | 103 | return self 104 | -------------------------------------------------------------------------------- /sd_fused/clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip_embedding import ClipEmbedding 2 | -------------------------------------------------------------------------------- /sd_fused/clip/clip_embedding.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Optional 3 | 4 | from functools import lru_cache 5 | from pathlib import Path 6 | 7 | import torch 8 | 9 | from transformers.models.clip.modeling_clip import CLIPTextModel 10 | from transformers.models.clip.tokenization_clip import CLIPTokenizer 11 | 12 | from ..layers.base.types import Device 13 | from .text_segment import TextSegment 14 | from .container import TensorAndWeight, TensorAndMaybeWeight 15 | from .parser import ( 16 | clean_spaces, 17 | add_delimiter4words, 18 | expand_delimiters, 19 | add_split_maker4emphasis, 20 | split_prompt_into_segments, 21 | ) 22 | 23 | MAX_TOKENS = 77 24 | 25 | 26 | class ClipEmbedding: 27 | """Convert a text to embeddings using a CLIP model.""" 28 | 29 | tokenizer: CLIPTokenizer 30 | text_encoder: CLIPTextModel 31 | 32 | def __init__( 33 | self, 34 | tokenizer_path: str | Path, 35 | text_encoder_path: str | Path, 36 | ) -> None: 37 | # no need for CUDA for simple embeddings... 38 | self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path) 39 | self.text_encoder = CLIPTextModel.from_pretrained(text_encoder_path) # type: ignore 40 | 41 | # token ids markers 42 | assert self.tokenizer.bos_token_id is not None 43 | assert self.tokenizer.eos_token_id is not None 44 | assert self.tokenizer.pad_token_id is not None 45 | 46 | self.bos_token_id = self.tokenizer.bos_token_id 47 | self.eos_token_id = self.tokenizer.eos_token_id 48 | self.pad_token_id = self.tokenizer.pad_token_id 49 | 50 | @staticmethod 51 | def parse_emphasis(prompt: str) -> str: 52 | """Parse emphasis notation.""" 53 | 54 | prompt = add_delimiter4words(prompt) 55 | prompt = expand_delimiters(prompt) 56 | prompt = add_split_maker4emphasis(prompt) 57 | 58 | return prompt 59 | 60 | @lru_cache(maxsize=None) 61 | def get_ids_and_weights(self, prompt: str) -> TensorAndWeight: 62 | """Get the token id and weight for a given prompt.""" 63 | 64 | prompt = clean_spaces(prompt) 65 | prompt = self.parse_emphasis(prompt) 66 | 67 | segments = [TextSegment(t) for t in split_prompt_into_segments(prompt)] 68 | 69 | ids: list[int] = [] 70 | weights: list[float] = [] 71 | for n, seg in enumerate(segments): 72 | seg_ids = self.tokenizer.encode(seg.prompt) 73 | 74 | # remove initial/final ids 75 | seg_ids = seg_ids[1:-1] 76 | 77 | ids.extend(seg_ids) 78 | weights.extend([seg.weight] * (len(seg_ids))) 79 | 80 | # add padding and initial/final ids 81 | pad_size = MAX_TOKENS - len(ids) - 2 82 | assert pad_size >= 0, "Text too big, it will result in truncation" 83 | 84 | ids = [ 85 | self.bos_token_id, 86 | *ids, 87 | *[self.pad_token_id] * pad_size, 88 | self.eos_token_id, 89 | ] 90 | weights = [1, *weights, *[1] * pad_size, 1] 91 | 92 | return TensorAndWeight(torch.tensor([ids]), torch.tensor([weights]).float()) 93 | 94 | @lru_cache(maxsize=None) 95 | @torch.no_grad() 96 | def get_embedding(self, prompt: str) -> TensorAndWeight: 97 | """Creates an embedding/weights for a prompt and cache it.""" 98 | 99 | ids, weight = self.get_ids_and_weights(prompt) 100 | emb = self.text_encoder(ids)[0] 101 | 102 | return TensorAndWeight(emb, weight) 103 | 104 | def __call__( 105 | self, 106 | prompt: str | list[str] = "", 107 | device: Optional[Device] = None, 108 | dtype: Optional[torch.dtype] = None, 109 | ) -> TensorAndMaybeWeight: 110 | """Creates embeddings/weights for a prompt and send to the correct device/dtype.""" 111 | 112 | if isinstance(prompt, str): 113 | prompt = [prompt] 114 | values = [self.get_embedding(t) for t in prompt] 115 | emb = torch.cat([v.tensor for v in values]) 116 | weight = torch.cat([v.weight for v in values]) 117 | 118 | emb = emb.to(device=device, dtype=dtype, non_blocking=True) 119 | 120 | # special case where all weights are one 121 | if weight.eq(1).all(): 122 | return TensorAndMaybeWeight(emb) 123 | 124 | weight = weight.to(device=device, dtype=dtype, non_blocking=True) 125 | 126 | return TensorAndMaybeWeight(emb, weight) 127 | -------------------------------------------------------------------------------- /sd_fused/clip/container.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import NamedTuple, Optional 3 | 4 | from torch import Tensor 5 | 6 | 7 | class TensorAndWeight(NamedTuple): 8 | tensor: Tensor 9 | weight: Tensor 10 | 11 | 12 | class TensorAndMaybeWeight(NamedTuple): 13 | tensor: Tensor 14 | weight: Optional[Tensor] = None 15 | -------------------------------------------------------------------------------- /sd_fused/clip/parser/__init__.py: -------------------------------------------------------------------------------- 1 | from .clean_spaces import clean_spaces 2 | from .add_delimiter4words import add_delimiter4words 3 | from .add_split_maker4emphasis import add_split_maker4emphasis, split_prompt_into_segments 4 | from .expand_delimiters import expand_delimiters 5 | from .prompt_choices import prompt_choices, prompts_choices 6 | 7 | from .diffuse_prompt import diffuse_prompt 8 | -------------------------------------------------------------------------------- /sd_fused/clip/parser/add_delimiter4words.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import re 4 | 5 | 6 | def add_delimiter4words(prompt: str) -> str: 7 | """Replaces `word:weight` -> `(word):weight`.""" 8 | 9 | prompt = re.sub(r"(\w+):([+-]?\d+(?:.\d+)?)", r"(\1):\2", prompt) 10 | 11 | return prompt 12 | -------------------------------------------------------------------------------- /sd_fused/clip/parser/add_split_maker4emphasis.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import re 4 | 5 | 6 | def add_split_maker4emphasis(prompt: str) -> str: 7 | """Add ⏎ to the begginig and end of (..):value""" 8 | 9 | pattern = r"(\(.+?\):[+-]?\d+(?:.\d+)?)" 10 | prompt = re.sub(pattern, r"⏎\1⏎", prompt) 11 | 12 | return prompt 13 | 14 | 15 | def split_prompt_into_segments(prompt: str) -> list[str]: 16 | """Split a prompt at ⏎ to give prompt-segments.""" 17 | 18 | return prompt.split("⏎") 19 | -------------------------------------------------------------------------------- /sd_fused/clip/parser/clean_spaces.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import re 4 | 5 | 6 | def clean_spaces(prompt: str) -> str: 7 | """Clean-up spaces/return characters.""" 8 | 9 | prompt = prompt.replace("\n", " ") 10 | prompt = re.sub(r"[ ]+", r" ", prompt) 11 | prompt = prompt.strip() 12 | 13 | return prompt 14 | -------------------------------------------------------------------------------- /sd_fused/clip/parser/diffuse_prompt.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import re 4 | 5 | import math 6 | from typing import Optional 7 | import torch 8 | 9 | 10 | def diffuse_prompt( 11 | prompt: str, 12 | vmax: float = 1, 13 | size: int = 1, 14 | seed: Optional[int] = None, 15 | ) -> list[str]: 16 | """Diffuse attention-weights to a prompt.""" 17 | 18 | assert ":" not in prompt 19 | 20 | pattern = re.compile(r"(\w+)") 21 | words = pattern.split(prompt) 22 | n = len(words) 23 | 24 | generator = torch.Generator() 25 | if seed is not None: 26 | generator.manual_seed(seed) 27 | 28 | weigths = torch.randn(size, n, generator=generator) 29 | weigths = weigths.cumsum(0) / math.sqrt(size) * vmax 30 | weigths += 1 - weigths[[0]] # start at the same weight 31 | weigths = weigths.tolist() 32 | 33 | prompts: list[str] = [] 34 | for weight in weigths: 35 | prompt = "".join([f"{w}:{a:.3f}" if pattern.search(w) else w for w, a in zip(words, weight)]) 36 | prompts.append(prompt) 37 | 38 | return prompts 39 | -------------------------------------------------------------------------------- /sd_fused/clip/parser/expand_delimiters.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import re 4 | 5 | FACTOR = 1.1 6 | MAX_EMPHASIS = 8 7 | 8 | 9 | def expand_delimiters(prompt: str) -> str: 10 | """Replace `(^n ... )^n` with `( ... ):factor^n`""" 11 | 12 | delimiters = [(r"\(" * repeat, r"\)" * repeat, repeat) for repeat in range(MAX_EMPHASIS, 0, -1)] 13 | 14 | avoid = r"\(\)" 15 | for left, right, repeat in delimiters: 16 | pattern = f"{left}([^{avoid}]+?){right}([^:]|$)" 17 | repl = f"(\\1):{FACTOR**repeat:.4f}\\2" 18 | prompt = re.sub(pattern, repl, prompt) 19 | 20 | # recover back parantheses and brackets 21 | prompt = prompt.replace(r"\(", "(").replace(r"\)", ")") 22 | prompt = prompt.replace(r"\[", "[").replace(r"\]", "]") 23 | 24 | return prompt 25 | -------------------------------------------------------------------------------- /sd_fused/clip/parser/prompt_choices.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import re 4 | 5 | from ...utils.diverse import to_list 6 | from ...utils.typing import MaybeIterable 7 | 8 | 9 | def prompt_choices(prompt: str) -> list[str]: 10 | """Create a set of prompt word-choices from 11 | `[word 1 | another word | yet another one]` 12 | """ 13 | 14 | pattern = re.compile(r"\[([^\[\]]+)\]") 15 | 16 | temp: list[str] = [prompt] 17 | prompts: list[str] = [] 18 | 19 | while len(temp) != 0: 20 | prompt = temp.pop() 21 | 22 | match = pattern.search(prompt) 23 | if match is not None: 24 | start, end = match.span() 25 | 26 | choices = match.group(1).split("|") 27 | assert len(choices) > 1 28 | 29 | for choice in choices: 30 | choice = choice.strip() 31 | new_text = "".join([prompt[:start], choice, prompt[end:]]) 32 | temp.append(new_text.strip()) 33 | else: 34 | prompts.append(prompt) 35 | 36 | return prompts[::-1] 37 | 38 | 39 | def prompts_choices(prompts: MaybeIterable[str]) -> list[str]: 40 | """Create a set of prompt word-choices from 41 | `[word 1 | another word | yet another one]`""" 42 | 43 | return [choice for prompt in to_list(prompts) for choice in prompt_choices(prompt)] 44 | -------------------------------------------------------------------------------- /sd_fused/clip/text_segment.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import re 4 | 5 | 6 | class TextSegment: 7 | """Split `(prompt):weight` for parsing.""" 8 | 9 | prompt: str 10 | weight: float = 1 11 | 12 | def __init__(self, prompt: str) -> None: 13 | self.prompt = prompt 14 | 15 | pattern = r"\((.+?)\):([+-]?\d+(?:.\d+)?)" 16 | match = re.match(pattern, prompt) 17 | if match: 18 | self.prompt = match.group(1) 19 | self.weight = float(match.group(2)) 20 | 21 | def __repr__(self) -> str: 22 | return f'{self.__class__.__qualname__}(text="{self.prompt}"; weight={self.weight})' 23 | -------------------------------------------------------------------------------- /sd_fused/layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tfernd/sd-fused/50570f983cc00dd4bd0dc415d0b515da638311b0/sd_fused/layers/__init__.py -------------------------------------------------------------------------------- /sd_fused/layers/activation/__init__.py: -------------------------------------------------------------------------------- 1 | from .geglu import GEGLU 2 | from .silu import SiLU 3 | -------------------------------------------------------------------------------- /sd_fused/layers/activation/geglu.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import torch.nn.functional as F 4 | from torch import Tensor 5 | 6 | from ..base import Module 7 | from ..basic import Linear 8 | 9 | 10 | class GEGLU(Module): 11 | def __init__(self, dim_in: int, dim_out: int) -> None: 12 | super().__init__() 13 | 14 | self.dim_in = dim_in 15 | self.dim_out = dim_out 16 | 17 | self.proj = Linear(dim_in, 2 * dim_out) 18 | 19 | def __call__(self, x: Tensor) -> Tensor: 20 | x, gate = self.proj(x).chunk(2, dim=-1) 21 | 22 | return F.gelu(gate).mul_(x) 23 | -------------------------------------------------------------------------------- /sd_fused/layers/activation/silu.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from torch import Tensor 4 | 5 | from ..base import Module 6 | 7 | 8 | class SiLU(Module): 9 | def __call__(self, x: Tensor) -> Tensor: 10 | return x.sigmoid().mul_(x) 11 | -------------------------------------------------------------------------------- /sd_fused/layers/auto_encoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .encoder import Encoder 2 | from .decoder import Decoder 3 | -------------------------------------------------------------------------------- /sd_fused/layers/auto_encoder/decoder.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from torch import Tensor 4 | 5 | from ..base import Module, ModuleList 6 | from ..basic import Conv2d 7 | from ..blocks.basic import GroupNormSiLUConv2d 8 | from ..blocks.spatial import UpDecoderBlock2D, UNetMidBlock2DSelfAttention 9 | 10 | 11 | class Decoder(Module): 12 | def __init__( 13 | self, 14 | *, 15 | in_channels: int, 16 | out_channels: int, 17 | block_out_channels: tuple[int, ...], 18 | layers_per_block: int, 19 | resnet_groups: int, 20 | ) -> None: 21 | super().__init__() 22 | 23 | self.in_channels = in_channels 24 | self.out_channels = out_channels 25 | self.block_out_channels = block_out_channels 26 | self.layers_per_block = layers_per_block 27 | self.resnet_groups = resnet_groups 28 | 29 | num_blocks = len(block_out_channels) 30 | 31 | self.pre_process = Conv2d(in_channels, block_out_channels[-1], kernel_size=3, padding=1) 32 | 33 | # mid 34 | self.mid_block = UNetMidBlock2DSelfAttention( 35 | in_channels=block_out_channels[-1], 36 | temb_channels=None, 37 | num_layers=1, 38 | resnet_groups=resnet_groups, 39 | attn_num_head_channels=None, 40 | ) 41 | 42 | # up 43 | reversed_block_out_channels = list(reversed(block_out_channels)) 44 | output_channel = reversed_block_out_channels[0] 45 | self.up_blocks = ModuleList[UpDecoderBlock2D]() 46 | for i in range(num_blocks): 47 | is_final_block = i == num_blocks - 1 48 | 49 | prev_output_channel = output_channel 50 | output_channel = reversed_block_out_channels[i] 51 | 52 | block = UpDecoderBlock2D( 53 | in_channels=prev_output_channel, 54 | out_channels=output_channel, 55 | num_layers=layers_per_block + 1, 56 | resnet_groups=resnet_groups, 57 | add_upsample=not is_final_block, 58 | ) 59 | 60 | self.up_blocks.append(block) 61 | prev_output_channel = output_channel 62 | 63 | # out 64 | self.post_process = GroupNormSiLUConv2d( 65 | resnet_groups, 66 | block_out_channels[0], 67 | out_channels, 68 | kernel_size=3, 69 | padding=1, 70 | ) 71 | 72 | def __call__(self, x: Tensor) -> Tensor: 73 | x = self.pre_process(x) 74 | x = self.mid_block(x) 75 | 76 | for up_block in self.up_blocks: 77 | x = up_block(x) 78 | 79 | return self.post_process(x) 80 | -------------------------------------------------------------------------------- /sd_fused/layers/auto_encoder/encoder.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from torch import Tensor 4 | 5 | from ..base import Module, ModuleList 6 | from ..basic import Conv2d 7 | from ..blocks.basic import GroupNormSiLUConv2d 8 | from ..blocks.spatial import DownEncoderBlock2D, UNetMidBlock2DSelfAttention 9 | 10 | 11 | class Encoder(Module): 12 | def __init__( 13 | self, 14 | *, 15 | in_channels: int, 16 | out_channels: int, 17 | block_out_channels: tuple[int, ...], 18 | layers_per_block: int, 19 | resnet_groups: int, 20 | double_z: bool, 21 | ) -> None: 22 | super().__init__() 23 | 24 | self.in_channels = in_channels 25 | self.out_channels = out_channels 26 | self.block_out_channels = block_out_channels 27 | self.layers_per_block = layers_per_block 28 | self.resnet_groups = resnet_groups 29 | self.double_z = double_z 30 | 31 | num_blocks = len(block_out_channels) 32 | 33 | self.pre_process = Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=1) 34 | 35 | # down 36 | output_channel = block_out_channels[0] 37 | self.down_blocks = ModuleList[DownEncoderBlock2D]() 38 | for i in range(num_blocks): 39 | is_final_block = i == num_blocks - 1 40 | 41 | input_channel = output_channel 42 | output_channel = block_out_channels[i] 43 | 44 | block = DownEncoderBlock2D( 45 | in_channels=input_channel, 46 | out_channels=output_channel, 47 | num_layers=layers_per_block, 48 | resnet_groups=resnet_groups, 49 | downsample_padding=0, 50 | add_downsample=not is_final_block, 51 | ) 52 | self.down_blocks.append(block) 53 | 54 | # mid 55 | self.mid_block = UNetMidBlock2DSelfAttention( 56 | in_channels=block_out_channels[-1], 57 | temb_channels=None, 58 | num_layers=1, 59 | resnet_groups=resnet_groups, 60 | attn_num_head_channels=None, 61 | ) 62 | 63 | # out 64 | conv_out_channels = 2 * out_channels if double_z else out_channels 65 | self.post_process = GroupNormSiLUConv2d( 66 | resnet_groups, 67 | block_out_channels[-1], 68 | conv_out_channels, 69 | kernel_size=3, 70 | padding=1, 71 | ) 72 | 73 | def __call__(self, x: Tensor) -> Tensor: 74 | x = self.pre_process(x) 75 | 76 | for down_block in self.down_blocks: 77 | x = down_block(x) 78 | 79 | x = self.mid_block(x) 80 | 81 | return self.post_process(x) 82 | -------------------------------------------------------------------------------- /sd_fused/layers/base/__init__.py: -------------------------------------------------------------------------------- 1 | from .module import Module 2 | from .sequential import Sequential 3 | from .module_list import ModuleList 4 | -------------------------------------------------------------------------------- /sd_fused/layers/base/base.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import torch 4 | 5 | from .types import Device 6 | 7 | 8 | class Base: 9 | dtype: torch.dtype = torch.float16 10 | device: Device = "cuda" 11 | -------------------------------------------------------------------------------- /sd_fused/layers/base/module.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Optional 3 | from typing_extensions import Self 4 | 5 | from abc import ABC 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch import Tensor 10 | 11 | from .base import Base 12 | from .types import Device 13 | 14 | 15 | class Module(Base, ABC): 16 | # TODO Types? 17 | def __call__(self, *args, **kwargs) -> Tensor: 18 | raise NotImplementedError 19 | 20 | def named_modules(self) -> dict[str, Module]: 21 | modules: dict[str, Module] = {} 22 | 23 | modules[""] = self 24 | for key, value in self.__dict__.items(): 25 | if isinstance(value, Module): 26 | modules[key] = value 27 | 28 | # recursive call 29 | for sub_key, sub_value in value.named_modules().items(): 30 | if sub_key != "": 31 | modules[f"{key}.{sub_key}"] = sub_value 32 | 33 | return modules 34 | 35 | def state_dict(self) -> dict[str, nn.Parameter]: 36 | params: dict[str, nn.Parameter] = {} 37 | 38 | for key, value in self.__dict__.items(): 39 | # single parameter 40 | if isinstance(value, nn.Parameter): 41 | params[key] = value 42 | 43 | # recursive call 44 | elif isinstance(value, Module): 45 | for sub_key, sub_value in value.state_dict().items(): 46 | params[f"{key}.{sub_key}"] = sub_value 47 | 48 | return params 49 | 50 | def load_state_dict(self, state: dict[str, nn.Parameter] | dict[str, Tensor], strict: bool = False) -> Self: 51 | current_state = self.state_dict() 52 | assert len(current_state) == len(state) 53 | 54 | for key, value in current_state.items(): 55 | assert key in state 56 | 57 | new_value = state[key].data 58 | if strict: 59 | assert new_value.shape == value.shape 60 | 61 | new_value = new_value.to(device=self.device, dtype=self.dtype, non_blocking=True) 62 | value.data = new_value 63 | 64 | return self 65 | 66 | def float(self) -> Self: 67 | return self.to(dtype=torch.float32) 68 | 69 | def half(self) -> Self: 70 | return self.to(dtype=torch.float16) 71 | 72 | def cpu(self) -> Self: 73 | return self.to(device=f"cpu") 74 | 75 | def cuda(self, index: int = 0) -> Self: 76 | return self.to(device=f"cuda:{index}") 77 | 78 | def to(self, *, device: Optional[Device] = None, dtype: Optional[torch.dtype] = None): 79 | if device is not None: 80 | self.device = device 81 | if dtype is not None: 82 | self.dtype = dtype 83 | 84 | for key, value in self.state_dict().items(): 85 | data = value.data 86 | value.data = data.to(device=device, dtype=dtype, non_blocking=True) 87 | 88 | return self 89 | -------------------------------------------------------------------------------- /sd_fused/layers/base/module_list.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Generator, Generic, NoReturn 3 | 4 | import torch.nn as nn 5 | 6 | from ...utils.typing import TypeVar, TypeVarTuple, Unpack 7 | from .module import Module 8 | 9 | 10 | T = TypeVar("T", bound=Module) 11 | Ts = TypeVarTuple("Ts") # ? bound 12 | 13 | 14 | class _ModuleSequence(Module): 15 | layers: tuple[Module, ...] | list[Module] 16 | 17 | def state_dict(self) -> dict[str, nn.Parameter]: 18 | params: dict[str, nn.Parameter] = {} 19 | for index, layer in enumerate(self.layers): 20 | for key, value in layer.state_dict().items(): 21 | params[f"{index}.{key}"] = value 22 | 23 | return params 24 | 25 | def named_modules(self) -> dict[str, Module]: 26 | modules: dict[str, Module] = {} 27 | for index, layer in enumerate(self.layers): 28 | for key, value in layer.named_modules().items(): 29 | name = str(index) if key == "" else f"{index}.{key}" 30 | modules[name] = value 31 | 32 | return modules 33 | 34 | def __call__(self) -> NoReturn: 35 | raise ValueError(f"{self.__class__.__qualname__} is not callable.") 36 | 37 | 38 | class ModuleList(_ModuleSequence, Generic[T]): 39 | layers: list[T] 40 | 41 | def __init__(self, *layers: T) -> None: 42 | self.layers = list(layers) 43 | 44 | def append(self, layer: T) -> None: 45 | self.layers.append(layer) 46 | 47 | def __iter__(self) -> Generator[T, None, None]: 48 | for layer in self.layers: 49 | yield layer 50 | 51 | 52 | # # ! only for debug 53 | # class ModuleTuple(Module, Generic[Unpack[Ts]]): 54 | # layers: tuple[Unpack[Ts]] 55 | 56 | # def __init__(self, *layers: Unpack[Ts]) -> None: 57 | # self.layers = layers 58 | 59 | # def __iter__(self): # ? type? 60 | # for layer in self.layers: 61 | # yield layer 62 | -------------------------------------------------------------------------------- /sd_fused/layers/base/sequential.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from torch import Tensor 4 | 5 | from .module import Module 6 | from .module_list import ModuleList, T 7 | 8 | 9 | class Sequential(ModuleList[T], Module): 10 | def __call__(self, x: Tensor) -> Tensor: 11 | for layer in self.layers: 12 | x = layer(x) 13 | 14 | return x 15 | -------------------------------------------------------------------------------- /sd_fused/layers/base/types.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Union 3 | 4 | import torch 5 | 6 | Device = Union[str, torch.device] 7 | -------------------------------------------------------------------------------- /sd_fused/layers/basic/__init__.py: -------------------------------------------------------------------------------- 1 | from .identity import Identity 2 | 3 | from .linear import Linear 4 | from .conv2d import Conv2d 5 | 6 | from .layer_norm import LayerNorm 7 | from .group_norm import GroupNorm 8 | -------------------------------------------------------------------------------- /sd_fused/layers/basic/conv2d.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Optional 3 | 4 | from functools import partial 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch import Tensor 10 | 11 | from ..base import Module 12 | from ..modifiers import HalfWeightsModule, half_weights 13 | 14 | 15 | class Conv2d(HalfWeightsModule, Module): 16 | def __init__( 17 | self, 18 | in_channels: int, 19 | out_channels: Optional[int] = None, 20 | *, 21 | kernel_size: int = 1, 22 | stride: int = 1, 23 | padding: int = 0, 24 | groups: int = 1, 25 | dilation: int = 1, 26 | bias: bool = True, 27 | ) -> None: 28 | assert in_channels % groups == 0 29 | 30 | self.in_channels = in_channels 31 | self.out_channels = out_channels = out_channels or in_channels 32 | self.kernel_size = kernel_size 33 | self.stride = stride 34 | self.padding = padding 35 | self.groups = groups 36 | self.dilation = dilation 37 | 38 | # TODO duplication 39 | empty = partial(torch.empty, dtype=self.dtype, device=self.device) 40 | parameter = partial(nn.Parameter, requires_grad=False) 41 | 42 | w = empty(out_channels, in_channels // groups, kernel_size, kernel_size) 43 | self.weight = parameter(w) 44 | self.bias = parameter(empty(out_channels)) if bias else None 45 | 46 | @half_weights 47 | def __call__(self, x: Tensor) -> Tensor: 48 | return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 49 | -------------------------------------------------------------------------------- /sd_fused/layers/basic/group_norm.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from functools import partial 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch import Tensor 9 | 10 | from ..base import Module 11 | from ..modifiers import HalfWeightsModule, half_weights 12 | 13 | 14 | class GroupNorm(HalfWeightsModule, Module): 15 | def __init__( 16 | self, 17 | num_groups: int, 18 | num_channels: int, 19 | *, 20 | eps: float = 1e-6, 21 | affine: bool = True, 22 | ) -> None: 23 | self.num_groups = num_groups 24 | self.num_channels = num_channels 25 | self.eps = eps 26 | self.affine = affine 27 | 28 | empty = partial(torch.empty, dtype=self.dtype, device=self.device) 29 | parameter = partial(nn.Parameter, requires_grad=False) 30 | 31 | self.weight = parameter(empty(num_channels)) if affine else None 32 | self.bias = parameter(empty(num_channels)) if affine else None 33 | 34 | @half_weights 35 | def __call__(self, x: Tensor) -> Tensor: 36 | return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) 37 | -------------------------------------------------------------------------------- /sd_fused/layers/basic/identity.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from torch import Tensor 4 | 5 | from ..base import Module 6 | 7 | 8 | class Identity(Module): 9 | def __call__(self, x: Tensor) -> Tensor: 10 | return x 11 | -------------------------------------------------------------------------------- /sd_fused/layers/basic/layer_norm.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from functools import partial 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch import Tensor 9 | 10 | from ..base import Module 11 | from ..modifiers import HalfWeightsModule, half_weights 12 | 13 | 14 | class LayerNorm(HalfWeightsModule, Module): 15 | def __init__( 16 | self, 17 | shape: int | tuple[int, ...], 18 | *, 19 | eps: float = 1e-6, 20 | elementwise_affine: bool = True, 21 | ) -> None: 22 | self.shape = shape = shape if isinstance(shape, tuple) else (shape,) 23 | self.eps = eps 24 | self.elementwise_affine = elementwise_affine 25 | 26 | empty = partial(torch.empty, dtype=self.dtype, device=self.device) 27 | parameter = partial(nn.Parameter, requires_grad=False) 28 | 29 | self.weight = parameter(empty(shape)) if elementwise_affine else NotImplemented 30 | self.bias = parameter(empty(shape)) if elementwise_affine else NotImplemented 31 | 32 | @half_weights 33 | def __call__(self, x: Tensor) -> Tensor: 34 | return F.layer_norm(x, self.shape, self.weight, self.bias, self.eps) 35 | -------------------------------------------------------------------------------- /sd_fused/layers/basic/linear.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Optional 3 | 4 | from functools import partial 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch import Tensor 10 | 11 | from ..base import Module 12 | from ..modifiers import HalfWeightsModule, half_weights 13 | 14 | 15 | class Linear(HalfWeightsModule, Module): 16 | def __init__( 17 | self, 18 | in_features: int, 19 | out_features: Optional[int] = None, 20 | *, 21 | bias: bool = True, 22 | ) -> None: 23 | self.in_features = in_features 24 | self.out_features = out_features = out_features or in_features 25 | 26 | empty = partial(torch.empty, dtype=self.dtype, device=self.device) 27 | parameter = partial(nn.Parameter, requires_grad=False) 28 | 29 | self.weight = parameter(empty(out_features, in_features)) 30 | self.bias = parameter(empty(out_features)) if bias else None 31 | 32 | @half_weights 33 | def __call__(self, x: Tensor) -> Tensor: 34 | return F.linear(x, self.weight, self.bias) 35 | -------------------------------------------------------------------------------- /sd_fused/layers/blocks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tfernd/sd-fused/50570f983cc00dd4bd0dc415d0b515da638311b0/sd_fused/layers/blocks/__init__.py -------------------------------------------------------------------------------- /sd_fused/layers/blocks/attention/__init__.py: -------------------------------------------------------------------------------- 1 | from .self_attention import SelfAttention 2 | from .cross_attention import CrossAttention 3 | -------------------------------------------------------------------------------- /sd_fused/layers/blocks/attention/base_attention.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Optional 3 | 4 | from torch import Tensor 5 | 6 | from ....utils.typing import Literal 7 | from .compute import attention, ChunkType 8 | 9 | 10 | class BaseAttention: 11 | attention_chunks: Optional[int | Literal["auto"]] = None 12 | chunk_type: Optional[ChunkType] = None 13 | use_flash_attention: bool = False 14 | tome_r: Optional[int | float] = None 15 | 16 | def attention(self, q: Tensor, k: Tensor, v: Tensor, weights: Optional[Tensor] = None) -> Tensor: 17 | return attention( 18 | q, 19 | k, 20 | v, 21 | chunks=self.attention_chunks, 22 | chunk_type=self.chunk_type, 23 | use_flash_attention=self.use_flash_attention, 24 | tome_r=self.tome_r, 25 | weights=weights, 26 | ) 27 | -------------------------------------------------------------------------------- /sd_fused/layers/blocks/attention/compute/__init__.py: -------------------------------------------------------------------------------- 1 | from .attention import attention 2 | from .auto_chunk_size import ChunkType 3 | -------------------------------------------------------------------------------- /sd_fused/layers/blocks/attention/compute/attention.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Optional 3 | 4 | from torch import Tensor 5 | 6 | from .....utils.typing import Literal 7 | from .join_spatial_dim import join_spatial_dim 8 | from .auto_chunk_size import auto_chunk_size, ChunkType 9 | from .standard_attention import standard_attention 10 | from .chunked_attention import batch_chunked_attention, sequence_chunked_attention 11 | from .flash_attention import flash_attention 12 | from .weighted_values import weighted_values 13 | from .tome import token_average 14 | 15 | 16 | def attention( 17 | q: Tensor, # (B, heads, H, W, C) 18 | k: Tensor, # (B, heads, *[H,W] | T', C) 19 | v: Tensor, # (B, heads, *[H,W] | T', C) 20 | *, 21 | weights: Optional[Tensor] = None, # (B, T') 22 | chunks: Optional[int | Literal["auto"]] = None, 23 | chunk_type: Optional[ChunkType] = None, 24 | use_flash_attention: bool = False, 25 | tome_r: Optional[int | float] = None, 26 | ) -> Tensor: 27 | """General attention computation.""" 28 | 29 | # header 30 | is_self_attention = q.shape == k.shape 31 | dtype = q.dtype 32 | B, heads, H, W, C = q.shape 33 | T = H * W 34 | 35 | if weights is not None: 36 | assert not is_self_attention 37 | 38 | v = weighted_values(v, weights) # ?keys is vad? 39 | 40 | q, k, v = join_spatial_dim(q, k, v) 41 | Tl = k.size(2) 42 | 43 | chunks = auto_chunk_size(chunks, B, heads, T, Tl, C, dtype, chunk_type) 44 | 45 | if is_self_attention and tome_r is not None: 46 | k, v, bias = token_average(k, v, tome_r) 47 | bias = None # ! not used for now. 48 | 49 | if chunks is not None: 50 | assert not use_flash_attention 51 | 52 | if chunk_type is None or chunk_type == "batch": 53 | out = batch_chunked_attention(q, k, v, chunks, bias) 54 | else: 55 | out = sequence_chunked_attention(q, k, v, chunks, bias) 56 | 57 | elif use_flash_attention: 58 | assert chunk_type is None 59 | assert tome_r is None 60 | 61 | out = flash_attention(q, k, v, bias) 62 | else: 63 | out = standard_attention(q, k, v, bias) 64 | 65 | out = out.unflatten(2, (H, W)) # separate-spatial-dim 66 | 67 | return out 68 | -------------------------------------------------------------------------------- /sd_fused/layers/blocks/attention/compute/auto_chunk_size.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Optional 3 | 4 | import torch 5 | 6 | from .....utils.typing import Literal 7 | from .....utils.cuda import free_memory 8 | 9 | 10 | ChunkType = Literal["batch", "sequence"] 11 | 12 | 13 | def auto_chunk_size( 14 | chunks: Optional[int | Literal["auto"]], 15 | B: int, 16 | heads: int, 17 | T: int, 18 | Tl: int, 19 | C: int, 20 | dtype: torch.dtype, 21 | chunk_type: Optional[ChunkType], 22 | ) -> Optional[int]: 23 | """Determine the maximum chunk size according to the available free memory.""" 24 | 25 | if chunks != "auto": 26 | return chunks 27 | 28 | B *= heads # ! ugly but... 29 | 30 | assert chunk_type is not None 31 | assert dtype in (torch.float32, torch.float16) 32 | 33 | num_bytes = 2 if dtype == torch.float16 else 4 34 | free = free_memory() 35 | 36 | if chunk_type is None or chunk_type == "batch": 37 | # Memory used: (2*Bchunks*T*Tl + Bchunks*T*C) * num_bytes 38 | Bchunks = free // (num_bytes * T * (C + 2 * Tl)) 39 | 40 | if Bchunks >= B: 41 | return None 42 | return Bchunks 43 | 44 | # Memory used: (2*B*Tchunk*Tl + B*Tchunk*C) * num_bytes 45 | Tchunks = free // (num_bytes * B * (C + 2 * Tl)) 46 | 47 | if Tchunks >= T: 48 | return None 49 | return Tchunks 50 | -------------------------------------------------------------------------------- /sd_fused/layers/blocks/attention/compute/chunked_attention.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Optional 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import Tensor 7 | 8 | from .scale_qk import scale_qk 9 | 10 | 11 | def batch_chunked_attention( 12 | q: Tensor, # (B, heads, T, C) 13 | k: Tensor, # (B, heads, T', C) 14 | v: Tensor, # (B, heads, T', C) 15 | chunks: int, 16 | bias: Optional[Tensor] = None, # (B, heads, T', C) 17 | ) -> Tensor: 18 | """Batch-chunked attention computation.""" 19 | 20 | assert chunks >= 1 21 | B, heads, T, C = q.shape 22 | 23 | # join batch-heads 24 | q = q.flatten(0, 1) 25 | k = k.flatten(0, 1) 26 | v = v.flatten(0, 1) 27 | 28 | q, k = scale_qk(q, k) 29 | kT = k.transpose(-1, -2) 30 | 31 | out = torch.empty_like(q) 32 | for i in range(0, B * heads, chunks): 33 | s = slice(i, min(i + chunks, B * heads)) 34 | 35 | score = q[s] @ kT[s] 36 | if bias is not None: 37 | score += bias[s] 38 | attn = F.softmax(score, dim=-1, dtype=q.dtype) 39 | del score 40 | 41 | out[s] = attn @ v[s] 42 | del attn 43 | 44 | return out.unflatten(0, (B, heads)) 45 | 46 | 47 | def sequence_chunked_attention( 48 | q: Tensor, # (B, heads, T, C) 49 | k: Tensor, # (B, heads, T', C) 50 | v: Tensor, # (B, heads, T', C) 51 | chunks: int, 52 | bias: Optional[Tensor] = None, # (B, heads, T', C) 53 | ) -> Tensor: 54 | """Sequence-chunked attention computation.""" 55 | 56 | # https://github.com/Doggettx/stable-diffusion/blob/main/ldm/modules/attention.py#L209 57 | 58 | assert chunks >= 1 59 | B, heads, T, C = q.shape 60 | 61 | # join batch-heads 62 | q, k, v = map(lambda x: x.flatten(0, 1), (q, k, v)) 63 | 64 | q, k = scale_qk(q, k) 65 | kT = k.transpose(-1, -2) 66 | 67 | out = torch.empty_like(q) 68 | for i in range(0, T, chunks): 69 | s = slice(i, min(i + chunks, T)) 70 | 71 | score = q[:, s] @ kT 72 | if bias is not None: 73 | score += bias[:, s] 74 | attn = F.softmax(score, dim=-1, dtype=q.dtype) 75 | del score 76 | 77 | out[:, s] = attn @ v 78 | del attn 79 | 80 | return out.unflatten(0, (B, heads)) 81 | -------------------------------------------------------------------------------- /sd_fused/layers/blocks/attention/compute/flash_attention.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Optional 3 | 4 | from torch import Tensor 5 | 6 | try: 7 | from xformers.ops import memory_efficient_attention # type: ignore 8 | except ImportError: 9 | memory_efficient_attention = None 10 | 11 | 12 | def flash_attention( 13 | q: Tensor, # (B, heads, T, C) 14 | k: Tensor, # (B, heads, T', C) 15 | v: Tensor, # (B, heads, T', C) 16 | bias: Optional[Tensor] = None, # (B, heads, T', C) 17 | ) -> Tensor: 18 | """xformers flash-attention computation.""" 19 | 20 | assert memory_efficient_attention is not None 21 | 22 | q, k, v = map(lambda x: x.contiguous(), (q, k, v)) 23 | bias = bias.contiguous() if bias is not None else None 24 | 25 | out = memory_efficient_attention(q, k, v, bias) 26 | 27 | return out 28 | -------------------------------------------------------------------------------- /sd_fused/layers/blocks/attention/compute/join_spatial_dim.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from torch import Tensor 4 | 5 | 6 | def join_spatial_dim(q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor, Tensor]: 7 | "Join height-width spatial dimensions of qkv." 8 | 9 | is_self_attention = q.ndim == k.ndim 10 | 11 | q = q.flatten(2, 3) 12 | if is_self_attention: 13 | k = k.flatten(2, 3) 14 | v = v.flatten(2, 3) 15 | 16 | return q, k, v 17 | -------------------------------------------------------------------------------- /sd_fused/layers/blocks/attention/compute/scale_qk.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import math 4 | 5 | from torch import Tensor 6 | 7 | 8 | def scale_qk(q: Tensor, k: Tensor) -> tuple[Tensor, Tensor]: 9 | "Scale the query and key." 10 | 11 | C = q.size(-1) 12 | scale = math.pow(C, -1 / 4) 13 | 14 | return q * scale, k * scale 15 | -------------------------------------------------------------------------------- /sd_fused/layers/blocks/attention/compute/standard_attention.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Optional 3 | 4 | import torch.nn.functional as F 5 | from torch import Tensor 6 | 7 | from .scale_qk import scale_qk 8 | 9 | 10 | def standard_attention( 11 | q: Tensor, # (B, heads, T, C) 12 | k: Tensor, # (B, heads, T', C) 13 | v: Tensor, # (B, heads, T', C) 14 | bias: Optional[Tensor] = None, # (B, heads, T', C) 15 | ) -> Tensor: 16 | """Standard attention computation.""" 17 | 18 | q, k = scale_qk(q, k) 19 | score = q @ k.transpose(-1, -2) 20 | if bias is not None: 21 | score += bias 22 | attn = F.softmax(score, dim=-1, dtype=q.dtype) 23 | del score 24 | 25 | return attn @ v 26 | -------------------------------------------------------------------------------- /sd_fused/layers/blocks/attention/compute/tome.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Optional 3 | 4 | import torch 5 | from torch import Tensor 6 | import math 7 | 8 | from .....utils.typing import Literal, Protocol 9 | 10 | Modes = Literal["sum", "mean"] 11 | 12 | 13 | class Merge(Protocol): 14 | def __call__(self, x: Tensor, mode: Modes = "mean") -> Tensor: 15 | ... 16 | 17 | 18 | def tome(metric: Tensor, r: int | float) -> Merge: 19 | # https://arxiv.org/abs/2210.09461 20 | # https://github.com/facebookresearch/ToMe/blob/main/tome/merge.py 21 | 22 | B, T, C = metric.shape 23 | 24 | if not isinstance(r, int): 25 | r = math.floor(r * T) 26 | 27 | assert 0 < r <= T / 2 # max 50% reduction 28 | 29 | with torch.no_grad(): 30 | metric = metric / metric.norm(dim=2, keepdim=True) 31 | a, b = metric[:, ::2], metric[:, 1::2] 32 | score = a @ b.transpose(1, 2) 33 | del a, b 34 | 35 | node_max, node_idx = score.max(dim=2, keepdim=True) 36 | del score 37 | edge_idx = node_max.argsort(dim=1, descending=True) 38 | 39 | # Unmerged/Merged Tokens 40 | size = (math.ceil(T / 2) - r, r) 41 | unmerged_idx, src_idx = edge_idx.split(size, dim=1) 42 | 43 | dst_idx = node_idx.gather(dim=1, index=src_idx) 44 | 45 | def merge(x: Tensor, mode: Modes = "mean") -> Tensor: 46 | src, dst = x[:, ::2], x[:, 1::2] 47 | B, T, C = src.shape 48 | 49 | unmerged = src.gather(dim=1, index=unmerged_idx.expand(-1, -1, C)) 50 | src = src.gather(dim=1, index=src_idx.expand(-1, -1, C)) 51 | dst = dst.scatter_reduce(1, dst_idx.expand(-1, -1, C), src, reduce=mode) 52 | 53 | return torch.cat([unmerged, dst], dim=1) 54 | 55 | return merge 56 | 57 | 58 | def merge_weighted_average( 59 | merge: Merge, 60 | x: Tensor, 61 | size: Optional[Tensor] = None, 62 | ) -> tuple[Tensor, Tensor]: 63 | B, T, C = x.shape 64 | 65 | if size is None: 66 | size = torch.ones(B, T, 1, dtype=x.dtype, device=x.device) 67 | 68 | x = merge(x * size, mode="sum") 69 | size = merge(size, mode="sum") 70 | 71 | x = x / size 72 | 73 | return x, size 74 | 75 | 76 | def token_average( 77 | k: Tensor, # (B, heads, T', C) 78 | v: Tensor, # (B, heads, T', C) 79 | r: int | float, 80 | ) -> tuple[Tensor, Tensor, Tensor]: 81 | B, heads, Tl, C = k.shape 82 | 83 | # join batch-heads 84 | k = k.flatten(0, 1) 85 | v = v.flatten(0, 1) 86 | 87 | merge = tome(k, r) 88 | 89 | k, size = merge_weighted_average(merge, k) 90 | v, size = merge_weighted_average(merge, v) 91 | bias = size.log() 92 | 93 | # separate batch-heads 94 | k = k.unflatten(0, (B, heads)) 95 | v = v.unflatten(0, (B, heads)) 96 | bias = bias.unflatten(0, (B, heads)) 97 | 98 | return k, v, bias 99 | -------------------------------------------------------------------------------- /sd_fused/layers/blocks/attention/compute/weighted_values.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Optional 3 | 4 | from torch import Tensor 5 | 6 | from einops import rearrange 7 | 8 | 9 | def weighted_values( 10 | x: Tensor, # (B, heads, T, C) 11 | weights: Optional[Tensor] = None, # (B, T) 12 | ) -> Tensor: 13 | "Modify the values to give emphasis to some tokens." 14 | 15 | if weights is None: 16 | return x 17 | 18 | weights = rearrange(weights, "B T -> B 1 T 1") 19 | 20 | return x * weights 21 | -------------------------------------------------------------------------------- /sd_fused/layers/blocks/attention/cross_attention.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Optional 3 | 4 | from torch import Tensor 5 | 6 | from ...external import Rearrange 7 | from ...base import Module 8 | from ...basic import LayerNorm, Linear 9 | from .base_attention import BaseAttention 10 | 11 | 12 | class CrossAttention(BaseAttention, Module): 13 | def __init__( 14 | self, 15 | *, 16 | query_features: int, 17 | context_features: Optional[int], 18 | head_features: int, 19 | num_heads: int, 20 | ) -> None: 21 | super().__init__() 22 | 23 | is_cross_attention = context_features is not None 24 | 25 | context_features = context_features or query_features 26 | inner_dim = head_features * num_heads 27 | 28 | self.query_features = query_features 29 | self.context_features = context_features 30 | self.head_features = head_features 31 | self.num_heads = num_heads 32 | 33 | # TODO These 4+1 operations can be fused 34 | self.norm = LayerNorm(query_features) 35 | self.to_q = Linear(query_features, inner_dim, bias=False) 36 | self.to_k = Linear(context_features, inner_dim, bias=False) 37 | self.to_v = Linear(context_features, inner_dim, bias=False) 38 | 39 | self.to_out = Linear(inner_dim, query_features) 40 | 41 | self.heads_to_batch = Rearrange("B H W (heads C) -> B heads H W C", heads=num_heads) 42 | self.heads_to_batch2 = self.heads_to_batch 43 | if is_cross_attention: 44 | self.heads_to_batch2 = Rearrange("B T (heads C) -> B heads T C", heads=num_heads) 45 | 46 | self.heads_to_channel = Rearrange("B heads H W C -> B H W (heads C)") 47 | 48 | def __call__( 49 | self, 50 | x: Tensor, 51 | *, 52 | context: Optional[Tensor] = None, 53 | weights: Optional[Tensor] = None, 54 | ) -> Tensor: 55 | 56 | xin = x 57 | x = self.norm(x) 58 | context = context if context is not None else x 59 | 60 | # key, query, value projections 61 | q = self.heads_to_batch(self.to_q(x)) 62 | k = self.heads_to_batch2(self.to_k(context)) 63 | v = self.heads_to_batch2(self.to_v(context)) 64 | 65 | x = self.attention(q, k, v, weights) 66 | del q, k, v 67 | x = self.heads_to_channel(x) 68 | 69 | return xin + self.to_out(x) 70 | -------------------------------------------------------------------------------- /sd_fused/layers/blocks/attention/self_attention.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Optional 3 | 4 | from torch import Tensor 5 | 6 | from ...external import Rearrange 7 | from ...base import Module 8 | from ...basic import GroupNorm, Linear 9 | from .base_attention import BaseAttention 10 | 11 | 12 | class SelfAttention(BaseAttention, Module): 13 | def __init__( 14 | self, 15 | *, 16 | in_features: int, 17 | head_features: Optional[int], 18 | num_groups: int, 19 | ) -> None: 20 | super().__init__() 21 | 22 | self.in_features = in_features 23 | self.head_features = head_features 24 | self.num_groups = num_groups 25 | 26 | head_features = head_features or in_features 27 | num_heads = in_features // head_features 28 | 29 | # TODO These 4+1 operations can be fused 30 | self.group_norm = GroupNorm(num_groups, in_features) 31 | self.query = Linear(in_features) 32 | self.key = Linear(in_features) 33 | self.value = Linear(in_features) 34 | 35 | self.proj_attn = Linear(in_features) 36 | 37 | self.channel_last = Rearrange("B C H W -> B H W C") 38 | self.channel_first = Rearrange("B H W C -> B C H W") 39 | self.heads_to_batch = Rearrange("B H W (heads C) -> B heads H W C", heads=num_heads) 40 | self.heads_to_channel = Rearrange("B heads H W C -> B H W (heads C)") 41 | 42 | def __call__(self, x: Tensor) -> Tensor: 43 | B, C, H, W = x.shape 44 | 45 | xin = x 46 | x = self.group_norm(x) 47 | x = self.channel_last(x) 48 | 49 | # key, query, value projections 50 | q = self.heads_to_batch(self.query(x)) 51 | k = self.heads_to_batch(self.key(x)) 52 | v = self.heads_to_batch(self.value(x)) 53 | 54 | x = self.attention(q, k, v) 55 | del q, k, v 56 | x = self.proj_attn(self.heads_to_channel(x)) 57 | 58 | # output 59 | x = self.channel_first(x, H=H, W=W) 60 | 61 | return xin + x 62 | -------------------------------------------------------------------------------- /sd_fused/layers/blocks/basic/__init__.py: -------------------------------------------------------------------------------- 1 | from .gn_conv import GroupNormConv2d 2 | from .gn_silu_conv import GroupNormSiLUConv2d 3 | 4 | from .ln_geglu_linear import LayerNormGEGLULinear 5 | -------------------------------------------------------------------------------- /sd_fused/layers/blocks/basic/gn_conv.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Optional 3 | 4 | from ...base import Sequential 5 | from ...basic import GroupNorm, Conv2d 6 | 7 | # TODO improve types by using ModuleTuple 8 | class GroupNormConv2d(Sequential): 9 | def __init__( 10 | self, 11 | num_groups: int, 12 | in_channels: int, 13 | out_channels: Optional[int] = None, 14 | *, 15 | kernel_size: int = 1, 16 | padding: int = 0, 17 | ) -> None: 18 | layers = ( 19 | GroupNorm(num_groups, in_channels), 20 | Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding), 21 | ) 22 | 23 | super().__init__(*layers) 24 | -------------------------------------------------------------------------------- /sd_fused/layers/blocks/basic/gn_silu_conv.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Optional 3 | 4 | from ...base import Sequential 5 | from ...basic import GroupNorm, Conv2d 6 | from ...activation import SiLU 7 | 8 | 9 | class GroupNormSiLUConv2d(Sequential): 10 | def __init__( 11 | self, 12 | num_groups: int, 13 | in_channels: int, 14 | out_channels: Optional[int] = None, 15 | *, 16 | kernel_size: int = 1, 17 | padding: int = 0, 18 | ) -> None: 19 | layers = ( 20 | GroupNorm(num_groups, in_channels), 21 | SiLU(), 22 | Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding), 23 | ) 24 | 25 | super().__init__(*layers) 26 | -------------------------------------------------------------------------------- /sd_fused/layers/blocks/basic/ln_geglu_linear.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from torch import Tensor 4 | 5 | from ...base import Sequential 6 | from ...basic import LayerNorm, Linear 7 | from ...activation import GEGLU 8 | 9 | 10 | class LayerNormGEGLULinear(Sequential): 11 | def __init__(self, dim: int, *, expand: float) -> None: 12 | 13 | self.dim = dim 14 | self.expand = expand 15 | 16 | inner_dim = int(dim * expand) 17 | dim = dim or dim 18 | 19 | layers = ( 20 | LayerNorm(dim), 21 | GEGLU(dim, inner_dim), 22 | Linear(inner_dim, dim), 23 | ) 24 | 25 | super().__init__(*layers) 26 | 27 | def __call__(self, x: Tensor) -> Tensor: 28 | return x + super().__call__(x) 29 | -------------------------------------------------------------------------------- /sd_fused/layers/blocks/spatial/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import DownBlock2D, UpBlock2D 2 | 3 | from .resampling import Upsample2D, Downsample2D 4 | 5 | from .ae import DownEncoderBlock2D, UpDecoderBlock2D 6 | from .cross_attention import CrossAttentionUpBlock2D, CrossAttentionDownBlock2D 7 | 8 | from .unet_mid import UNetMidBlock2DSelfAttention, UNetMidBlock2DCrossAttention 9 | from .resnet import ResnetBlock2D 10 | -------------------------------------------------------------------------------- /sd_fused/layers/blocks/spatial/ae/__init__.py: -------------------------------------------------------------------------------- 1 | from .down_encoder import DownEncoderBlock2D 2 | from .up_decoder import UpDecoderBlock2D 3 | -------------------------------------------------------------------------------- /sd_fused/layers/blocks/spatial/ae/down_encoder.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import torch.nn as nn 4 | from torch import Tensor 5 | 6 | from ....base import Module, ModuleList 7 | from ..resampling import Downsample2D 8 | from ..resnet import ResnetBlock2D 9 | 10 | 11 | class DownEncoderBlock2D(Module): 12 | def __init__( 13 | self, 14 | *, 15 | in_channels: int, 16 | out_channels: int, 17 | num_layers: int, 18 | resnet_groups: int, 19 | add_downsample: bool, 20 | downsample_padding: int, 21 | ) -> None: 22 | super().__init__() 23 | 24 | self.in_channels = in_channels 25 | self.out_channels = out_channels 26 | self.num_layers = num_layers 27 | self.resnet_groups = resnet_groups 28 | self.add_downsample = add_downsample 29 | self.downsample_padding = downsample_padding 30 | 31 | self.resnets = ModuleList[ResnetBlock2D]() 32 | for i in range(num_layers): 33 | in_channels = in_channels if i == 0 else out_channels 34 | 35 | self.resnets.append( 36 | ResnetBlock2D( 37 | in_channels=in_channels, 38 | out_channels=out_channels, 39 | temb_channels=None, 40 | num_groups=resnet_groups, 41 | num_out_groups=None, 42 | ) 43 | ) 44 | 45 | if add_downsample: 46 | self.downsampler = Downsample2D(in_channels, out_channels, padding=downsample_padding) 47 | else: 48 | self.downsampler = nn.Identity() 49 | 50 | def __call__(self, x: Tensor) -> Tensor: 51 | for resnet in self.resnets: 52 | x = resnet(x) 53 | 54 | x = self.downsampler(x) 55 | 56 | return x 57 | -------------------------------------------------------------------------------- /sd_fused/layers/blocks/spatial/ae/up_decoder.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import torch.nn as nn 4 | from torch import Tensor 5 | 6 | from ....base import Module, ModuleList 7 | from ..resampling import Upsample2D 8 | from ..resnet import ResnetBlock2D 9 | 10 | 11 | class UpDecoderBlock2D(Module): 12 | def __init__( 13 | self, 14 | *, 15 | in_channels: int, 16 | out_channels: int, 17 | num_layers: int, 18 | resnet_groups: int, 19 | add_upsample: bool, 20 | ) -> None: 21 | super().__init__() 22 | 23 | self.in_channels = in_channels 24 | self.out_channels = out_channels 25 | self.num_layers = num_layers 26 | self.resnet_groups = resnet_groups 27 | self.add_upsample = add_upsample 28 | 29 | self.resnets = ModuleList[ResnetBlock2D]() 30 | for i in range(num_layers): 31 | input_channels = in_channels if i == 0 else out_channels 32 | 33 | self.resnets.append( 34 | ResnetBlock2D( 35 | in_channels=input_channels, 36 | out_channels=out_channels, 37 | temb_channels=None, 38 | num_groups=resnet_groups, 39 | num_out_groups=None, 40 | ) 41 | ) 42 | 43 | if add_upsample: 44 | self.upsampler = Upsample2D(out_channels, out_channels) 45 | else: 46 | self.upsampler = nn.Identity() 47 | 48 | def __call__(self, x: Tensor) -> Tensor: 49 | for resnet in self.resnets: 50 | x = resnet(x) 51 | 52 | x = self.upsampler(x) 53 | 54 | return x 55 | -------------------------------------------------------------------------------- /sd_fused/layers/blocks/spatial/base/__init__.py: -------------------------------------------------------------------------------- 1 | from .down import DownBlock2D 2 | from .up import UpBlock2D 3 | -------------------------------------------------------------------------------- /sd_fused/layers/blocks/spatial/base/down.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Optional 3 | 4 | import torch.nn as nn 5 | from torch import Tensor 6 | 7 | from ....base import Module, ModuleList 8 | from ..resampling import Downsample2D 9 | from ..resnet import ResnetBlock2D 10 | from ..output_states import OutputStates 11 | 12 | 13 | class DownBlock2D(Module): 14 | def __init__( 15 | self, 16 | *, 17 | in_channels: int, 18 | out_channels: int, 19 | temb_channels: Optional[int], 20 | num_layers: int, 21 | num_groups: int, 22 | add_downsample: bool, 23 | downsample_padding: int, 24 | ) -> None: 25 | super().__init__() 26 | 27 | self.in_channels = in_channels 28 | self.out_channels = out_channels 29 | self.temb_channels = temb_channels 30 | self.num_layers = num_layers 31 | self.num_groups = num_groups 32 | self.add_downsample = add_downsample 33 | self.downsample_padding = downsample_padding 34 | 35 | self.resnets = ModuleList[ResnetBlock2D]() 36 | for i in range(num_layers): 37 | in_channels = in_channels if i == 0 else out_channels 38 | 39 | self.resnets.append( 40 | ResnetBlock2D( 41 | in_channels=in_channels, 42 | out_channels=out_channels, 43 | temb_channels=temb_channels, 44 | num_groups=num_groups, 45 | num_out_groups=None, 46 | ) 47 | ) 48 | 49 | if add_downsample: 50 | self.downsampler = Downsample2D(in_channels, out_channels, padding=downsample_padding) 51 | else: 52 | self.downsampler = None 53 | 54 | def __call__( 55 | self, 56 | x: Tensor, 57 | *, 58 | temb: Optional[Tensor] = None, 59 | ) -> OutputStates: 60 | states: list[Tensor] = [] 61 | for resnet in self.resnets: 62 | x = resnet(x, temb=temb) 63 | states.append(x) 64 | 65 | if self.downsampler is not None: 66 | x = self.downsampler(x) 67 | 68 | states.append(x) 69 | 70 | return OutputStates(x, states) 71 | -------------------------------------------------------------------------------- /sd_fused/layers/blocks/spatial/base/up.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Optional 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch import Tensor 7 | 8 | from ....base import Module, ModuleList 9 | from ..resampling import Upsample2D 10 | from ..resnet import ResnetBlock2D 11 | 12 | 13 | class UpBlock2D(Module): 14 | def __init__( 15 | self, 16 | *, 17 | in_channels: int, 18 | prev_output_channel: int, 19 | out_channels: int, 20 | temb_channels: Optional[int], 21 | num_layers: int, 22 | resnet_groups: int, 23 | add_upsample: bool, 24 | ) -> None: 25 | super().__init__() 26 | 27 | self.in_channels = in_channels 28 | self.prev_output_channel = prev_output_channel 29 | self.out_channels = out_channels 30 | self.temb_channels = temb_channels 31 | self.num_layers = num_layers 32 | self.resnet_groups = resnet_groups 33 | self.add_upsample = add_upsample 34 | 35 | self.resnets = ModuleList[ResnetBlock2D]() 36 | for i in range(num_layers): 37 | if i == num_layers - 1: 38 | res_skip_channels = in_channels 39 | else: 40 | res_skip_channels = out_channels 41 | 42 | if i == 0: 43 | resnet_in_channels = prev_output_channel 44 | else: 45 | resnet_in_channels = out_channels 46 | 47 | self.resnets.append( 48 | ResnetBlock2D( 49 | in_channels=resnet_in_channels + res_skip_channels, 50 | out_channels=out_channels, 51 | temb_channels=temb_channels, 52 | num_groups=resnet_groups, 53 | num_out_groups=None, 54 | ) 55 | ) 56 | 57 | if add_upsample: 58 | self.upsampler = Upsample2D(out_channels) 59 | else: 60 | self.upsampler = nn.Identity() 61 | 62 | def __call__( 63 | self, 64 | x: Tensor, 65 | *, 66 | states: list[Tensor], 67 | temb: Optional[Tensor] = None, 68 | ) -> Tensor: 69 | assert len(states) == self.num_layers 70 | 71 | for resnet, state in zip(self.resnets, states): 72 | x = torch.cat([x, state], dim=1) 73 | x = resnet(x, temb=temb) 74 | 75 | x = self.upsampler(x) 76 | 77 | return x 78 | -------------------------------------------------------------------------------- /sd_fused/layers/blocks/spatial/cross_attention/__init__.py: -------------------------------------------------------------------------------- 1 | from .up import CrossAttentionUpBlock2D 2 | from .down import CrossAttentionDownBlock2D 3 | -------------------------------------------------------------------------------- /sd_fused/layers/blocks/spatial/cross_attention/down.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Optional 3 | 4 | import torch.nn as nn 5 | from torch import Tensor 6 | 7 | from ....base import Module, ModuleList 8 | from ...transformer import SpatialTransformer 9 | from ..resampling import Downsample2D 10 | from ..resnet import ResnetBlock2D 11 | from ..output_states import OutputStates 12 | 13 | 14 | class CrossAttentionDownBlock2D(Module): 15 | def __init__( 16 | self, 17 | *, 18 | in_channels: int, 19 | out_channels: int, 20 | temb_channels: int, 21 | num_layers: int, 22 | resnet_groups: int, 23 | attn_num_head_channels: int, 24 | cross_attention_dim: int, 25 | downsample_padding: int, 26 | add_downsample: bool, 27 | ) -> None: 28 | super().__init__() 29 | 30 | self.in_channels = in_channels 31 | self.out_channels = out_channels 32 | self.temb_channels = temb_channels 33 | self.num_layers = num_layers 34 | self.resnet_groups = resnet_groups 35 | self.attn_num_head_channels = attn_num_head_channels 36 | self.cross_attention_dim = cross_attention_dim 37 | self.downsample_padding = downsample_padding 38 | self.add_downsample = add_downsample 39 | 40 | self.resnets = ModuleList[ResnetBlock2D]() 41 | self.attentions = ModuleList[SpatialTransformer]() 42 | for i in range(num_layers): 43 | in_channels = in_channels if i == 0 else out_channels 44 | 45 | self.resnets.append( 46 | ResnetBlock2D( 47 | in_channels=in_channels, 48 | out_channels=out_channels, 49 | temb_channels=temb_channels, 50 | num_groups=resnet_groups, 51 | num_out_groups=None, 52 | ) 53 | ) 54 | 55 | self.attentions.append( 56 | SpatialTransformer( 57 | in_channels=out_channels, 58 | num_heads=attn_num_head_channels, 59 | head_features=out_channels // attn_num_head_channels, 60 | depth=1, 61 | num_groups=resnet_groups, 62 | context_features=cross_attention_dim, 63 | ) 64 | ) 65 | 66 | if add_downsample: 67 | self.downsampler = Downsample2D(in_channels, out_channels, padding=downsample_padding) 68 | else: 69 | self.downsampler = None 70 | 71 | def __call__( 72 | self, 73 | x: Tensor, 74 | *, 75 | temb: Optional[Tensor] = None, 76 | context: Optional[Tensor] = None, 77 | weights: Optional[Tensor] = None, 78 | ) -> OutputStates: 79 | states: list[Tensor] = [] 80 | for resnet, attn in zip(self.resnets, self.attentions): 81 | x = resnet(x, temb=temb) 82 | x = attn(x, context=context, weights=weights) 83 | 84 | states.append(x) 85 | 86 | if self.downsampler is not None: 87 | x = self.downsampler(x) 88 | 89 | states.append(x) 90 | 91 | return OutputStates(x, states) 92 | -------------------------------------------------------------------------------- /sd_fused/layers/blocks/spatial/cross_attention/up.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Optional 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch import Tensor 7 | 8 | from ....base import Module, ModuleList 9 | from ...transformer import SpatialTransformer 10 | from ..resampling import Upsample2D 11 | from ..resnet import ResnetBlock2D 12 | 13 | 14 | class CrossAttentionUpBlock2D(Module): 15 | def __init__( 16 | self, 17 | *, 18 | in_channels: int, 19 | out_channels: int, 20 | prev_output_channel: int, 21 | temb_channels: int, 22 | num_layers: int, 23 | resnet_groups: int, 24 | attn_num_head_channels: int, 25 | cross_attention_dim: int, 26 | add_upsample: bool, 27 | ) -> None: 28 | super().__init__() 29 | 30 | self.in_channels = in_channels 31 | self.out_channels = out_channels 32 | self.prev_output_channel = prev_output_channel 33 | self.temb_channels = temb_channels 34 | self.num_layers = num_layers 35 | self.resnet_groups = resnet_groups 36 | self.attn_num_head_channels = attn_num_head_channels 37 | self.cross_attention_dim = cross_attention_dim 38 | self.add_upsample = add_upsample 39 | 40 | self.resnets = ModuleList[ResnetBlock2D]() 41 | self.attentions = ModuleList[SpatialTransformer]() 42 | for i in range(num_layers): 43 | if i == num_layers - 1: 44 | res_skip_channels = in_channels 45 | else: 46 | res_skip_channels = out_channels 47 | 48 | if i == 0: 49 | resnet_in_channels = prev_output_channel 50 | else: 51 | resnet_in_channels = out_channels 52 | 53 | self.resnets.append( 54 | ResnetBlock2D( 55 | in_channels=resnet_in_channels + res_skip_channels, 56 | out_channels=out_channels, 57 | temb_channels=temb_channels, 58 | num_groups=resnet_groups, 59 | num_out_groups=None, 60 | ) 61 | ) 62 | 63 | self.attentions.append( 64 | SpatialTransformer( 65 | in_channels=out_channels, 66 | num_heads=attn_num_head_channels, 67 | head_features=out_channels // attn_num_head_channels, 68 | depth=1, 69 | num_groups=resnet_groups, 70 | context_features=cross_attention_dim, 71 | ) 72 | ) 73 | 74 | if add_upsample: 75 | self.upsampler = Upsample2D(out_channels) 76 | else: 77 | self.upsampler = nn.Identity() 78 | 79 | def __call__( 80 | self, 81 | x: Tensor, 82 | *, 83 | states: list[Tensor], 84 | temb: Optional[Tensor] = None, 85 | context: Optional[Tensor] = None, 86 | weights: Optional[Tensor] = None, 87 | ) -> Tensor: 88 | assert len(states) == self.num_layers 89 | 90 | for resnet, attn, state in zip(self.resnets, self.attentions, states): 91 | x = torch.cat([x, state], dim=1) 92 | x = resnet(x, temb=temb) 93 | x = attn(x, context=context, weights=weights) 94 | 95 | x = self.upsampler(x) 96 | 97 | return x 98 | -------------------------------------------------------------------------------- /sd_fused/layers/blocks/spatial/output_states.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import NamedTuple 3 | 4 | from torch import Tensor 5 | 6 | 7 | class OutputStates(NamedTuple): 8 | x: Tensor 9 | states: list[Tensor] 10 | -------------------------------------------------------------------------------- /sd_fused/layers/blocks/spatial/resampling/__init__.py: -------------------------------------------------------------------------------- 1 | from .upsample2d import Upsample2D 2 | from .downsample2d import Downsample2D 3 | -------------------------------------------------------------------------------- /sd_fused/layers/blocks/spatial/resampling/downsample2d.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Callable, Optional 3 | 4 | from functools import partial 5 | 6 | import torch.nn.functional as F 7 | from torch import Tensor 8 | 9 | from ....base import Module 10 | from ....basic import Conv2d 11 | 12 | 13 | class Downsample2D(Module): 14 | pad: Callable[[Tensor], Tensor] 15 | 16 | def __init__( 17 | self, 18 | in_channels: int, 19 | out_channels: Optional[int], 20 | *, 21 | padding: int, 22 | ) -> None: 23 | super().__init__() 24 | 25 | self.channels = in_channels 26 | self.out_channels = out_channels 27 | self.padding = padding 28 | 29 | out_channels = out_channels or in_channels 30 | stride = 2 31 | 32 | self.conv = Conv2d( 33 | in_channels, 34 | out_channels, 35 | kernel_size=3, 36 | stride=stride, 37 | padding=padding, 38 | ) 39 | 40 | self.pad = lambda x: x 41 | if padding == 0: 42 | # ? Why? 43 | self.pad = partial( 44 | F.pad, 45 | pad=(0, 1, 0, 1), 46 | mode="constant", 47 | value=0, 48 | ) 49 | 50 | def __call__(self, x: Tensor) -> Tensor: 51 | x = self.pad(x) 52 | 53 | return self.conv(x) 54 | -------------------------------------------------------------------------------- /sd_fused/layers/blocks/spatial/resampling/upsample2d.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Optional 3 | 4 | from functools import partial 5 | 6 | import torch.nn.functional as F 7 | from torch import Tensor 8 | 9 | from ....base import Module 10 | from ....basic import Conv2d 11 | 12 | 13 | class Upsample2D(Module): 14 | def __init__( 15 | self, 16 | channels: int, 17 | out_channels: Optional[int] = None, 18 | ) -> None: 19 | super().__init__() 20 | 21 | self.channels = channels 22 | self.out_channels = out_channels 23 | 24 | out_channels = out_channels or channels 25 | 26 | self.conv = Conv2d(channels, out_channels, kernel_size=3, padding=1) 27 | self.upscale = partial(F.interpolate, mode="nearest", scale_factor=2) 28 | 29 | def __call__(self, x: Tensor) -> Tensor: 30 | x = self.upscale(x) 31 | 32 | return self.conv(x) 33 | -------------------------------------------------------------------------------- /sd_fused/layers/blocks/spatial/resnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Optional 3 | 4 | from torch import Tensor 5 | 6 | from ...external import Rearrange 7 | from ...base import Module, Sequential 8 | from ...activation import SiLU 9 | from ...basic import Conv2d, Linear, Identity 10 | from ..basic import GroupNormSiLUConv2d 11 | 12 | 13 | class ResnetBlock2D(Module): 14 | def __init__( 15 | self, 16 | *, 17 | in_channels: int, 18 | out_channels: Optional[int], 19 | temb_channels: Optional[int], 20 | num_groups: int, 21 | num_out_groups: Optional[int], 22 | ) -> None: 23 | super().__init__() 24 | 25 | self.in_channels = in_channels 26 | self.out_channels = out_channels 27 | self.temb_channels = temb_channels 28 | self.num_groups = num_groups 29 | self.num_out_groups = num_out_groups 30 | 31 | out_channels = out_channels or in_channels 32 | num_out_groups = num_out_groups or num_groups 33 | 34 | if in_channels != out_channels: 35 | self.conv_shortcut = Conv2d(in_channels, out_channels) 36 | else: 37 | self.conv_shortcut = Identity() 38 | 39 | self.pre_process = GroupNormSiLUConv2d(num_groups, in_channels, out_channels, kernel_size=3, padding=1) 40 | 41 | self.post_process = GroupNormSiLUConv2d( 42 | num_out_groups, 43 | out_channels, 44 | out_channels, 45 | kernel_size=3, 46 | padding=1, 47 | ) 48 | 49 | if temb_channels is not None: 50 | self.time_emb_proj = Sequential( 51 | SiLU(), 52 | Linear(temb_channels, out_channels), 53 | Rearrange("b c -> b c 1 1"), 54 | ) 55 | else: 56 | self.time_emb_proj = None 57 | 58 | def __call__(self, x: Tensor, *, temb: Optional[Tensor] = None) -> Tensor: 59 | xin = self.conv_shortcut(x) 60 | 61 | x = self.pre_process(x) 62 | 63 | if self.time_emb_proj is not None: 64 | assert temb is not None 65 | 66 | temb = self.time_emb_proj(temb) 67 | x = x + temb 68 | 69 | x = self.post_process(x) 70 | 71 | return xin + x 72 | -------------------------------------------------------------------------------- /sd_fused/layers/blocks/spatial/unet_mid/__init__.py: -------------------------------------------------------------------------------- 1 | from .cross_attention import UNetMidBlock2DCrossAttention 2 | from .self_attention import UNetMidBlock2DSelfAttention 3 | -------------------------------------------------------------------------------- /sd_fused/layers/blocks/spatial/unet_mid/cross_attention.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Optional 3 | 4 | from torch import Tensor 5 | 6 | from ....base import Module, ModuleList 7 | from ...transformer import SpatialTransformer 8 | from ..resnet import ResnetBlock2D 9 | 10 | 11 | class UNetMidBlock2DCrossAttention(Module): 12 | def __init__( 13 | self, 14 | *, 15 | in_channels: int, 16 | temb_channels: int, 17 | num_layers: int, 18 | resnet_groups: int, 19 | attn_num_head_channels: int, 20 | cross_attention_dim: int, 21 | ) -> None: 22 | super().__init__() 23 | 24 | self.in_channels = in_channels 25 | self.temb_channels = temb_channels 26 | self.num_layers = num_layers 27 | self.resnet_groups = resnet_groups 28 | self.attn_num_head_channels = attn_num_head_channels 29 | self.cross_attention_dim = cross_attention_dim 30 | 31 | resnet_groups = resnet_groups or min(in_channels // 4, 32) 32 | 33 | self.attentions = ModuleList[SpatialTransformer]() 34 | self.resnets = ModuleList[ResnetBlock2D]() 35 | for i in range(num_layers + 1): 36 | if i > 0: 37 | self.attentions.append( 38 | SpatialTransformer( 39 | in_channels=in_channels, 40 | num_heads=attn_num_head_channels, 41 | head_features=in_channels // attn_num_head_channels, 42 | depth=1, 43 | num_groups=resnet_groups, 44 | context_features=cross_attention_dim, 45 | ) 46 | ) 47 | 48 | self.resnets.append( 49 | ResnetBlock2D( 50 | in_channels=in_channels, 51 | out_channels=in_channels, 52 | temb_channels=temb_channels, 53 | num_groups=resnet_groups, 54 | num_out_groups=None, 55 | ) 56 | ) 57 | 58 | def __call__( 59 | self, 60 | x: Tensor, 61 | *, 62 | temb: Optional[Tensor] = None, 63 | context: Optional[Tensor] = None, 64 | weights: Optional[Tensor] = None, 65 | ) -> Tensor: 66 | first_resnet, *rest_resnets = self.resnets 67 | 68 | x = first_resnet(x, temb=temb) 69 | 70 | for attn, resnet in zip(self.attentions, rest_resnets): 71 | x = attn(x, context=context, weights=weights) 72 | x = resnet(x, temb=temb) 73 | 74 | return x 75 | -------------------------------------------------------------------------------- /sd_fused/layers/blocks/spatial/unet_mid/self_attention.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Optional 3 | 4 | 5 | from torch import Tensor 6 | 7 | from ....base import Module, ModuleList 8 | from ...attention import SelfAttention 9 | from ..resnet import ResnetBlock2D 10 | 11 | 12 | class UNetMidBlock2DSelfAttention(Module): 13 | def __init__( 14 | self, 15 | *, 16 | in_channels: int, 17 | temb_channels: Optional[int], 18 | num_layers: int, 19 | resnet_groups: int, 20 | attn_num_head_channels: Optional[int], 21 | ) -> None: 22 | super().__init__() 23 | 24 | self.in_channels = in_channels 25 | self.temb_channels = temb_channels 26 | self.num_layers = num_layers 27 | self.resnet_groups = resnet_groups 28 | self.attn_num_head_channels = attn_num_head_channels 29 | 30 | resnet_groups = resnet_groups or min(in_channels // 4, 32) 31 | 32 | self.attentions = ModuleList[SelfAttention]() 33 | self.resnets = ModuleList[ResnetBlock2D]() 34 | for i in range(num_layers + 1): 35 | if i > 0: 36 | self.attentions.append( 37 | SelfAttention( 38 | in_features=in_channels, 39 | num_groups=resnet_groups, 40 | head_features=attn_num_head_channels, 41 | ) 42 | ) 43 | 44 | self.resnets.append( 45 | ResnetBlock2D( 46 | in_channels=in_channels, 47 | out_channels=in_channels, 48 | temb_channels=temb_channels, 49 | num_groups=resnet_groups, 50 | num_out_groups=None, 51 | ) 52 | ) 53 | 54 | def __call__( 55 | self, 56 | x: Tensor, 57 | *, 58 | temb: Optional[Tensor] = None, 59 | ) -> Tensor: 60 | first_resnet, *rest_resnets = self.resnets 61 | 62 | x = first_resnet(x, temb=temb) 63 | 64 | for attn, resnet in zip(self.attentions, rest_resnets): 65 | x = attn(x) 66 | x = resnet(x, temb=temb) 67 | 68 | return x 69 | -------------------------------------------------------------------------------- /sd_fused/layers/blocks/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .basic_transformer import BasicTransformer 2 | from .spatial_transformer import SpatialTransformer 3 | -------------------------------------------------------------------------------- /sd_fused/layers/blocks/transformer/basic_transformer.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Optional 3 | 4 | from torch import Tensor 5 | 6 | from ...base import Module 7 | from ..attention import CrossAttention 8 | from ..basic import LayerNormGEGLULinear 9 | 10 | 11 | class BasicTransformer(Module): 12 | def __init__( 13 | self, 14 | *, 15 | in_features: int, 16 | num_heads: int, 17 | head_features: int, 18 | context_features: Optional[int], 19 | ): 20 | super().__init__() 21 | 22 | self.in_features = in_features 23 | self.num_heads = num_heads 24 | self.head_features = head_features 25 | self.context_features = context_features 26 | 27 | self.attn1 = CrossAttention( 28 | query_features=in_features, 29 | num_heads=num_heads, 30 | head_features=head_features, 31 | context_features=None, 32 | ) 33 | self.attn2 = CrossAttention( 34 | query_features=in_features, 35 | num_heads=num_heads, 36 | head_features=head_features, 37 | context_features=context_features, 38 | ) 39 | 40 | self.ff = LayerNormGEGLULinear(in_features, expand=4) 41 | 42 | def __call__( 43 | self, 44 | x: Tensor, 45 | *, 46 | context: Optional[Tensor] = None, 47 | weights: Optional[Tensor] = None, 48 | ) -> Tensor: 49 | x = self.attn1(x) 50 | x = self.attn2(x, context=context, weights=weights) 51 | x = self.ff(x) 52 | 53 | return x 54 | -------------------------------------------------------------------------------- /sd_fused/layers/blocks/transformer/spatial_transformer.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Optional 3 | 4 | from torch import Tensor 5 | 6 | from ...external import Rearrange 7 | from ...base import Module, ModuleList 8 | from ...basic import Conv2d 9 | from ..basic import GroupNormConv2d 10 | from .basic_transformer import BasicTransformer 11 | 12 | 13 | class SpatialTransformer(Module): 14 | def __init__( 15 | self, 16 | *, 17 | in_channels: int, 18 | num_heads: int, 19 | head_features: int, 20 | depth: int, 21 | num_groups: int, 22 | context_features: Optional[int], 23 | ): 24 | super().__init__() 25 | 26 | self.in_channels = in_channels 27 | self.num_heads = num_heads 28 | self.head_features = head_features 29 | self.depth = depth 30 | self.num_groups = num_groups 31 | self.context_features = context_features 32 | 33 | inner_dim = num_heads * head_features 34 | 35 | self.proj_in = GroupNormConv2d(num_groups, in_channels, inner_dim) 36 | self.proj_out = Conv2d(inner_dim, in_channels) 37 | 38 | self.transformer_blocks = ModuleList[BasicTransformer]() 39 | for _ in range(depth): 40 | self.transformer_blocks.append( 41 | BasicTransformer( 42 | in_features=inner_dim, 43 | num_heads=num_heads, 44 | head_features=head_features, 45 | context_features=context_features, 46 | ) 47 | ) 48 | 49 | self.channel_last = Rearrange("B C H W -> B H W C") 50 | self.channel_first = Rearrange("B H W C -> B C H W") 51 | 52 | def __call__( 53 | self, 54 | x: Tensor, 55 | *, 56 | context: Optional[Tensor] = None, 57 | weights: Optional[Tensor] = None, 58 | ) -> Tensor: 59 | B, C, H, W = x.shape 60 | 61 | xin = x 62 | 63 | x = self.proj_in(x) 64 | x = self.channel_last(x) 65 | 66 | for block in self.transformer_blocks: 67 | x = block(x, context=context, weights=weights) 68 | x = self.channel_first(x, H=H, W=W) 69 | 70 | return xin + self.proj_out(x) 71 | -------------------------------------------------------------------------------- /sd_fused/layers/distribution/__init__.py: -------------------------------------------------------------------------------- 1 | from .diag_gaussian import DiagonalGaussianDistribution 2 | -------------------------------------------------------------------------------- /sd_fused/layers/distribution/diag_gaussian.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Optional 3 | 4 | import torch 5 | from torch import Tensor 6 | 7 | 8 | class DiagonalGaussianDistribution: 9 | def __init__( 10 | self, 11 | mean: Tensor, 12 | logvar: Tensor, 13 | ) -> None: 14 | super().__init__() 15 | 16 | self.device = mean.device 17 | self.dtype = mean.dtype 18 | 19 | self.mean = mean 20 | 21 | logvar = logvar.clamp(-30, 20) 22 | self.std = torch.exp(logvar / 2) 23 | 24 | def sample(self, generator: Optional[torch.Generator] = None) -> Tensor: 25 | # TODO use seeds? 26 | noise = torch.randn(self.std.shape, generator=generator, device=self.device, dtype=self.dtype) 27 | 28 | return self.mean + self.std * noise 29 | -------------------------------------------------------------------------------- /sd_fused/layers/embedding/__init__.py: -------------------------------------------------------------------------------- 1 | from .time_steps import Timesteps 2 | from .time_step_emb import TimestepEmbedding 3 | -------------------------------------------------------------------------------- /sd_fused/layers/embedding/time_step_emb.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from ..base import Sequential 4 | from ..basic import Linear, Identity 5 | from ..activation import SiLU 6 | 7 | 8 | class TimestepEmbedding(Sequential): 9 | def __init__( 10 | self, 11 | *, 12 | channel: int, 13 | time_embed_dim: int, 14 | use_silu: bool = True, # !always true 15 | ) -> None: 16 | 17 | self.channel = channel 18 | self.time_embed_dim = time_embed_dim 19 | self.use_silu = use_silu 20 | 21 | layers = ( 22 | Linear(channel, time_embed_dim), 23 | SiLU() if use_silu else Identity(), # ? silu always true? 24 | Linear(time_embed_dim, time_embed_dim), 25 | ) 26 | 27 | super().__init__(*layers) 28 | -------------------------------------------------------------------------------- /sd_fused/layers/embedding/time_steps.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import math 4 | 5 | import torch 6 | from torch import Tensor 7 | 8 | from ..base import Module 9 | 10 | 11 | class Timesteps(Module): 12 | freq: Tensor 13 | amplitude: Tensor 14 | phase: Tensor 15 | 16 | def __init__( 17 | self, 18 | *, 19 | num_channels: int, 20 | flip_sin_to_cos: bool, 21 | downscale_freq_shift: float, 22 | scale: float = 1, 23 | max_period: int = 10_000, 24 | ) -> None: 25 | super().__init__() 26 | 27 | self.num_channels = num_channels 28 | self.flip_sin_to_cos = flip_sin_to_cos 29 | self.downscale_freq_shift = downscale_freq_shift 30 | self.scale = scale 31 | self.max_period = max_period 32 | 33 | assert num_channels % 2 == 0 34 | half_dim = num_channels // 2 35 | 36 | idx = torch.arange(half_dim) 37 | 38 | exponent = -math.log(max_period) * idx 39 | exponent /= half_dim - downscale_freq_shift 40 | 41 | freq = exponent.exp() 42 | freq = torch.cat([freq, freq]).unsqueeze(0) 43 | 44 | amplitude = torch.full((1, num_channels), scale) 45 | 46 | zeros = torch.zeros((1, half_dim)) 47 | ones = torch.ones((1, half_dim)) 48 | phase = torch.cat([zeros, ones], dim=1) * torch.pi / 2 49 | 50 | if flip_sin_to_cos: 51 | freq = reverse(freq, half_dim) 52 | phase = reverse(phase, half_dim) 53 | amplitude = reverse(amplitude, half_dim) 54 | 55 | # TODO create nn.Buffer 56 | self.freq = freq 57 | self.phase = phase 58 | self.amplitude = amplitude 59 | 60 | def __call__(self, x: Tensor) -> Tensor: 61 | x = x[..., None] 62 | 63 | kwargs = dict(device=self.device, dtype=self.dtype, non_blocking=True) 64 | self.freq = self.freq.to(**kwargs) 65 | self.phase = self.phase.to(**kwargs) 66 | self.amplitude = self.amplitude.to(**kwargs) 67 | 68 | return self.amplitude * torch.sin(x * self.freq + self.phase) 69 | 70 | 71 | def reverse(x: Tensor, half_dim: int) -> Tensor: 72 | return torch.cat([x[:, half_dim:], x[:, :half_dim]], dim=1) 73 | -------------------------------------------------------------------------------- /sd_fused/layers/external/__init__.py: -------------------------------------------------------------------------------- 1 | from .rearrange import Rearrange 2 | -------------------------------------------------------------------------------- /sd_fused/layers/external/rearrange.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from torch import Tensor 4 | from einops import rearrange 5 | 6 | from ..base import Module 7 | 8 | 9 | class Rearrange(Module): 10 | def __init__(self, pattern: str, **axes_length: int) -> None: 11 | self.pattern = pattern 12 | self.axes_length = axes_length 13 | 14 | def __call__(self, x: Tensor, **axes_length: int) -> Tensor: 15 | return rearrange(x, self.pattern, **self.axes_length, **axes_length) 16 | 17 | def make_inverse(self) -> Rearrange: 18 | left, right = self.pattern.split("->") 19 | 20 | new = Rearrange(f"{right} -> {left}") 21 | 22 | return new 23 | -------------------------------------------------------------------------------- /sd_fused/layers/modifiers/__init__.py: -------------------------------------------------------------------------------- 1 | from .half_weights import HalfWeightsModule, half_weights 2 | -------------------------------------------------------------------------------- /sd_fused/layers/modifiers/half_weights.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing_extensions import Self 3 | 4 | from functools import wraps 5 | 6 | import torch 7 | from torch import Tensor 8 | 9 | from ..base.module import Module 10 | 11 | 12 | class HalfWeightsModule(Module): 13 | use_half_weights: bool = False 14 | 15 | def half_weights(self, use: bool = True) -> Self: 16 | self.use_half_weights = use 17 | 18 | return self.half() if use or self.dtype == torch.float16 else self.float() 19 | 20 | 21 | def half_weights(fun): 22 | @wraps(fun) 23 | def wrapper(self: HalfWeightsModule, *args, **kwargs): 24 | if self.use_half_weights: 25 | self.float() 26 | 27 | args = tuple(a.float() if isinstance(a, Tensor) else a for a in args) 28 | kwargs = {k: v.float() if isinstance(v, Tensor) else v for k, v in kwargs.items()} 29 | 30 | out = fun(self, *args, **kwargs) 31 | self.half() 32 | else: 33 | out = fun(self, *args, **kwargs) 34 | 35 | return out 36 | 37 | return wrapper 38 | -------------------------------------------------------------------------------- /sd_fused/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .ae_kl import AutoencoderKL 2 | from .unet_conditional import UNet2DConditional 3 | -------------------------------------------------------------------------------- /sd_fused/models/ae_kl.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing_extensions import Self 3 | 4 | from pathlib import Path 5 | import json 6 | 7 | import torch 8 | from torch import Tensor 9 | 10 | from ..layers.base import Module 11 | from ..layers.basic import Conv2d 12 | from ..layers.auto_encoder import Encoder, Decoder 13 | from ..layers.distribution import DiagonalGaussianDistribution 14 | from ..utils.tensors import normalize, denormalize 15 | from .config import VaeConfig 16 | from .convert import diffusers2fused_vae 17 | from .convert.states import debug_state_replacements 18 | from .modifiers import HalfWeightsModel, SplitAttentionModel, FlashAttentionModel, ToMeModel 19 | 20 | 21 | class AutoencoderKL(HalfWeightsModel, SplitAttentionModel, FlashAttentionModel, ToMeModel, Module): 22 | @classmethod 23 | def from_config(cls, path: str | Path) -> Self: 24 | """Creates a model from a (diffusers) config file.""" 25 | 26 | path = Path(path) 27 | if path.is_dir(): 28 | path /= "config.json" 29 | assert path.suffix == ".json" 30 | 31 | db = json.load(open(path, "r")) 32 | config = VaeConfig(**db) 33 | 34 | return cls( 35 | in_channels=config.in_channels, 36 | out_channels=config.out_channels, 37 | block_out_channels=tuple(config.block_out_channels), 38 | layers_per_block=config.layers_per_block, 39 | latent_channels=config.latent_channels, 40 | norm_num_groups=config.norm_num_groups, 41 | ) 42 | 43 | def __init__( 44 | self, 45 | *, 46 | in_channels: int = 3, 47 | out_channels: int = 3, 48 | block_out_channels: tuple[int, ...] = (128, 256, 512, 512), 49 | layers_per_block: int = 2, 50 | latent_channels: int = 4, 51 | norm_num_groups: int = 32, 52 | ) -> None: 53 | super().__init__() 54 | 55 | self.in_channels = in_channels 56 | self.out_channels = out_channels 57 | self.block_out_channels = block_out_channels 58 | self.layers_per_block = layers_per_block 59 | self.latent_channels = latent_channels 60 | self.norm_num_groups = norm_num_groups 61 | 62 | self.encoder = Encoder( 63 | in_channels=in_channels, 64 | out_channels=latent_channels, 65 | block_out_channels=block_out_channels, 66 | layers_per_block=layers_per_block, 67 | resnet_groups=norm_num_groups, 68 | double_z=True, 69 | ) 70 | 71 | self.decoder = Decoder( 72 | in_channels=latent_channels, 73 | out_channels=out_channels, 74 | block_out_channels=block_out_channels, 75 | layers_per_block=layers_per_block, 76 | resnet_groups=norm_num_groups, 77 | ) 78 | 79 | # TODO very bad names... 80 | self.quant_conv = Conv2d(2 * latent_channels) 81 | self.post_quant_conv = Conv2d(latent_channels) 82 | 83 | def encode(self, x: Tensor) -> DiagonalGaussianDistribution: 84 | """Encode an byte-Tensor into a posterior distribution.""" 85 | 86 | x = normalize(x, self.dtype) 87 | x = self.encoder(x) 88 | 89 | moments = self.quant_conv(x) 90 | mean, logvar = moments.chunk(2, dim=1) 91 | 92 | return DiagonalGaussianDistribution(mean, logvar) 93 | 94 | def decode(self, z: Tensor) -> Tensor: 95 | """Decode the latent's space into an image.""" 96 | 97 | z = self.post_quant_conv(z) 98 | out = self.decoder(z) 99 | 100 | out = denormalize(out) 101 | 102 | return out 103 | 104 | def __call__(self): 105 | raise ValueError("This function is not callable") 106 | 107 | @classmethod 108 | def from_diffusers(cls, path: str | Path) -> Self: 109 | """Load Stable-Diffusion from diffusers checkpoint folder.""" 110 | 111 | path = Path(path) 112 | model = cls.from_config(path) 113 | 114 | state_path = next(path.glob("*.bin")) 115 | old_state = torch.load(state_path, map_location="cpu") 116 | replaced_state = diffusers2fused_vae(old_state) 117 | 118 | debug_state_replacements(model.state_dict(), replaced_state) 119 | 120 | model.load_state_dict(replaced_state) 121 | 122 | return model 123 | -------------------------------------------------------------------------------- /sd_fused/models/config.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Type 3 | 4 | from dataclasses import dataclass 5 | 6 | from ..layers.blocks.spatial import DownBlock2D, UpBlock2D, CrossAttentionDownBlock2D, CrossAttentionUpBlock2D 7 | 8 | 9 | @dataclass 10 | class VaeConfig: 11 | _class_name: str 12 | _diffusers_version: str 13 | act_fn: str 14 | in_channels: int 15 | latent_channels: int 16 | layers_per_block: int 17 | out_channels: int 18 | sample_size: int 19 | block_out_channels: list[int] 20 | down_block_types: list[str] 21 | up_block_types: list[str] 22 | norm_num_groups: int = 32 23 | 24 | def __post_init__(self) -> None: 25 | assert self._class_name == "AutoencoderKL" 26 | assert self.act_fn == "silu" 27 | 28 | for block in self.down_block_types: 29 | assert block == "DownEncoderBlock2D" 30 | 31 | for block in self.up_block_types: 32 | assert block == "UpDecoderBlock2D" 33 | 34 | 35 | @dataclass 36 | class UnetConfig: 37 | _class_name: str 38 | _diffusers_version: str 39 | act_fn: str 40 | attention_head_dim: int 41 | block_out_channels: list[int] 42 | center_input_sample: bool 43 | cross_attention_dim: int 44 | down_block_types: list[str] 45 | downsample_padding: int 46 | flip_sin_to_cos: bool 47 | freq_shift: int 48 | in_channels: int 49 | layers_per_block: int 50 | mid_block_scale_factor: int 51 | norm_eps: float 52 | norm_num_groups: int 53 | out_channels: int 54 | sample_size: int 55 | up_block_types: list[str] 56 | 57 | def __post_init__(self) -> None: 58 | assert self._class_name == "UNet2DConditionModel" 59 | assert self.act_fn == "silu" 60 | 61 | for block in self.down_block_types: 62 | assert block in ("CrossAttnDownBlock2D", "DownBlock2D") 63 | for block in self.up_block_types: 64 | assert block in ("UpBlock2D", "CrossAttnUpBlock2D") 65 | 66 | @property 67 | def down_blocks( 68 | self, 69 | ) -> tuple[Type[CrossAttentionDownBlock2D] | Type[DownBlock2D], ...]: 70 | def get_block( 71 | block: str, 72 | ) -> Type[CrossAttentionDownBlock2D] | Type[DownBlock2D]: 73 | if block == "CrossAttnDownBlock2D": 74 | return CrossAttentionDownBlock2D 75 | if block == "DownBlock2D": 76 | return DownBlock2D 77 | 78 | raise ValueError 79 | 80 | return tuple(get_block(block) for block in self.down_block_types) 81 | 82 | @property 83 | def up_blocks( 84 | self, 85 | ) -> tuple[Type[CrossAttentionUpBlock2D] | Type[UpBlock2D], ...]: 86 | def get_block( 87 | block: str, 88 | ) -> Type[CrossAttentionUpBlock2D] | Type[UpBlock2D]: 89 | if block == "CrossAttnUpBlock2D": 90 | return CrossAttentionUpBlock2D 91 | if block == "UpBlock2D": 92 | return UpBlock2D 93 | 94 | raise ValueError("Invalid") 95 | 96 | return tuple(get_block(block) for block in self.up_block_types) 97 | -------------------------------------------------------------------------------- /sd_fused/models/convert/__init__.py: -------------------------------------------------------------------------------- 1 | from .vae.sd2diffusers import sd2diffusers as sd2diffusers_vae 2 | from .vae.diffusers2fused import diffusers2fused as diffusers2fused_vae 3 | 4 | from .unet.diffusers2fused import diffusers2fused as diffusers2fused_unet 5 | -------------------------------------------------------------------------------- /sd_fused/models/convert/states.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import re 4 | 5 | import torch.nn as nn 6 | from torch import Tensor 7 | 8 | 9 | def replace_state( 10 | old_state: dict[str, Tensor], 11 | replacements: list[tuple[str, str]], 12 | ) -> dict[str, Tensor]: 13 | """Replace the state-dict with new keys""" 14 | 15 | state: dict[str, Tensor] = {} 16 | for key in old_state.keys(): 17 | new_key = key 18 | for old, new in replacements: 19 | new_key = re.sub(old, new, new_key) 20 | 21 | state[new_key] = old_state[key] 22 | 23 | return state 24 | 25 | 26 | def debug_state_replacements( 27 | state: dict[str, Tensor] | dict[str, nn.Parameter], 28 | replaced_state: dict[str, Tensor], 29 | ) -> None: 30 | good_keys = set(state.keys()) 31 | replaced_keys = set(replaced_state.keys()) 32 | 33 | delta = good_keys - replaced_keys 34 | if len(delta) != 0: 35 | print("miss replacing some keys") 36 | print("=" * 32) 37 | for key in delta: 38 | print(key) 39 | 40 | delta = replaced_keys - good_keys 41 | if len(delta) != 0: 42 | print("wrongly replaced some keys") 43 | print("=" * 32) 44 | for key in replaced_keys - good_keys: 45 | print(key) 46 | -------------------------------------------------------------------------------- /sd_fused/models/convert/unet/diffusers2fused.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from torch import Tensor 4 | 5 | from ..states import replace_state 6 | 7 | 8 | def diffusers2fused(old_state: dict[str, Tensor]) -> dict[str, Tensor]: 9 | """Convert a diffusers checkpoint into a sd-fused one for the unet.""" 10 | 11 | return replace_state(old_state, REPLACEMENTS) 12 | 13 | 14 | # fmt: off 15 | # diffusers to sd-fused replacements 16 | REPLACEMENTS: list[tuple[str, str]] = [ 17 | # Cross-attention 18 | (r"transformer_blocks.(\d).norm([12]).(weight|bias)", r"transformer_blocks.\1.attn\2.norm.\3"), 19 | ## FeedForward (norm) 20 | (r"transformer_blocks.(\d).norm3.(weight|bias)", r"transformer_blocks.\1.ff.0.\2"), 21 | ## FeedForward (geglu) 22 | (r"ff.net.0.proj.(weight|bias)", r"ff.1.proj.\1"), 23 | ## FeedForward-Linear 24 | (r"ff.net.2.(weight|bias)", r"ff.2.\1"), 25 | # up/down samplers 26 | (r"(up|down)samplers.0", r"\1sampler"), 27 | # CrossAttention projection 28 | (r"to_out.0.", r"to_out."), 29 | # TimeEmbedding 30 | (r"time_embedding.linear_1.(weight|bias)", r"time_embedding.0.\1"), 31 | (r"time_embedding.linear_2.(weight|bias)", r"time_embedding.2.\1"), 32 | # resnet-blocks pre/post-process 33 | (r"resnets.(\d).norm1.(bias|weight)", r"resnets.\1.pre_process.0.\2"), 34 | (r"resnets.(\d).conv1.(bias|weight)", r"resnets.\1.pre_process.2.\2"), 35 | (r"resnets.(\d).norm2.(bias|weight)", r"resnets.\1.post_process.0.\2"), 36 | (r"resnets.(\d).conv2.(bias|weight)", r"resnets.\1.post_process.2.\2"), 37 | # resnet-time-embedding 38 | (r"time_emb_proj.(bias|weight)", r"time_emb_proj.1.\1"), 39 | # spatial transformer fused 40 | (r"attentions.(\d).norm.(bias|weight)", r"attentions.\1.proj_in.0.\2"), 41 | (r"attentions.(\d).proj_in.(bias|weight)", r"attentions.\1.proj_in.1.\2"), 42 | # post-processing 43 | (r"conv_norm_out.(bias|weight)", r"post_process.0.\1"), 44 | (r"conv_out.(bias|weight)", r"post_process.2.\1"), 45 | # pre-processing 46 | (r'conv_in.(bias|weight)', r'pre_process.\1') 47 | ] 48 | -------------------------------------------------------------------------------- /sd_fused/models/convert/vae/diffusers2fused.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from torch import Tensor 4 | 5 | from ..states import replace_state 6 | 7 | # ! Tensor/Parameter? 8 | def diffusers2fused(old_state: dict[str, Tensor]) -> dict[str, Tensor]: 9 | """Convert a diffusers checkpoint into a sd-fused one for the AutoencoderKL.""" 10 | 11 | return replace_state(old_state, REPLACEMENTS) 12 | 13 | 14 | # fmt: off 15 | # diffusers to sd-fused replacements 16 | REPLACEMENTS: list[tuple[str, str]] = [ 17 | # up/down samplers 18 | (r"(up|down)samplers.0", r"\1sampler"), 19 | # post_process 20 | (r"(decoder|encoder).conv_norm_out.(bias|weight)", r"\1.post_process.0.\2"), 21 | (r"(decoder|encoder).conv_out.(bias|weight)", r"\1.post_process.2.\2"), 22 | # resnet-blocks pre/post-process 23 | (r"resnets.(\d).norm1.(bias|weight)", r"resnets.\1.pre_process.0.\2"), 24 | (r"resnets.(\d).conv1.(bias|weight)", r"resnets.\1.pre_process.2.\2"), 25 | (r"resnets.(\d).norm2.(bias|weight)", r"resnets.\1.post_process.0.\2"), 26 | (r"resnets.(\d).conv2.(bias|weight)", r"resnets.\1.post_process.2.\2"), 27 | # pre-processing 28 | (r'(encoder|decoder).conv_in.(bias|weight)', r'\1.pre_process.\2') 29 | ] 30 | -------------------------------------------------------------------------------- /sd_fused/models/convert/vae/sd2diffusers.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import re 4 | from torch import Tensor 5 | 6 | # TODO make a config from sd? 7 | # TODO work in progress... 8 | 9 | 10 | def sd2diffusers(old_state: dict[str, Tensor]) -> dict[str, Tensor]: 11 | """Convert a Stable-Diffusion checkpoint into a diffusers-one for the AutoencoderKL.""" 12 | 13 | VAE = "first_stage_model." 14 | 15 | state: dict[str, Tensor] = {} 16 | for key in old_state.keys(): 17 | if not key.startswith(VAE): 18 | continue 19 | 20 | new_key = key.replace(VAE, "") 21 | 22 | for old, new in REPLACEMENTS: 23 | new_key = re.sub(old, new, new_key) 24 | 25 | state[new_key] = old_state[key] 26 | 27 | # reshape attentions weights 28 | # needed due to conv -> linear conversion 29 | if re.match( 30 | r"(encoder|decoder).mid_block.attentions.\d.(key|query|value|proj_attn).weight", 31 | new_key, 32 | ): 33 | state[new_key] = state[new_key].squeeze() 34 | 35 | return state 36 | 37 | 38 | # fmt: off 39 | # Stable-diffusion to diffusers replacements 40 | REPLACEMENTS: list[tuple[str, str]] = [ 41 | # reverse up-block order 42 | (r"up\.0", r"UP.3"), 43 | (r"up\.1", r"UP.2"), 44 | (r"up\.2", r"UP.1"), 45 | (r"up\.3", r"UP.0"), 46 | (r"UP", "up"), # recover lower-case 47 | # short-cut connection 48 | (r"nin_shortcut", r"conv_shortcut"), 49 | # encoder/decoder up/down-blocks 50 | ( 51 | r"(encoder|decoder)\.(down|up)\.(\d)\.block\.(\d)\.(norm\d|conv\d|conv_shortcut)\.(weight|bias)", 52 | r"\1.\2_blocks.\3.resnets.\4.\5.\6", 53 | ), 54 | ( 55 | r"(encoder|decoder)\.(down|up)\.(\d)\.(downsample|upsample)\.conv\.(weight|bias)", 56 | r"\1.\2_blocks.\3.\4rs.0.conv.\5", 57 | ), 58 | # mid-blocks 59 | (r"block_1", r"block_0"), 60 | (r"block_2", r"block_1"), 61 | ( 62 | r"(encoder|decoder)\.mid\.block_(\d)\.(norm\d|conv\d)\.(weight|bias)", 63 | r"\1.mid_block.resnets.\2.\3.\4", 64 | ), 65 | # mid-block attention 66 | (r"\.q\.", r".query."), 67 | (r"\.k\.", r".key."), 68 | (r"\.v\.", r".value."), 69 | (r"attn_1\.proj_out", r"attn_1.proj_attn"), 70 | (r"attn_1\.norm", r"attn_1.group_norm"), 71 | ( 72 | r"(encoder|decoder)\.mid\.attn_1\.(group_norm|query|key|value|proj_attn)\.(weight|bias)", 73 | r"\1.mid_block.attentions.0.\2.\3", 74 | ), 75 | # in-out 76 | (r"(encoder|decoder)\.(norm_out)\.(weight|bias)", r"\1.conv_\2.\3"), 77 | ] 78 | -------------------------------------------------------------------------------- /sd_fused/models/modifiers/__init__.py: -------------------------------------------------------------------------------- 1 | from .half_weights import HalfWeightsModel 2 | from .split_attention import SplitAttentionModel 3 | from .flash_attention import FlashAttentionModel 4 | from .tome import ToMeModel 5 | -------------------------------------------------------------------------------- /sd_fused/models/modifiers/flash_attention.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing_extensions import Self 3 | 4 | from ...layers.base import Module 5 | from ...layers.blocks.attention import CrossAttention, SelfAttention 6 | 7 | 8 | class FlashAttentionModel(Module): 9 | def flash_attention(self, use: bool = True) -> Self: 10 | """Use xformers flash-attention.""" 11 | 12 | for name, module in self.named_modules().items(): 13 | if isinstance(module, (CrossAttention, SelfAttention)): 14 | module.use_flash_attention = use 15 | 16 | if use: 17 | module.attention_chunks = None 18 | module.chunk_type = None 19 | 20 | return self 21 | -------------------------------------------------------------------------------- /sd_fused/models/modifiers/half_weights.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing_extensions import Self 3 | 4 | from ...layers.base import Module 5 | from ...layers.blocks.attention import CrossAttention, SelfAttention 6 | from ...layers.modifiers import HalfWeightsModule 7 | 8 | 9 | class HalfWeightsModel(Module): 10 | def half_weights(self, use: bool = True) -> Self: 11 | """Store the weights in half-precision but 12 | compute forward pass in full precision. 13 | Useful for GPUs that gives NaN when used in half-precision. 14 | """ 15 | 16 | for name, module in self.named_modules().items(): 17 | if isinstance(module, HalfWeightsModule): 18 | module.half_weights(use) 19 | 20 | return self 21 | -------------------------------------------------------------------------------- /sd_fused/models/modifiers/split_attention.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Optional 3 | from typing_extensions import Self 4 | 5 | from ...layers.base import Module 6 | from ...utils.typing import Literal 7 | from ...layers.blocks.attention import CrossAttention, SelfAttention 8 | from ...layers.blocks.attention.compute import ChunkType 9 | 10 | 11 | class SplitAttentionModel(Module): 12 | def split_attention( 13 | self, 14 | chunks: Optional[int | Literal["auto"]] = "auto", 15 | chunk_type: Optional[ChunkType] = None, 16 | ) -> Self: 17 | """Split cross/self-attention computation into chunks.""" 18 | 19 | for name, module in self.named_modules().items(): 20 | if isinstance(module, (CrossAttention, SelfAttention)): 21 | module.attention_chunks = chunks 22 | module.chunk_type = chunk_type 23 | 24 | if chunks is not None: 25 | module.use_flash_attention = False 26 | 27 | return self 28 | -------------------------------------------------------------------------------- /sd_fused/models/modifiers/tome.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Optional 3 | from typing_extensions import Self 4 | 5 | from ...layers.base import Module 6 | from ...layers.blocks.attention import CrossAttention, SelfAttention 7 | 8 | 9 | class ToMeModel(Module): 10 | # https://arxiv.org/abs/2210.09461 11 | # https://github.com/facebookresearch/ToMe/blob/main/tome/merge.py 12 | 13 | def tome(self, r: Optional[int | float] = None) -> Self: 14 | """Merge similar tokens.""" 15 | 16 | for name, module in self.named_modules().items(): 17 | if isinstance(module, (CrossAttention, SelfAttention)): 18 | module.tome_r = r 19 | 20 | return self 21 | -------------------------------------------------------------------------------- /sd_fused/models/unet_conditional.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Optional, Type 3 | from typing_extensions import Self 4 | 5 | from pathlib import Path 6 | import json 7 | 8 | import torch 9 | from torch import Tensor 10 | 11 | from ..layers.base import Module, ModuleList 12 | from ..layers.basic import Conv2d 13 | from ..layers.embedding import Timesteps, TimestepEmbedding 14 | from ..layers.blocks.basic import GroupNormSiLUConv2d 15 | from ..layers.blocks.spatial import ( 16 | UNetMidBlock2DCrossAttention, 17 | DownBlock2D, 18 | UpBlock2D, 19 | CrossAttentionDownBlock2D, 20 | CrossAttentionUpBlock2D, 21 | ) 22 | from ..utils.tensors import to_tensor 23 | from .modifiers import HalfWeightsModel, SplitAttentionModel, FlashAttentionModel, ToMeModel 24 | from .config import UnetConfig 25 | from .convert import diffusers2fused_unet 26 | from .convert.states import debug_state_replacements 27 | 28 | 29 | class UNet2DConditional(HalfWeightsModel, SplitAttentionModel, FlashAttentionModel, ToMeModel, Module): 30 | @classmethod 31 | def from_config(cls, path: str | Path) -> Self: 32 | """Creates a model from a (diffusers) config file.""" 33 | 34 | path = Path(path) 35 | if path.is_dir(): 36 | path /= "config.json" 37 | assert path.suffix == ".json" 38 | 39 | db = json.load(open(path, "r")) 40 | config = UnetConfig(**db) 41 | 42 | return cls( 43 | in_channels=config.in_channels, 44 | out_channels=config.out_channels, 45 | flip_sin_to_cos=config.flip_sin_to_cos, 46 | freq_shift=config.freq_shift, 47 | down_blocks=config.down_blocks, 48 | up_blocks=config.up_blocks, 49 | block_out_channels=tuple(config.block_out_channels), 50 | layers_per_block=config.layers_per_block, 51 | downsample_padding=config.downsample_padding, 52 | norm_num_groups=config.norm_num_groups, 53 | cross_attention_dim=config.cross_attention_dim, 54 | attention_head_dim=config.attention_head_dim, 55 | ) 56 | 57 | def __init__( 58 | self, 59 | *, 60 | in_channels: int = 4, 61 | out_channels: int = 4, 62 | flip_sin_to_cos: bool = True, 63 | freq_shift: int = 0, 64 | down_blocks: tuple[Type[CrossAttentionDownBlock2D] | Type[DownBlock2D], ...] = ( 65 | CrossAttentionDownBlock2D, 66 | CrossAttentionDownBlock2D, 67 | CrossAttentionDownBlock2D, 68 | DownBlock2D, 69 | ), 70 | up_blocks: tuple[Type[CrossAttentionUpBlock2D] | Type[UpBlock2D], ...] = ( 71 | UpBlock2D, 72 | CrossAttentionUpBlock2D, 73 | CrossAttentionUpBlock2D, 74 | CrossAttentionUpBlock2D, 75 | ), 76 | block_out_channels: tuple[int, ...] = (320, 640, 1280, 1280), 77 | layers_per_block: int = 2, 78 | downsample_padding: int = 1, 79 | norm_num_groups: int = 32, 80 | cross_attention_dim: int = 768, 81 | attention_head_dim: int = 8, 82 | ) -> None: 83 | super().__init__() 84 | 85 | self.in_channels = in_channels 86 | self.out_channels = out_channels 87 | self.flip_sin_to_cos = flip_sin_to_cos 88 | self.freq_shift = freq_shift 89 | self.block_out_channels = block_out_channels 90 | self.layers_per_block = layers_per_block 91 | self.downsample_padding = downsample_padding 92 | self.norm_num_groups = norm_num_groups 93 | self.cross_attention_dim = cross_attention_dim 94 | self.attention_head_dim = attention_head_dim 95 | 96 | time_embed_dim = block_out_channels[0] * 4 97 | 98 | # input 99 | self.pre_process = Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=1) 100 | 101 | # time 102 | self.time_proj = Timesteps( 103 | num_channels=block_out_channels[0], 104 | flip_sin_to_cos=flip_sin_to_cos, 105 | downscale_freq_shift=freq_shift, 106 | ) 107 | timestep_input_dim = block_out_channels[0] 108 | 109 | self.time_embedding = TimestepEmbedding(channel=timestep_input_dim, time_embed_dim=time_embed_dim) 110 | 111 | # down 112 | output_channel = block_out_channels[0] 113 | self.down_blocks = ModuleList[CrossAttentionDownBlock2D | DownBlock2D]() 114 | for i, block in enumerate(down_blocks): 115 | input_channel = output_channel 116 | output_channel = block_out_channels[i] 117 | is_final_block = i == len(block_out_channels) - 1 118 | 119 | if block == CrossAttentionDownBlock2D: 120 | self.down_blocks.append( 121 | CrossAttentionDownBlock2D( 122 | in_channels=input_channel, 123 | out_channels=output_channel, 124 | temb_channels=time_embed_dim, 125 | num_layers=layers_per_block, 126 | resnet_groups=norm_num_groups, 127 | attn_num_head_channels=attention_head_dim, 128 | cross_attention_dim=cross_attention_dim, 129 | downsample_padding=downsample_padding, 130 | add_downsample=not is_final_block, 131 | ) 132 | ) 133 | elif block == DownBlock2D: 134 | self.down_blocks.append( 135 | DownBlock2D( 136 | in_channels=input_channel, 137 | out_channels=output_channel, 138 | temb_channels=time_embed_dim, 139 | num_layers=layers_per_block, 140 | num_groups=norm_num_groups, 141 | add_downsample=not is_final_block, 142 | downsample_padding=downsample_padding, 143 | ) 144 | ) 145 | 146 | # mid 147 | self.mid_block = UNetMidBlock2DCrossAttention( 148 | in_channels=block_out_channels[-1], 149 | temb_channels=time_embed_dim, 150 | cross_attention_dim=cross_attention_dim, 151 | attn_num_head_channels=attention_head_dim, 152 | resnet_groups=norm_num_groups, 153 | num_layers=1, 154 | ) 155 | 156 | # up 157 | reversed_block_out_channels = tuple(reversed(block_out_channels)) 158 | output_channel = reversed_block_out_channels[0] 159 | self.up_blocks = ModuleList[CrossAttentionUpBlock2D | UpBlock2D]() 160 | for i, block in enumerate(up_blocks): 161 | prev_output_channel = output_channel 162 | output_channel = reversed_block_out_channels[i] 163 | input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] 164 | 165 | is_final_block = i == len(block_out_channels) - 1 166 | 167 | if block == CrossAttentionUpBlock2D: 168 | self.up_blocks.append( 169 | CrossAttentionUpBlock2D( 170 | in_channels=input_channel, 171 | out_channels=output_channel, 172 | prev_output_channel=prev_output_channel, 173 | num_layers=layers_per_block + 1, 174 | temb_channels=time_embed_dim, 175 | add_upsample=not is_final_block, 176 | resnet_groups=norm_num_groups, 177 | cross_attention_dim=cross_attention_dim, 178 | attn_num_head_channels=attention_head_dim, 179 | ) 180 | ) 181 | elif block == UpBlock2D: 182 | self.up_blocks.append( 183 | UpBlock2D( 184 | in_channels=input_channel, 185 | out_channels=output_channel, 186 | prev_output_channel=prev_output_channel, 187 | num_layers=layers_per_block + 1, 188 | temb_channels=time_embed_dim, 189 | add_upsample=not is_final_block, 190 | resnet_groups=norm_num_groups, 191 | ) 192 | ) 193 | prev_output_channel = output_channel 194 | 195 | # out 196 | self.post_process = GroupNormSiLUConv2d( 197 | norm_num_groups, 198 | block_out_channels[0], 199 | out_channels, 200 | kernel_size=3, 201 | padding=1, 202 | ) 203 | 204 | def __call__( 205 | self, 206 | x: Tensor, 207 | timestep: int | Tensor, 208 | context: Tensor, 209 | weights: Optional[Tensor] = None, 210 | ) -> Tensor: 211 | B, C, H, W = x.shape 212 | 213 | # 1. time embedding 214 | timestep = to_tensor(timestep, device=x.device, dtype=x.dtype) 215 | if timestep.size(0) != B: 216 | assert timestep.size(0) == 1 217 | timestep = timestep.expand(B) 218 | 219 | temb = self.time_proj(timestep) 220 | temb = self.time_embedding(temb) 221 | 222 | # 2. pre-process 223 | x = self.pre_process(x) 224 | 225 | # 3. down 226 | # TODO it is possible to make it a list[list[Tensor]]? or is the number of elements wrong? 227 | all_states: list[Tensor] = [x] 228 | for block in self.down_blocks: 229 | if isinstance(block, CrossAttentionDownBlock2D): 230 | x, states = block(x, temb=temb, context=context, weights=weights) 231 | elif isinstance(block, DownBlock2D): 232 | x, states = block(x, temb=temb) 233 | else: 234 | raise ValueError 235 | 236 | all_states.extend(states) 237 | del states 238 | 239 | # 4. mid 240 | x = self.mid_block(x, temb=temb, context=context, weights=weights) 241 | 242 | # 5. up 243 | for block in self.up_blocks: 244 | assert isinstance(block, (CrossAttentionUpBlock2D, UpBlock2D)) 245 | 246 | # ! I don't like the construction of this... 247 | states = list(all_states.pop() for _ in range(block.num_layers)) 248 | 249 | if isinstance(block, CrossAttentionUpBlock2D): 250 | x = block(x, states=states, temb=temb, context=context, weights=weights) 251 | elif isinstance(block, UpBlock2D): 252 | x = block(x, states=states, temb=temb) 253 | else: 254 | raise ValueError 255 | 256 | del states 257 | del all_states 258 | 259 | # 6. post-process 260 | x = self.post_process(x) 261 | 262 | return x 263 | 264 | @classmethod 265 | def from_diffusers(cls, path: str | Path) -> Self: 266 | """Load Stable-Diffusion from diffusers checkpoint folder.""" 267 | 268 | path = Path(path) 269 | model = cls.from_config(path) 270 | 271 | state_path = next(path.glob("*.bin")) 272 | old_state = torch.load(state_path, map_location="cpu") 273 | replaced_state = diffusers2fused_unet(old_state) 274 | 275 | debug_state_replacements(model.state_dict(), replaced_state) 276 | 277 | model.load_state_dict(replaced_state) 278 | 279 | return model 280 | -------------------------------------------------------------------------------- /sd_fused/scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | from .ddim import Scheduler, DDIMScheduler 2 | -------------------------------------------------------------------------------- /sd_fused/scheduler/ddim.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Optional 3 | 4 | import math 5 | 6 | import torch 7 | from torch import Tensor 8 | 9 | from ..layers.base.types import Device 10 | from ..utils.tensors import to_tensor 11 | from .scheduler import Scheduler 12 | 13 | TRAINED_STEPS = 1_000 14 | BETA_BEGIN = 0.00085 15 | BETA_END = 0.012 16 | POWER = 2 17 | 18 | 19 | class DDIMScheduler(Scheduler): 20 | """Denoising Diffusion Implicit Models scheduler.""" 21 | 22 | # https://arxiv.org/abs/2010.02502 23 | 24 | ᾱ: Tensor 25 | ϖ: Tensor 26 | σ: Tensor 27 | 28 | @property # TODO Type 29 | def info(self) -> dict[str, str | int]: 30 | return dict( 31 | name=self.__class__.__qualname__, 32 | steps=self.steps, 33 | skip_timestep=self.skip_timestep, 34 | ) 35 | 36 | def __init__( 37 | self, 38 | steps: int, 39 | shape: tuple[int, ...], 40 | seeds: list[int], 41 | strength: Optional[float] = None, # img2img/inpainting 42 | device: Optional[Device] = None, 43 | dtype: torch.dtype = torch.float32, 44 | ) -> None: 45 | super().__init__(steps, shape, seeds, strength, device, dtype) 46 | 47 | assert steps <= TRAINED_STEPS 48 | 49 | # scheduler betas and alphas 50 | β_begin = math.pow(BETA_BEGIN, 1 / POWER) 51 | β_end = math.pow(BETA_END, 1 / POWER) 52 | β = torch.linspace(β_begin, β_end, TRAINED_STEPS).pow(POWER) 53 | β = β.to(torch.float64) # extra-precision 54 | 55 | # increase steps by 1 to account last timestep 56 | steps += 1 57 | 58 | # trimmed timesteps for selection 59 | timesteps = torch.linspace(TRAINED_STEPS - 1, 0, steps).ceil().long() 60 | 61 | # cummulative ᾱ trimmed 62 | α = 1 - β 63 | ᾱ = α.cumprod(dim=0) 64 | ᾱ /= ᾱ.max() # makes last-value=1 65 | ᾱ = ᾱ[timesteps] 66 | ϖ = 1 - ᾱ 67 | del α, β # reminder that is not used anymore 68 | 69 | # standard deviation, eq (16) 70 | σ = torch.sqrt(ϖ[1:] / ϖ[:-1] * (1 - ᾱ[:-1] / ᾱ[1:])) 71 | 72 | # use device/dtype 73 | self.timesteps = timesteps.to(device=device) 74 | self.ᾱ = ᾱ.to(device=device, dtype=dtype) 75 | self.ϖ = ϖ.to(device=device, dtype=dtype) 76 | self.σ = σ.to(device=device, dtype=dtype) 77 | 78 | def add_noise( 79 | self, 80 | latents: Tensor, 81 | noise: Tensor, 82 | index: int, 83 | ) -> Tensor: 84 | # eq 4 85 | return latents * self.ᾱ[index].sqrt() + noise * self.ϖ[index].sqrt() 86 | 87 | def step( 88 | self, 89 | pred_noise: Tensor, 90 | latents: Tensor, 91 | index: int, 92 | etas: Optional[float | Tensor] = None, 93 | ) -> Tensor: 94 | etas = to_tensor(etas, self.device, self.dtype, add_spatial=True) 95 | 96 | # eq (12) part 1 97 | pred_latent = latents - self.ϖ[index].sqrt() * pred_noise 98 | pred_latent /= self.ᾱ[index].sqrt() 99 | 100 | # eq (12) part 2 101 | temp = 1 - self.ᾱ[index + 1] - self.σ[index].mul(etas).square() 102 | pred_dir = temp.abs().sqrt() * pred_noise 103 | 104 | # eq (12) part 3 105 | noise = self.noise[index] * self.σ[index] * etas 106 | 107 | # full eq (12) 108 | latents = pred_latent * self.ᾱ[index + 1].sqrt() + pred_dir + noise 109 | 110 | return latents 111 | 112 | def __repr__(self) -> str: 113 | name = self.__class__.__qualname__ 114 | 115 | return f"{name}(steps={self.steps})" 116 | -------------------------------------------------------------------------------- /sd_fused/scheduler/scheduler.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Optional 3 | 4 | from abc import ABC, abstractmethod 5 | 6 | import math 7 | 8 | import torch 9 | from torch import Tensor 10 | 11 | from ..layers.base.types import Device 12 | from ..models import UNet2DConditional 13 | from ..utils.tensors import generate_noise 14 | 15 | 16 | class Scheduler(ABC): 17 | """Base-class for all schedulers.""" 18 | 19 | timesteps: Tensor 20 | 21 | def __init__( 22 | self, 23 | steps: int, 24 | shape: tuple[int, ...], 25 | seeds: list[int], 26 | strength: Optional[float] = None, # img2img/inpainting 27 | device: Optional[Device] = None, 28 | dtype: torch.dtype = torch.float32, 29 | ) -> None: 30 | if strength is not None: 31 | assert 0 < strength <= 1 32 | 33 | self.steps = steps 34 | self.shape = shape 35 | self.seeds = seeds 36 | self.strength = strength 37 | self.device = device 38 | self.dtype = dtype 39 | 40 | # TODO add sub-seeds 41 | self.noise = generate_noise(shape, seeds, device, dtype, steps) 42 | 43 | @property 44 | def skip_timestep(self) -> int: 45 | """Text-to-Image generation starting timestep.""" 46 | 47 | if self.strength is None: 48 | return 0 49 | 50 | return math.ceil(self.steps * (1 - self.strength)) 51 | 52 | @abstractmethod 53 | def add_noise(self, latents: Tensor, noise: Tensor, index: int) -> Tensor: 54 | """Add noise for a timestep.""" 55 | 56 | def prepare_latents( 57 | self, 58 | image_latents: Optional[Tensor] = None, 59 | mask_latents: Optional[Tensor] = None, # ! old-stype inpainting 60 | masked_image_latents: Optional[Tensor] = None, # ! new-style inpainting 61 | ) -> Tensor: 62 | """Prepare initial latents for generation.""" 63 | 64 | noise = self.noise[0] 65 | 66 | if image_latents is None: 67 | return noise 68 | 69 | if mask_latents is None and masked_image_latents is None: 70 | return self.add_noise(image_latents, noise, self.skip_timestep) 71 | 72 | # TODO inpainting 73 | raise NotImplementedError 74 | 75 | @abstractmethod 76 | def step( 77 | self, 78 | pred_noise: Tensor, 79 | latents: Tensor, 80 | index: int, 81 | **kwargs, # ! type 82 | ) -> Tensor: 83 | """Get the previous timestep for the latents.""" 84 | 85 | @torch.no_grad() 86 | def pred_noise( 87 | self, 88 | unet: UNet2DConditional, 89 | latents: Tensor, 90 | timestep: int, 91 | context: Tensor, 92 | weights: Optional[Tensor], 93 | scale: Optional[Tensor], 94 | unconditional: bool, 95 | low_ram: bool, 96 | ) -> Tensor: 97 | """Noise prediction for a given timestep.""" 98 | 99 | if unconditional: 100 | assert scale is None 101 | 102 | return unet(latents, timestep, context, weights) 103 | 104 | assert scale is not None 105 | 106 | if low_ram: 107 | negative_context, prompt_context = context.chunk(2, dim=0) 108 | if weights is not None: 109 | negative_weight, prompt_weight = weights.chunk(2, dim=0) 110 | else: 111 | negative_weight = prompt_weight = None 112 | 113 | pred_noise_prompt = unet(latents, timestep, prompt_context, prompt_weight) 114 | pred_noise_negative = unet(latents, timestep, negative_context, negative_weight) 115 | else: 116 | latents = torch.cat([latents] * 2, dim=0) 117 | 118 | pred_noise_all = unet(latents, timestep, context, weights) 119 | pred_noise_negative, pred_noise_prompt = pred_noise_all.chunk(2, dim=0) 120 | 121 | scale = scale[:, None, None, None] # add fake channel/spatial dimensions 122 | 123 | latents = pred_noise_negative + (pred_noise_prompt - pred_noise_negative) * scale 124 | 125 | return latents 126 | -------------------------------------------------------------------------------- /sd_fused/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tfernd/sd-fused/50570f983cc00dd4bd0dc415d0b515da638311b0/sd_fused/utils/__init__.py -------------------------------------------------------------------------------- /sd_fused/utils/cuda/__init__.py: -------------------------------------------------------------------------------- 1 | from .clear_cuda import clear_cuda 2 | from .free_memory import free_memory 3 | -------------------------------------------------------------------------------- /sd_fused/utils/cuda/clear_cuda.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import gc 4 | import torch 5 | 6 | 7 | def clear_cuda() -> None: 8 | """Clear CUDA memory and garbage collection.""" 9 | 10 | gc.collect() 11 | if torch.cuda.is_available(): 12 | torch.cuda.empty_cache() 13 | torch.cuda.ipc_collect() 14 | -------------------------------------------------------------------------------- /sd_fused/utils/cuda/free_memory.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import torch 4 | 5 | 6 | def free_memory() -> int: 7 | """Amount of free memory available.""" 8 | 9 | assert torch.cuda.is_available() 10 | 11 | stats = torch.cuda.memory_stats() 12 | 13 | reserved = stats["reserved_bytes.all.current"] 14 | active = stats["active_bytes.all.current"] 15 | free = torch.cuda.mem_get_info()[0] # type: ignore 16 | 17 | free += reserved - active 18 | 19 | return free 20 | -------------------------------------------------------------------------------- /sd_fused/utils/diverse/__init__.py: -------------------------------------------------------------------------------- 1 | from .product_args import product_args 2 | from .separate import separate 3 | from .single import single 4 | from .to_list import to_list 5 | -------------------------------------------------------------------------------- /sd_fused/utils/diverse/product_args.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Optional 3 | 4 | from itertools import product 5 | 6 | from ..typing import MaybeIterable 7 | from .to_list import to_list 8 | 9 | # TODO types! 10 | def product_args(**kwargs: Optional[MaybeIterable]) -> list[dict]: 11 | """All possible combintations of kwargs""" 12 | 13 | args = list(kwargs.values()) 14 | keys = list(kwargs.keys()) 15 | 16 | args = tuple(map(to_list, args)) 17 | perms = list(product(*args)) 18 | 19 | return [dict(zip(keys, args)) for args in perms] 20 | -------------------------------------------------------------------------------- /sd_fused/utils/diverse/separate.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Optional 3 | 4 | from torch import Tensor 5 | 6 | from ...utils.typing import TypeVar 7 | 8 | T = TypeVar("T", Tensor, int, float, str) 9 | 10 | 11 | def separate(xs: list[T | None]) -> Optional[list[T]]: 12 | """Separate a list of that may containt `None` into a list 13 | that does not contain `None` or is `None` itself.""" 14 | 15 | assert len(xs) >= 1 16 | 17 | if xs[0] is None: 18 | for x in xs: 19 | assert x is None 20 | 21 | return None 22 | 23 | for x in xs: 24 | assert x is not None 25 | 26 | return [x for x in xs if x is not None] 27 | -------------------------------------------------------------------------------- /sd_fused/utils/diverse/single.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Optional, overload 3 | 4 | import torch 5 | 6 | from ...utils.typing import TypeVar 7 | from ...layers.base.types import Device 8 | 9 | T = TypeVar("T", int, float, Device, torch.dtype) 10 | 11 | 12 | @overload 13 | def single(x: set[T]) -> T: 14 | ... 15 | 16 | 17 | @overload 18 | def single(x: set[Optional[T]]) -> Optional[T]: 19 | ... 20 | 21 | 22 | def single(x: set[Optional[T]] | set[T]) -> Optional[T]: 23 | """Get the single element of a set.""" 24 | 25 | assert len(x) == 1 26 | 27 | return x.pop() 28 | -------------------------------------------------------------------------------- /sd_fused/utils/diverse/to_list.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Optional, overload 3 | 4 | from PIL import Image 5 | from pathlib import Path 6 | 7 | from ..typing import MaybeIterable, T 8 | 9 | 10 | @overload 11 | def to_list(x: MaybeIterable[T]) -> list[T]: 12 | ... 13 | 14 | 15 | @overload 16 | def to_list(x: Optional[MaybeIterable[T]]) -> tuple[None] | list[T]: 17 | ... 18 | 19 | 20 | def to_list(x: Optional[MaybeIterable[T]]) -> tuple[None] | list[T]: 21 | """Convert a `MaybeIterable` into a list.""" 22 | 23 | if x is None: 24 | return (None,) 25 | 26 | if isinstance(x, (int, float, str, Path, Image.Image)): 27 | return [x] # type: ignore 28 | 29 | return list(x) # type: ignore 30 | -------------------------------------------------------------------------------- /sd_fused/utils/image/__init__.py: -------------------------------------------------------------------------------- 1 | from .image_base64 import image_base64 2 | from .image2tensor import image2tensor 3 | from .tensor2images import tensor2images 4 | from .open_image import open_image 5 | from .image_size import image_size 6 | 7 | from .types import ImageType, ResizeModes 8 | -------------------------------------------------------------------------------- /sd_fused/utils/image/image2tensor.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Optional 3 | 4 | from functools import partial 5 | from PIL import Image 6 | 7 | import math 8 | import numpy as np 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | from torch import Tensor 13 | 14 | from einops import rearrange 15 | 16 | from ...layers.base.types import Device 17 | from .open_image import open_image 18 | from .types import ResizeModes, ImageType 19 | 20 | 21 | def image2tensor( 22 | path: ImageType, 23 | height: Optional[int] = None, 24 | width: Optional[int] = None, 25 | mode: Optional[ResizeModes] = None, 26 | device: Optional[Device] = None, 27 | rescale: Optional[float] = None, 28 | ) -> Tensor: 29 | """Open an image/url as pytorch batched-Tensor (B=1 C H W).""" 30 | 31 | img = open_image(path) 32 | resize = partial(img.resize, resample=Image.LANCZOS) 33 | 34 | if height is None or width is None: 35 | assert height is None and width is None 36 | 37 | width, height = img.size 38 | if rescale is not None: 39 | width = math.ceil(width * rescale) 40 | height = math.ceil(height * rescale) 41 | 42 | if mode == "resize" or rescale is not None: 43 | img = resize((width, height)) 44 | else: 45 | assert mode is not None 46 | 47 | ar = width / height 48 | src_ar = img.width / img.height 49 | 50 | diff = ar > src_ar 51 | if mode == "resize-pad": 52 | diff = not diff 53 | 54 | w = math.ceil(width if diff else height * src_ar) 55 | h = math.ceil(height if not diff else width / src_ar) 56 | 57 | img = resize((w, h)) 58 | 59 | data = torch.from_numpy(np.asarray(img).copy()).to(device) 60 | data = rearrange(data, "H W C -> 1 C H W") 61 | 62 | # crop/padding size 63 | h, w = data.shape[-2:] 64 | dh, dw = height - h, width - w 65 | 66 | pad = (dw // 2, dw - dw // 2, dh // 2, dh - dh // 2) 67 | data = F.pad(data, pad, value=0) 68 | 69 | return data 70 | -------------------------------------------------------------------------------- /sd_fused/utils/image/image_base64.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import base64 4 | from io import BytesIO 5 | 6 | from .open_image import open_image 7 | from .types import ImageType 8 | 9 | 10 | def image_base64(path: ImageType) -> str: 11 | """Encodes an image as base64 (JPGE) string.""" 12 | 13 | img = open_image(path) 14 | 15 | buffered = BytesIO() 16 | img.save(buffered, format="JPEG") 17 | code = base64.b64encode(buffered.getvalue()) 18 | code = code.decode("ascii") 19 | 20 | return f"data:image/jpg;base64,{code}" 21 | -------------------------------------------------------------------------------- /sd_fused/utils/image/image_size.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from .open_image import open_image 4 | from .types import ImageType 5 | 6 | 7 | def image_size(path: ImageType) -> tuple[int, int]: 8 | img = open_image(path) 9 | 10 | width, height = img.size 11 | 12 | return height, width 13 | -------------------------------------------------------------------------------- /sd_fused/utils/image/open_image.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from PIL import Image 4 | import requests 5 | 6 | from io import BytesIO 7 | import validators 8 | 9 | from .types import ImageType 10 | 11 | 12 | def open_image(path: ImageType) -> Image.Image: 13 | """Open a path or url as an image.""" 14 | 15 | if isinstance(path, Image.Image): 16 | img = path 17 | elif isinstance(path, str) and validators.url(path): # type: ignore 18 | response = requests.get(path) 19 | img = Image.open(BytesIO(response.content)) 20 | else: 21 | img = Image.open(path) 22 | img = img.convert("RGB") 23 | 24 | # TODO get alpha channel as mask? 25 | 26 | return img 27 | -------------------------------------------------------------------------------- /sd_fused/utils/image/tensor2images.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from PIL import Image 4 | 5 | import torch 6 | from torch import Tensor 7 | from einops import rearrange 8 | 9 | 10 | def tensor2images(data: Tensor) -> list[Image.Image]: 11 | """Creates a list of images according to the batch size.""" 12 | 13 | assert data.dtype == torch.uint8 14 | 15 | data = rearrange(data, "B C H W -> B H W C").cpu().numpy() 16 | 17 | return [Image.fromarray(v) for v in data] 18 | -------------------------------------------------------------------------------- /sd_fused/utils/image/types.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Union 3 | 4 | from PIL import Image 5 | from pathlib import Path 6 | 7 | from ..typing import Literal 8 | 9 | ResizeModes = Literal["resize", "resize-crop", "resize-pad"] 10 | ImageType = Union[str, Path, Image.Image] 11 | -------------------------------------------------------------------------------- /sd_fused/utils/parameters/__init__.py: -------------------------------------------------------------------------------- 1 | from .parameters import Parameters 2 | from .parameters_list import ParametersList 3 | 4 | from .batch_parameters import batch_parameters 5 | from .group_parameters import group_parameters 6 | -------------------------------------------------------------------------------- /sd_fused/utils/parameters/batch_parameters.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from .parameters import Parameters 4 | 5 | 6 | def batch_parameters( 7 | groups: list[list[Parameters]], 8 | batch_size: int, 9 | ) -> list[list[Parameters]]: 10 | """Separate groups into batches of an specified size.""" 11 | 12 | batched_parameters: list[list[Parameters]] = [] 13 | for group in groups: 14 | for i in range(0, len(group), batch_size): 15 | s = slice(i, min(i + batch_size, len(group))) 16 | 17 | batched_parameters.append(group[s]) 18 | 19 | return batched_parameters 20 | -------------------------------------------------------------------------------- /sd_fused/utils/parameters/group_parameters.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from .parameters import Parameters 4 | 5 | 6 | def group_parameters(parameters: list[Parameters]) -> list[list[Parameters]]: 7 | """Group parameters that can share a batch.""" 8 | 9 | groups: list[list[Parameters]] = [] 10 | for parameter in parameters: 11 | if len(groups) == 0: 12 | groups.append([parameter]) 13 | continue 14 | 15 | can_share = False 16 | for group in groups: 17 | can_share = True 18 | for other_parameter in group: 19 | can_share &= parameter.can_share_batch(other_parameter) 20 | 21 | if can_share: 22 | group.append(parameter) 23 | break 24 | 25 | if not can_share: 26 | groups.append([parameter]) 27 | 28 | return groups 29 | -------------------------------------------------------------------------------- /sd_fused/utils/parameters/parameters.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Optional 3 | from typing_extensions import Self 4 | 5 | from dataclasses import dataclass, field 6 | from PIL.PngImagePlugin import PngInfo 7 | 8 | import torch 9 | from torch import Tensor 10 | 11 | from ...layers.base.types import Device 12 | from ..image import ResizeModes, image2tensor, image_base64, ImageType 13 | 14 | SAVE_ARGS = [ 15 | "steps", 16 | "height", 17 | "width", 18 | "seed", 19 | "negative_prompt", 20 | "scale", 21 | "eta", 22 | "sub_seed", 23 | "seed_interpolation", 24 | "prompt", 25 | "strength", 26 | "mode", 27 | "image_base64", 28 | "mask_base64", 29 | ] 30 | 31 | 32 | @dataclass 33 | class Parameters: 34 | """Hold information from a single image generation.""" 35 | 36 | steps: int 37 | height: int 38 | width: int 39 | seed: int 40 | negative_prompt: str 41 | scale: Optional[float] = None 42 | eta: Optional[float] = None # DDIM 43 | sub_seed: Optional[int] = None 44 | seed_interpolation: Optional[float] = None 45 | prompt: Optional[str] = None 46 | img: Optional[ImageType] = None 47 | mask: Optional[ImageType] = None 48 | strength: Optional[float] = None 49 | mode: Optional[ResizeModes] = None 50 | 51 | device: Optional[Device] = field(default=None, repr=False) 52 | dtype: Optional[torch.dtype] = field(default=None, repr=False) 53 | 54 | def __post_init__(self) -> None: 55 | assert self.height % 8 == 0 56 | assert self.width % 8 == 0 57 | 58 | if self.img is None: 59 | assert self.mode is None 60 | assert self.strength is None 61 | assert self.mask is None 62 | else: 63 | assert self.mode is not None 64 | assert self.strength is not None 65 | 66 | if self.sub_seed is None: 67 | assert self.seed_interpolation is None 68 | else: 69 | assert self.seed_interpolation is not None 70 | 71 | if self.prompt is None: 72 | assert self.scale is None 73 | else: 74 | assert self.scale is not None 75 | 76 | @property 77 | def unconditional(self) -> bool: 78 | return self.prompt is None 79 | 80 | def can_share_batch(self, other: Self) -> bool: 81 | """Determine if two parameters can share a batch.""" 82 | 83 | value = self.steps == other.steps 84 | value &= self.strength == other.strength 85 | value &= self.height == other.height 86 | value &= self.width == other.width 87 | value &= self.unconditional == other.unconditional 88 | 89 | return value 90 | 91 | @property 92 | def image_data(self) -> Optional[Tensor]: 93 | """Image data as a Tensor.""" 94 | 95 | if self.img is None or self.mode is None: 96 | return 97 | 98 | return image2tensor(self.img, self.height, self.width, self.mode, self.device) 99 | 100 | @property 101 | def mask_data(self) -> Optional[Tensor]: 102 | """Mask data as a Tensor.""" 103 | 104 | if self.mask is None or self.mode is None: 105 | return 106 | 107 | data = image2tensor(self.mask, self.height, self.width, self.mode, self.device) 108 | 109 | # single-channel 110 | data = data.float().mean(dim=1, keepdim=True) 111 | 112 | return data >= 255 / 2 # bool-Tensor 113 | 114 | @property 115 | def image_base64(self) -> Optional[str]: 116 | """Image data as a base64 string.""" 117 | 118 | if self.img is None: 119 | return 120 | 121 | return image_base64(self.img) 122 | 123 | @property 124 | def mask_base64(self) -> Optional[str]: 125 | """Mask data as a base64 string.""" 126 | 127 | if self.mask is None: 128 | return 129 | 130 | return image_base64(self.mask) 131 | 132 | @property 133 | def png_info(self) -> PngInfo: 134 | """PNG metadata.""" 135 | 136 | info = PngInfo() 137 | for key in SAVE_ARGS: 138 | value = getattr(self, key) 139 | 140 | if value is None: 141 | continue 142 | if isinstance(value, (int, float)): 143 | value = str(value) 144 | 145 | info.add_text(f"SD {key}", value) 146 | 147 | return info 148 | -------------------------------------------------------------------------------- /sd_fused/utils/parameters/parameters_list.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Iterator, Optional 3 | 4 | from functools import lru_cache 5 | 6 | import torch 7 | from torch import Tensor 8 | 9 | from ...layers.base.types import Device 10 | from ..diverse import separate, single 11 | from .parameters import Parameters 12 | 13 | 14 | class ParametersList: 15 | """Hold information from a many image generations.""" 16 | 17 | def __init__(self, parameters: list[Parameters]) -> None: 18 | self.parameters = parameters 19 | 20 | def __len__(self) -> int: 21 | return len(self.parameters) 22 | 23 | def __iter__(self) -> Iterator[Parameters]: 24 | return iter(self.parameters) 25 | 26 | @property 27 | def prompts(self) -> Optional[list[str]]: 28 | prompts = [p.prompt for p in self.parameters] 29 | prompts = separate(prompts) 30 | 31 | return prompts 32 | 33 | @property 34 | def negative_prompts(self) -> list[str]: 35 | return [p.negative_prompt for p in self.parameters] 36 | 37 | @property 38 | def unconditional(self) -> bool: 39 | return self.prompts is None 40 | 41 | @property 42 | def seeds(self) -> list[int]: 43 | return [p.seed for p in self.parameters] 44 | 45 | @property 46 | def sub_seeds(self) -> Optional[list[int]]: 47 | sub_seeds = [p.sub_seed for p in self.parameters] 48 | sub_seeds = separate(sub_seeds) 49 | 50 | return sub_seeds 51 | 52 | @property 53 | def seeds_interpolation(self) -> Optional[Tensor]: 54 | seed_interpolations = [p.seed_interpolation for p in self.parameters] 55 | seed_interpolations = separate(seed_interpolations) 56 | 57 | if seed_interpolations is None: 58 | return None 59 | 60 | return torch.tensor(seed_interpolations, device=self.device, dtype=self.dtype) 61 | 62 | @property 63 | def height(self) -> int: 64 | height = set(p.height for p in self.parameters) 65 | 66 | return single(height) 67 | 68 | @property 69 | def width(self) -> int: 70 | width = set(p.width for p in self.parameters) 71 | 72 | return single(width) 73 | 74 | def shape(self, channels: int) -> tuple[int, int, int, int]: 75 | return (len(self), channels, self.height // 8, self.width // 8) 76 | 77 | @property 78 | def steps(self) -> int: 79 | steps = set(p.steps for p in self.parameters) 80 | 81 | return single(steps) 82 | 83 | @property 84 | def strength(self) -> Optional[float]: 85 | strength = set(p.strength for p in self.parameters) 86 | 87 | return single(strength) 88 | 89 | @property 90 | def device(self) -> Optional[Device]: 91 | devices = set(p.device for p in self.parameters) 92 | 93 | return single(devices) 94 | 95 | @property 96 | def dtype(self) -> Optional[torch.dtype]: 97 | dtypes = set(p.dtype for p in self.parameters) 98 | 99 | return single(dtypes) 100 | 101 | @property 102 | def scales(self) -> Optional[Tensor]: 103 | scales = [p.scale for p in self.parameters] 104 | scales = separate(scales) 105 | 106 | if scales is None: 107 | return None 108 | 109 | return torch.tensor(scales, device=self.device, dtype=self.dtype) 110 | 111 | @property 112 | def etas(self) -> Optional[Tensor]: 113 | etas = [p.eta for p in self.parameters] 114 | etas = separate(etas) 115 | 116 | if etas is None: 117 | return None 118 | 119 | return torch.tensor(etas, device=self.device, dtype=self.dtype) 120 | 121 | @property 122 | @lru_cache(None) 123 | def images_data(self) -> Optional[Tensor]: 124 | data = [p.image_data for p in self.parameters] 125 | data = separate(data) 126 | 127 | if data is None: 128 | return None 129 | 130 | return torch.cat(data, dim=0) 131 | 132 | @property 133 | @lru_cache(None) 134 | def masks_data(self) -> Optional[Tensor]: 135 | data = [p.mask_data for p in self.parameters] 136 | data = separate(data) 137 | 138 | if data is None: 139 | return None 140 | 141 | return torch.cat(data, dim=0) 142 | 143 | @property 144 | @lru_cache(None) 145 | def masked_images_data(self) -> Optional[Tensor]: 146 | if self.images_data is None or self.masks_data is None: 147 | return None 148 | 149 | return self.images_data * (~self.masks_data) 150 | -------------------------------------------------------------------------------- /sd_fused/utils/tensors/__init__.py: -------------------------------------------------------------------------------- 1 | from .generate_noise import generate_noise, random_seeds 2 | from .normalize import normalize, denormalize 3 | from .slerp import slerp 4 | from .to_tensor import to_tensor 5 | -------------------------------------------------------------------------------- /sd_fused/utils/tensors/generate_noise.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Optional 3 | 4 | import random 5 | 6 | import torch 7 | from torch import Tensor 8 | 9 | from ...layers.base.types import Device 10 | 11 | 12 | def generate_noise( 13 | shape: tuple[int, ...], 14 | seeds: list[int], 15 | device: Optional[Device] = None, 16 | dtype: Optional[torch.dtype] = None, 17 | repeat: int = 1, # TODO: support repeat 18 | ) -> Tensor: 19 | """Generate random noise with individual seeds per batch.""" 20 | 21 | batch_size, *rest = shape 22 | assert len(seeds) == batch_size 23 | 24 | extended_shape = (repeat, batch_size, *rest) 25 | 26 | noise = torch.empty(extended_shape) 27 | for n in range(repeat): 28 | for i, s in enumerate(seeds): 29 | generator = torch.Generator() 30 | generator.manual_seed(s + n) 31 | 32 | noise[n, i] = torch.randn(*rest, generator=generator) 33 | 34 | noise = noise.to(device=device, dtype=dtype) 35 | 36 | if repeat == 1: 37 | return noise[0] 38 | 39 | return noise 40 | 41 | 42 | def random_seeds(size: int) -> list[int]: 43 | """Generate random seeds.""" 44 | 45 | return [random.randint(0, 2**32 - 1) for _ in range(size)] 46 | -------------------------------------------------------------------------------- /sd_fused/utils/tensors/normalize.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Optional 3 | 4 | import torch 5 | from torch import Tensor 6 | 7 | 8 | def normalize(data: Tensor, dtype: Optional[torch.dtype] = None) -> Tensor: 9 | """Normalize a byte-Tensor to the [-1, 1] range.""" 10 | 11 | assert data.dtype == torch.uint8 12 | 13 | return data.div(255 / 2).sub(1).to(dtype) 14 | 15 | 16 | def denormalize(data: Tensor) -> Tensor: 17 | """Denormalize a tensor of the range [-1, 1] to a byte-Tensor.""" 18 | 19 | assert data.requires_grad == False 20 | 21 | return data.add(1).mul(255 / 2).clamp(0, 255).byte() 22 | -------------------------------------------------------------------------------- /sd_fused/utils/tensors/slerp.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import torch 4 | from torch import Tensor 5 | 6 | 7 | # TODO DEBUG 8 | def slerp(a: Tensor, b: Tensor, t: Tensor) -> Tensor: 9 | "Spherical linear interpolation." 10 | 11 | # https://en.wikipedia.org/wiki/Slerp 12 | 13 | # 0 <= t <= 1 14 | assert t.ge(0).all() and t.le(1).all() 15 | 16 | assert a.shape == b.shape 17 | assert t.shape[0] == a.shape[0] 18 | assert a.ndim == 4 19 | assert t.ndim == 1 20 | 21 | t = t[:, None, None, None] 22 | 23 | # ? that's how you normalize? 24 | an = a / a.norm(dim=1, keepdim=True) 25 | bn = b / b.norm(dim=1, keepdim=True) 26 | 27 | Ω = an.mul(bn).sum(1).clamp(-1, 1).acos() 28 | 29 | den = torch.sin(Ω) 30 | 31 | A = torch.sin((1 - t) * Ω) 32 | B = torch.sin(t * Ω) 33 | 34 | return (A * a + B * b) / den 35 | -------------------------------------------------------------------------------- /sd_fused/utils/tensors/to_tensor.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Optional 3 | 4 | import torch 5 | from torch import Tensor 6 | 7 | from ...layers.base.types import Device 8 | 9 | 10 | def to_tensor( 11 | x: Optional[int | float | Tensor], 12 | device: Optional[Device] = None, 13 | dtype: Optional[torch.dtype] = None, 14 | *, 15 | add_spatial: bool = False, 16 | ) -> Tensor: 17 | """Convert a number to a Tensor with fake channel/spatial dimensions.""" 18 | 19 | if x is None: 20 | x = 0 21 | 22 | if isinstance(x, (int, float)): 23 | x = torch.tensor([x], device=device, dtype=dtype) 24 | else: 25 | assert x.ndim == 1 26 | x = x.to(device=device, dtype=dtype) 27 | 28 | if add_spatial: 29 | x = x.view(-1, 1, 1, 1) 30 | 31 | return x 32 | -------------------------------------------------------------------------------- /sd_fused/utils/typing.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import TypeVar, Union, Iterable 3 | from typing_extensions import Unpack 4 | 5 | import sys 6 | 7 | if sys.version_info >= (3, 8): 8 | from typing import Literal, Final, Protocol 9 | else: 10 | from typing_extensions import Literal, Final, Protocol 11 | 12 | if sys.version_info >= (3, 11): 13 | from typing import TypeVarTuple 14 | else: 15 | from typing_extensions import TypeVarTuple 16 | 17 | from .image import ImageType 18 | 19 | T = TypeVar("T", int, float, str, ImageType) 20 | MaybeIterable = Union[T, Iterable[T]] 21 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # type: ignore 2 | from distutils.core import setup 3 | 4 | from pathlib import Path 5 | 6 | from sd_fused import StableDiffusion 7 | 8 | packages = list(set(str(p.parent) for p in Path("sd_fused").rglob("*.py"))) 9 | 10 | with open("./requirements.txt") as handle: 11 | requirements = [l.strip() for l in handle.read().split()] 12 | 13 | 14 | setup( 15 | name="sd-fused", 16 | version=StableDiffusion.version, 17 | description="Stable-Diffusion + Fused CUDA kernels", 18 | author="Thales Fernandes", 19 | author_email="thalesfdfernandes@gmail.com", 20 | url="https://github.com/tfernd/sd-fused", 21 | python_requires=">=3.7", # TODO Less? 22 | # packages=find_packages(exclude=["/cuda"]), 23 | packages=packages, 24 | install_requires=requirements, 25 | ) 26 | --------------------------------------------------------------------------------