├── README.md
├── assets
└── text2img.png
├── paths.py
├── requirements.txt
├── sd_fused
├── __init__.py
├── app
│ ├── __init__.py
│ ├── helpers.py
│ ├── sd.py
│ └── setup.py
├── clip
│ ├── __init__.py
│ ├── clip_embedding.py
│ ├── container.py
│ ├── parser
│ │ ├── __init__.py
│ │ ├── add_delimiter4words.py
│ │ ├── add_split_maker4emphasis.py
│ │ ├── clean_spaces.py
│ │ ├── diffuse_prompt.py
│ │ ├── expand_delimiters.py
│ │ └── prompt_choices.py
│ └── text_segment.py
├── layers
│ ├── __init__.py
│ ├── activation
│ │ ├── __init__.py
│ │ ├── geglu.py
│ │ └── silu.py
│ ├── auto_encoder
│ │ ├── __init__.py
│ │ ├── decoder.py
│ │ └── encoder.py
│ ├── base
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── module.py
│ │ ├── module_list.py
│ │ ├── sequential.py
│ │ └── types.py
│ ├── basic
│ │ ├── __init__.py
│ │ ├── conv2d.py
│ │ ├── group_norm.py
│ │ ├── identity.py
│ │ ├── layer_norm.py
│ │ └── linear.py
│ ├── blocks
│ │ ├── __init__.py
│ │ ├── attention
│ │ │ ├── __init__.py
│ │ │ ├── base_attention.py
│ │ │ ├── compute
│ │ │ │ ├── __init__.py
│ │ │ │ ├── attention.py
│ │ │ │ ├── auto_chunk_size.py
│ │ │ │ ├── chunked_attention.py
│ │ │ │ ├── flash_attention.py
│ │ │ │ ├── join_spatial_dim.py
│ │ │ │ ├── scale_qk.py
│ │ │ │ ├── standard_attention.py
│ │ │ │ ├── tome.py
│ │ │ │ └── weighted_values.py
│ │ │ ├── cross_attention.py
│ │ │ └── self_attention.py
│ │ ├── basic
│ │ │ ├── __init__.py
│ │ │ ├── gn_conv.py
│ │ │ ├── gn_silu_conv.py
│ │ │ └── ln_geglu_linear.py
│ │ ├── spatial
│ │ │ ├── __init__.py
│ │ │ ├── ae
│ │ │ │ ├── __init__.py
│ │ │ │ ├── down_encoder.py
│ │ │ │ └── up_decoder.py
│ │ │ ├── base
│ │ │ │ ├── __init__.py
│ │ │ │ ├── down.py
│ │ │ │ └── up.py
│ │ │ ├── cross_attention
│ │ │ │ ├── __init__.py
│ │ │ │ ├── down.py
│ │ │ │ └── up.py
│ │ │ ├── output_states.py
│ │ │ ├── resampling
│ │ │ │ ├── __init__.py
│ │ │ │ ├── downsample2d.py
│ │ │ │ └── upsample2d.py
│ │ │ ├── resnet.py
│ │ │ └── unet_mid
│ │ │ │ ├── __init__.py
│ │ │ │ ├── cross_attention.py
│ │ │ │ └── self_attention.py
│ │ └── transformer
│ │ │ ├── __init__.py
│ │ │ ├── basic_transformer.py
│ │ │ └── spatial_transformer.py
│ ├── distribution
│ │ ├── __init__.py
│ │ └── diag_gaussian.py
│ ├── embedding
│ │ ├── __init__.py
│ │ ├── time_step_emb.py
│ │ └── time_steps.py
│ ├── external
│ │ ├── __init__.py
│ │ └── rearrange.py
│ └── modifiers
│ │ ├── __init__.py
│ │ └── half_weights.py
├── models
│ ├── __init__.py
│ ├── ae_kl.py
│ ├── config.py
│ ├── convert
│ │ ├── __init__.py
│ │ ├── states.py
│ │ ├── unet
│ │ │ └── diffusers2fused.py
│ │ └── vae
│ │ │ ├── diffusers2fused.py
│ │ │ └── sd2diffusers.py
│ ├── modifiers
│ │ ├── __init__.py
│ │ ├── flash_attention.py
│ │ ├── half_weights.py
│ │ ├── split_attention.py
│ │ └── tome.py
│ └── unet_conditional.py
├── scheduler
│ ├── __init__.py
│ ├── ddim.py
│ └── scheduler.py
└── utils
│ ├── __init__.py
│ ├── cuda
│ ├── __init__.py
│ ├── clear_cuda.py
│ └── free_memory.py
│ ├── diverse
│ ├── __init__.py
│ ├── product_args.py
│ ├── separate.py
│ ├── single.py
│ └── to_list.py
│ ├── image
│ ├── __init__.py
│ ├── image2tensor.py
│ ├── image_base64.py
│ ├── image_size.py
│ ├── open_image.py
│ ├── tensor2images.py
│ └── types.py
│ ├── parameters
│ ├── __init__.py
│ ├── batch_parameters.py
│ ├── group_parameters.py
│ ├── parameters.py
│ └── parameters_list.py
│ ├── tensors
│ ├── __init__.py
│ ├── generate_noise.py
│ ├── normalize.py
│ ├── slerp.py
│ └── to_tensor.py
│ └── typing.py
└── setup.py
/README.md:
--------------------------------------------------------------------------------
1 | # Stable-Diffusion + Fused CUDA kernels = FUN!
2 |
3 | ## Introduction
4 |
5 | This is a re-written implementation of Stable-Diffusion (SD) based on the original [diffusers](https://github.com/huggingface/diffusers) and [stable-diffusion](https://github.com/CompVis/stable-diffusion) repositories (all kudos for the original programmers).
6 |
7 | The goal of this reimplementation is to make it clearer, more readable, and more upgradable code that is easy to read and modify.
8 | Unfortunately, the original code is very difficult to read due to the lack of proper typing, variable naming, and other factors.
9 |
10 | ## For the inpatients:
11 |
12 | ### Emphasis
13 |
14 | Using the notation `(a few words):weight` you can give emphasis (high number), take out emphasis (small number), or even avoid the subject (negative number).
15 | The words (tokens) inside the parentheses are given a weight that is passed down to the attention calculation, enhancing, attenuating, or negative the attention to the given token.
16 |
17 | Below is a small test where the word `cyberpunk` is given a different emphasis.
18 |
19 | ```python
20 | weight = torch.linspace(-1.2, 4.2, 32).tolist()
21 | choices = '|'.join(map(str, weight))
22 |
23 | out = pipeline.generate(
24 | prompt=f"portrait, woman, cyberpunk:[{choices}], digital art, detailed, epic, beautiful",
25 | steps=24,
26 | scale=11,
27 | height=512,
28 | width=512,
29 | seed=1658926406,
30 | eta=0.6,
31 | show=True,
32 | batch_size=8,
33 | )
34 | ```
35 |
36 |
37 |
38 |
39 | ### Batched sweep
40 |
41 | Any input parameter can be passed as a list for sweeping, where any multiple combinations of sweeps are allowed.
42 | For example:
43 |
44 | ```python
45 | out = pipeline.generate(
46 | prompt="portrait, woman, cyberpunk, digital art, detailed, epic, beautiful",
47 | steps=26,
48 | height=512,
49 | width=512,
50 | seed=1331366415,
51 | eta=torch.linspace(-1, 1, 64).tolist(),
52 | show=True,
53 | batch_size=8,
54 | )
55 | ```
56 |
57 |
58 |
59 |
60 | ### Seed-Interpolations
61 |
62 | You can perform interpolation between many to one known seeds.
63 |
64 | ```python
65 | # pipeline.tome(None)
66 | out = pipeline.generate(
67 | prompt="portrait, woman, cyberpunk, digital art, detailed, epic, beautiful",
68 | steps=26,
69 | height=512,
70 | width=512,
71 | seed=3783195593,
72 | sub_seed=2148348002,
73 | interpolation=torch.linspace(0, 1, 64).tolist(),
74 | eta=0.6,
75 | show=True,
76 | batch_size=8,
77 | )
78 | ```
79 |
80 |
81 |
82 |
100 |
101 |
102 |
103 |
104 |
124 |
125 |
126 |
127 |
146 |
147 |
148 | ## Kernel fusion
149 |
150 | This is an ongoing project to fuse as many layers as possible to make it more memory friendly and faster.
151 |
152 | ## Installation
153 |
154 | ```bash
155 | pip install -U git+https://github.com/tfernd/sd-fused
156 | ```
157 |
158 | ## Text2Image generation
159 |
160 | Base code for text-to-image generation.
161 |
162 | ```python
163 | from IPython.display import display
164 | from sd_fused.app import StableDiffusion
165 |
166 | # Assuming you downloaded SD and put it in the folder below
167 | pipeline = StableDiffusion('.pretrained/stable-diffusion')
168 |
169 | # If you have a GPU with 3-4 Gb, use the line below
170 | # pipeline.set_low_ram().half_weights().cuda()
171 | pipeline.half().cuda()
172 | pipeline.split_attention(cross_attention_chunks=1)
173 | # if you have xformers installed, use the line below
174 | # pipeline.flash_attention()
175 |
176 | out = pipeline.generate(
177 | prompt='portrait of zombie, digital art, detailed, artistic',
178 | negative_prompt='old man',
179 | steps=28,
180 | scale=11,
181 | height=512,
182 | width=512,
183 | seed=42,
184 | show=True
185 | )
186 | ```
187 |
188 | 
189 |
--------------------------------------------------------------------------------
/assets/text2img.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tfernd/sd-fused/50570f983cc00dd4bd0dc415d0b515da638311b0/assets/text2img.png
--------------------------------------------------------------------------------
/paths.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import sys
4 | import modules.safe
5 |
6 | script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
7 |
8 | # Parse the --data-dir flag first so we can use it as a base for our other argument default values
9 | parser = argparse.ArgumentParser(add_help=False)
10 | parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",)
11 | cmd_opts_pre = parser.parse_known_args()[0]
12 | data_path = cmd_opts_pre.data_dir
13 | models_path = os.path.join(data_path, "models")
14 |
15 | # data_path = cmd_opts_pre.data
16 | sys.path.insert(0, script_path)
17 |
18 | # search for directory of stable diffusion in following places
19 | sd_path = None
20 | possible_sd_paths = [
21 | os.path.join(script_path, '/content/gdrive/MyDrive/sd/stablediffusion'),
22 | os.path.join(script_path, '/content/sd/stablediffusion'),
23 | '.',
24 | os.path.dirname(script_path)
25 | ]
26 | for possible_sd_path in possible_sd_paths:
27 | if os.path.exists(os.path.join(possible_sd_path, 'ldm/models/diffusion/ddpm.py')):
28 | sd_path = os.path.abspath(possible_sd_path)
29 | break
30 |
31 | assert sd_path is not None, "Couldn't find Stable Diffusion in any of: " + str(possible_sd_paths)
32 |
33 | path_dirs = [
34 | (sd_path, 'ldm', 'Stable Diffusion', []),
35 | (os.path.join(sd_path, 'src/taming-transformers'), 'taming', 'Taming Transformers', []),
36 | (os.path.join(sd_path, 'src/codeformer'), 'inference_codeformer.py', 'CodeFormer', []),
37 | (os.path.join(sd_path, 'src/blip'), 'models/blip.py', 'BLIP', []),
38 | (os.path.join(sd_path, 'src/k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
39 | ]
40 |
41 | paths = {}
42 |
43 | for d, must_exist, what, options in path_dirs:
44 | must_exist_path = os.path.abspath(os.path.join(script_path, d, must_exist))
45 | if not os.path.exists(must_exist_path):
46 | print(f"Warning: {what} not found at path {must_exist_path}", file=sys.stderr)
47 | else:
48 | d = os.path.abspath(d)
49 | if "atstart" in options:
50 | sys.path.insert(0, d)
51 | else:
52 | sys.path.append(d)
53 | paths[what] = d
54 |
55 | class Prioritize:
56 | def __init__(self, name):
57 | self.name = name
58 | self.path = None
59 |
60 | def __enter__(self):
61 | self.path = sys.path.copy()
62 | sys.path = [paths[self.name]] + sys.path
63 |
64 | def __exit__(self, exc_type, exc_val, exc_tb):
65 | sys.path = self.path
66 | self.path = None
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | ipywidgets>=7,<8
2 | typing_extensions>=4.3
3 | torch
4 | transformers>=4.24
5 | einops>=0.6
6 | tqdm>=4.64
7 | Pillow>=9.3
8 | scipy
9 | ftfy
10 | validators>=0.20
--------------------------------------------------------------------------------
/sd_fused/__init__.py:
--------------------------------------------------------------------------------
1 | from .app import StableDiffusion
2 |
--------------------------------------------------------------------------------
/sd_fused/app/__init__.py:
--------------------------------------------------------------------------------
1 | from .sd import StableDiffusion
2 |
--------------------------------------------------------------------------------
/sd_fused/app/helpers.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Optional
3 |
4 | from pathlib import Path
5 | from datetime import datetime
6 | from copy import deepcopy
7 |
8 | from PIL import Image
9 | from PIL.PngImagePlugin import PngInfo
10 |
11 | import random
12 | import torch
13 | from torch import Tensor
14 |
15 | from ..models import AutoencoderKL, UNet2DConditional
16 | from ..clip import ClipEmbedding
17 | from ..utils.tensors import slerp, generate_noise
18 |
19 | MAGIC = 0.18215
20 |
21 |
22 | class Helpers:
23 | # version: str
24 | model_name: str
25 |
26 | save_dir: Path
27 |
28 | clip: ClipEmbedding
29 | vae: AutoencoderKL
30 | unet: UNet2DConditional
31 |
32 | device: torch.device
33 | dtype: torch.dtype
34 |
35 | @property
36 | def latent_channels(self) -> int:
37 | """Latent-space channel size."""
38 |
39 | return self.unet.out_channels
40 |
41 | @property
42 | def is_true_inpainting(self) -> bool:
43 | """RunwayMl true inpainting model."""
44 |
45 | return self.unet.in_channels == 4 and self.latent_channels == 9
46 |
47 | def save_image(
48 | self,
49 | image: Image.Image,
50 | png_info: Optional[PngInfo] = None,
51 | ID: Optional[int] = None,
52 | ) -> Path:
53 | """Save the image using the provided metadata information."""
54 |
55 | now = datetime.now()
56 | timestamp = now.strftime(r"%Y-%m-%d %H-%M-%S.%f")
57 |
58 | if ID is None:
59 | ID = random.randint(0, 2**64)
60 |
61 | self.save_dir.mkdir(parents=True, exist_ok=True)
62 |
63 | path = self.save_dir / f"{timestamp} - {ID:x}.SD.png"
64 | image.save(path, bitmap_format="png", pnginfo=png_info)
65 |
66 | return path
67 |
68 | @torch.no_grad()
69 | def encode(self, data: Tensor) -> Tensor:
70 | """Encodes (stochastically) a RGB image into a latent vector."""
71 |
72 | return self.vae.encode(data).sample().mul(MAGIC)
73 |
74 | @torch.no_grad()
75 | def decode(self, latents: Tensor) -> Tensor:
76 | """Decode latent vector into an RGB image."""
77 |
78 | return self.vae.decode(latents.div(MAGIC))
79 |
80 | @torch.no_grad()
81 | def get_context(
82 | self,
83 | negative_prompts: list[str],
84 | prompts: Optional[list[str]],
85 | ) -> tuple[Tensor, Optional[Tensor]]:
86 | """Creates a context Tensor (negative + positive prompt) and a emphasis weights."""
87 |
88 | texts = deepcopy(negative_prompts)
89 | if prompts is not None:
90 | texts.extend(deepcopy(prompts))
91 |
92 | context, weight = self.clip(texts, self.device, self.dtype)
93 |
94 | return context, weight
95 |
96 | def generate_noise(
97 | self,
98 | seeds: list[int],
99 | sub_seeds: Optional[list[int]],
100 | interpolations: Optional[Tensor],
101 | height: int,
102 | width: int,
103 | batch_size: int,
104 | ) -> Tensor:
105 | """Generate random noise with individual seeds per batch and
106 | possible sub-seed interpolation."""
107 |
108 | shape = (batch_size, self.latent_channels, height // 8, width // 8)
109 | noise = generate_noise(shape, seeds, self.device, self.dtype)
110 | if sub_seeds is None:
111 | return noise
112 |
113 | assert interpolations is not None
114 | sub_noise = generate_noise(shape, sub_seeds, self.device, self.dtype)
115 | noise = slerp(noise, sub_noise, interpolations)
116 |
117 | return noise
118 |
--------------------------------------------------------------------------------
/sd_fused/app/sd.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Optional
3 |
4 | from pathlib import Path
5 | from tqdm.auto import trange, tqdm
6 | from PIL import Image
7 | from IPython.display import display
8 |
9 | from copy import deepcopy
10 | import torch
11 | from torch import Tensor
12 |
13 | from ..models import AutoencoderKL, UNet2DConditional
14 | from ..clip import ClipEmbedding
15 | from ..utils.cuda import clear_cuda
16 | from ..scheduler import Scheduler, DDIMScheduler
17 | from ..utils.image import tensor2images
18 | from ..utils.image import ImageType, ResizeModes
19 | from ..utils.typing import MaybeIterable
20 | from ..utils.diverse import to_list, product_args
21 | from ..clip.parser import prompts_choices
22 | from ..utils.tensors import random_seeds
23 | from ..utils.parameters import Parameters, ParametersList, group_parameters, batch_parameters
24 | from .setup import Setup
25 | from .helpers import Helpers
26 |
27 |
28 | class StableDiffusion(Setup, Helpers):
29 | version: str = "0.6.0"
30 |
31 | def __init__(
32 | self,
33 | path: str | Path,
34 | *,
35 | save_dir: str | Path = "./gallery",
36 | model_name: Optional[str] = None,
37 | ) -> None:
38 | """Load Stable-Diffusion diffusers checkpoint."""
39 |
40 | self.path = path = Path(path)
41 | self.save_dir = Path(save_dir)
42 | self.model_name = model_name or path.name
43 |
44 | assert path.is_dir()
45 |
46 | self.clip = ClipEmbedding(path / "tokenizer", path / "text_encoder")
47 |
48 | self.vae = AutoencoderKL.from_diffusers(path / "vae")
49 | self.unet = UNet2DConditional.from_diffusers(path / "unet")
50 |
51 | # initialize
52 | self.low_ram(False)
53 | self.split_attention(None)
54 | self.flash_attention(False)
55 | self.tome(None)
56 |
57 | def generate(
58 | self,
59 | *,
60 | eta: MaybeIterable[float] = 0,
61 | steps: MaybeIterable[int] = 32,
62 | height: MaybeIterable[int] = 512,
63 | width: MaybeIterable[int] = 512,
64 | negative_prompt: MaybeIterable[str] = "",
65 | # optionals
66 | scale: Optional[MaybeIterable[float]] = 7.5,
67 | prompt: Optional[MaybeIterable[str]] = None,
68 | img: Optional[MaybeIterable[ImageType]] = None,
69 | mask: Optional[MaybeIterable[ImageType]] = None,
70 | strength: Optional[MaybeIterable[float]] = None,
71 | mode: Optional[ResizeModes] = None,
72 | seed: Optional[MaybeIterable[int]] = None,
73 | # sub_seed: Optional[int] = None, # TODO Iterable?
74 | # seed_interpolation: Optional[MaybeIterable[float]] = None,
75 | # latents: Optional[Tensor] = None, # TODO
76 | batch_size: int = 1,
77 | repeat: int = 1,
78 | show: bool = True,
79 | share_seed: bool = True,
80 | ) -> list[tuple[Image.Image, Path, Parameters]]:
81 | """Create a list of parameters and group them
82 | into batches to be processed.
83 | """
84 |
85 | if seed is not None:
86 | repeat = 1
87 |
88 | if prompt is not None:
89 | prompt = prompts_choices(prompt)
90 | negative_prompt = prompts_choices(negative_prompt)
91 |
92 | list_kwargs = product_args(
93 | eta=eta,
94 | steps=steps,
95 | scale=scale,
96 | height=height,
97 | width=width,
98 | negative_prompt=negative_prompt,
99 | prompt=prompt,
100 | img=img,
101 | mask=mask,
102 | strength=strength,
103 | # seed_interpolation=seed_interpolation,
104 | )
105 | size = len(list_kwargs)
106 | list_kwargs = deepcopy(list_kwargs * repeat)
107 |
108 | # if seeds are given or share-seed set
109 | # each repeated-iteration has the same seed
110 | if seed is not None or share_seed:
111 | if seed is None:
112 | seed = random_seeds(repeat)
113 | seeds = [s for s in to_list(seed) for _ in range(size)]
114 |
115 | # otherwise each iteration has it's own unique seed
116 | else:
117 | seeds = random_seeds(size * repeat)
118 |
119 | # create parameters list and group/batch them
120 | parameters = [
121 | Parameters(**kwargs, mode=mode, seed=seed, device=self.device, dtype=self.dtype) # sub_seed=sub_seed
122 | for (seed, kwargs) in zip(seeds, list_kwargs)
123 | ]
124 | groups = group_parameters(parameters)
125 | batched_parameters = batch_parameters(groups, batch_size)
126 |
127 | out: list[tuple[Image.Image, Path, Parameters]] = []
128 | for params in tqdm(batched_parameters, desc="Generating batches"):
129 | ipp = self.generate_from_parameters(ParametersList(params))
130 | out.extend(ipp)
131 |
132 | if show:
133 | for image, path, parameters in ipp:
134 | print(parameters)
135 | display(image)
136 |
137 | return out
138 |
139 | @torch.no_grad()
140 | def generate_from_parameters(
141 | self,
142 | pL: ParametersList,
143 | ) -> list[tuple[Image.Image, Path, Parameters]]:
144 |
145 | context, weight = self.get_context(pL.negative_prompts, pL.prompts)
146 |
147 | # TODO make general
148 | scheduler = DDIMScheduler(
149 | pL.steps, pL.shape(self.latent_channels), pL.seeds, pL.strength, self.device, self.dtype
150 | )
151 |
152 | enc = lambda x: self.encode(x) if x is not None else None
153 | image_latents = enc(pL.images_data)
154 | mask_latents = enc(pL.masks_data)
155 | masked_image_latents = enc(pL.masked_images_data)
156 |
157 | latents = scheduler.prepare_latents(image_latents, mask_latents, masked_image_latents)
158 | # TODO add to scheduler
159 | latents = self.denoise_latents(scheduler, latents, context, weight, pL.unconditional, pL.scales, pL.etas)
160 |
161 | data = self.decode(latents)
162 | images = tensor2images(data)
163 |
164 | paths: list[Path] = []
165 | for parameter, image in zip(pL, images):
166 | path = self.save_image(image, parameter.png_info)
167 | paths.append(path)
168 |
169 | return list(zip(images, paths, list(pL)))
170 |
171 | def denoise_latents(
172 | self,
173 | scheduler: Scheduler,
174 | latents: Tensor,
175 | context: Tensor,
176 | weight: Optional[Tensor],
177 | unconditional: bool,
178 | scales: Optional[Tensor],
179 | etas: Optional[Tensor],
180 | ) -> Tensor:
181 | """Main loop where latents are denoised."""
182 |
183 | clear_cuda()
184 | for index in trange(scheduler.skip_timestep, scheduler.steps, desc="Denoising latents"):
185 | timestep = int(scheduler.timesteps[index].item())
186 |
187 | pred_noise = scheduler.pred_noise(
188 | self.unet, latents, timestep, context, weight, scales, unconditional, self.use_low_ram
189 | )
190 | latents = scheduler.step(pred_noise, latents, index, etas=etas)
191 |
192 | del pred_noise
193 | clear_cuda()
194 |
195 | return latents
196 |
197 | def __repr__(self) -> str:
198 | name = self.__class__.__qualname__
199 |
200 | return f'{name}(model="{self.model_name}", version="{self.version}")'
201 |
--------------------------------------------------------------------------------
/sd_fused/app/setup.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Optional
3 | from typing_extensions import Self
4 |
5 | import torch
6 |
7 | from ..utils.typing import Literal
8 | from ..utils.cuda import clear_cuda
9 | from ..models import AutoencoderKL, UNet2DConditional
10 | from ..layers.blocks.attention.compute import ChunkType
11 | from ..layers.base.types import Device
12 | from ..layers.base.base import Base
13 |
14 |
15 | class Setup(Base):
16 | use_low_ram: bool
17 |
18 | vae: AutoencoderKL
19 | unet: UNet2DConditional
20 |
21 | def low_ram(self, use: bool = True) -> Self:
22 | """Split context into two passes to save memory."""
23 |
24 | self.use_low_ram = use
25 |
26 | return self
27 |
28 | def to(self, *, device: Optional[Device] = None, dtype: Optional[torch.dtype] = None) -> Self:
29 | if device is not None:
30 | self.device = device
31 | if dtype is not None:
32 | self.dtype = dtype
33 |
34 | self.unet.to(device=self.device, dtype=dtype)
35 | self.vae.to(device=self.device, dtype=dtype)
36 |
37 | return self
38 |
39 | def cuda(self) -> Self:
40 | """Send unet and auto-encoder to cuda."""
41 |
42 | clear_cuda()
43 |
44 | return self.to(device="cuda")
45 |
46 | def cpu(self) -> Self:
47 | """Send unet and auto-encoder to cpu."""
48 |
49 | return self.to(device="cpu")
50 |
51 | def half(self) -> Self:
52 | """Use half-precision for unet and auto-encoder."""
53 |
54 | return self.to(dtype=torch.float16)
55 |
56 | def float(self) -> Self:
57 | """Use full-precision for unet and auto-encoder."""
58 |
59 | return self.to(dtype=torch.float32)
60 |
61 | def half_weights(self, use: bool = True) -> Self:
62 | """Store the weights in half-precision but
63 | compute forward pass in full precision.
64 | Useful for GPUs that gives NaN when used in half-precision.
65 | """
66 |
67 | self.unet.half_weights(use)
68 | self.vae.half_weights(use)
69 |
70 | return self
71 |
72 | def split_attention(
73 | self,
74 | chunks: Optional[int | Literal["auto"]] = "auto",
75 | chunk_types: Optional[ChunkType] = None,
76 | ) -> Self:
77 | """Split cross-attention computation into chunks."""
78 |
79 | # TODO this should not be here...
80 | # default to batch if not set
81 | if chunks is not None and chunk_types is None:
82 | chunk_types = "batch"
83 |
84 | self.unet.split_attention(chunks, chunk_types)
85 | self.vae.split_attention(chunks, chunk_types)
86 |
87 | return self
88 |
89 | def flash_attention(self, use: bool = True) -> Self:
90 | """Use xformers flash-attention."""
91 |
92 | self.unet.flash_attention(use)
93 | self.vae.flash_attention(use)
94 |
95 | return self
96 |
97 | def tome(self, r: Optional[int | float] = None) -> Self:
98 | """Merge similar tokens."""
99 |
100 | self.unet.tome(r)
101 | self.vae.tome(r)
102 |
103 | return self
104 |
--------------------------------------------------------------------------------
/sd_fused/clip/__init__.py:
--------------------------------------------------------------------------------
1 | from .clip_embedding import ClipEmbedding
2 |
--------------------------------------------------------------------------------
/sd_fused/clip/clip_embedding.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Optional
3 |
4 | from functools import lru_cache
5 | from pathlib import Path
6 |
7 | import torch
8 |
9 | from transformers.models.clip.modeling_clip import CLIPTextModel
10 | from transformers.models.clip.tokenization_clip import CLIPTokenizer
11 |
12 | from ..layers.base.types import Device
13 | from .text_segment import TextSegment
14 | from .container import TensorAndWeight, TensorAndMaybeWeight
15 | from .parser import (
16 | clean_spaces,
17 | add_delimiter4words,
18 | expand_delimiters,
19 | add_split_maker4emphasis,
20 | split_prompt_into_segments,
21 | )
22 |
23 | MAX_TOKENS = 77
24 |
25 |
26 | class ClipEmbedding:
27 | """Convert a text to embeddings using a CLIP model."""
28 |
29 | tokenizer: CLIPTokenizer
30 | text_encoder: CLIPTextModel
31 |
32 | def __init__(
33 | self,
34 | tokenizer_path: str | Path,
35 | text_encoder_path: str | Path,
36 | ) -> None:
37 | # no need for CUDA for simple embeddings...
38 | self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
39 | self.text_encoder = CLIPTextModel.from_pretrained(text_encoder_path) # type: ignore
40 |
41 | # token ids markers
42 | assert self.tokenizer.bos_token_id is not None
43 | assert self.tokenizer.eos_token_id is not None
44 | assert self.tokenizer.pad_token_id is not None
45 |
46 | self.bos_token_id = self.tokenizer.bos_token_id
47 | self.eos_token_id = self.tokenizer.eos_token_id
48 | self.pad_token_id = self.tokenizer.pad_token_id
49 |
50 | @staticmethod
51 | def parse_emphasis(prompt: str) -> str:
52 | """Parse emphasis notation."""
53 |
54 | prompt = add_delimiter4words(prompt)
55 | prompt = expand_delimiters(prompt)
56 | prompt = add_split_maker4emphasis(prompt)
57 |
58 | return prompt
59 |
60 | @lru_cache(maxsize=None)
61 | def get_ids_and_weights(self, prompt: str) -> TensorAndWeight:
62 | """Get the token id and weight for a given prompt."""
63 |
64 | prompt = clean_spaces(prompt)
65 | prompt = self.parse_emphasis(prompt)
66 |
67 | segments = [TextSegment(t) for t in split_prompt_into_segments(prompt)]
68 |
69 | ids: list[int] = []
70 | weights: list[float] = []
71 | for n, seg in enumerate(segments):
72 | seg_ids = self.tokenizer.encode(seg.prompt)
73 |
74 | # remove initial/final ids
75 | seg_ids = seg_ids[1:-1]
76 |
77 | ids.extend(seg_ids)
78 | weights.extend([seg.weight] * (len(seg_ids)))
79 |
80 | # add padding and initial/final ids
81 | pad_size = MAX_TOKENS - len(ids) - 2
82 | assert pad_size >= 0, "Text too big, it will result in truncation"
83 |
84 | ids = [
85 | self.bos_token_id,
86 | *ids,
87 | *[self.pad_token_id] * pad_size,
88 | self.eos_token_id,
89 | ]
90 | weights = [1, *weights, *[1] * pad_size, 1]
91 |
92 | return TensorAndWeight(torch.tensor([ids]), torch.tensor([weights]).float())
93 |
94 | @lru_cache(maxsize=None)
95 | @torch.no_grad()
96 | def get_embedding(self, prompt: str) -> TensorAndWeight:
97 | """Creates an embedding/weights for a prompt and cache it."""
98 |
99 | ids, weight = self.get_ids_and_weights(prompt)
100 | emb = self.text_encoder(ids)[0]
101 |
102 | return TensorAndWeight(emb, weight)
103 |
104 | def __call__(
105 | self,
106 | prompt: str | list[str] = "",
107 | device: Optional[Device] = None,
108 | dtype: Optional[torch.dtype] = None,
109 | ) -> TensorAndMaybeWeight:
110 | """Creates embeddings/weights for a prompt and send to the correct device/dtype."""
111 |
112 | if isinstance(prompt, str):
113 | prompt = [prompt]
114 | values = [self.get_embedding(t) for t in prompt]
115 | emb = torch.cat([v.tensor for v in values])
116 | weight = torch.cat([v.weight for v in values])
117 |
118 | emb = emb.to(device=device, dtype=dtype, non_blocking=True)
119 |
120 | # special case where all weights are one
121 | if weight.eq(1).all():
122 | return TensorAndMaybeWeight(emb)
123 |
124 | weight = weight.to(device=device, dtype=dtype, non_blocking=True)
125 |
126 | return TensorAndMaybeWeight(emb, weight)
127 |
--------------------------------------------------------------------------------
/sd_fused/clip/container.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import NamedTuple, Optional
3 |
4 | from torch import Tensor
5 |
6 |
7 | class TensorAndWeight(NamedTuple):
8 | tensor: Tensor
9 | weight: Tensor
10 |
11 |
12 | class TensorAndMaybeWeight(NamedTuple):
13 | tensor: Tensor
14 | weight: Optional[Tensor] = None
15 |
--------------------------------------------------------------------------------
/sd_fused/clip/parser/__init__.py:
--------------------------------------------------------------------------------
1 | from .clean_spaces import clean_spaces
2 | from .add_delimiter4words import add_delimiter4words
3 | from .add_split_maker4emphasis import add_split_maker4emphasis, split_prompt_into_segments
4 | from .expand_delimiters import expand_delimiters
5 | from .prompt_choices import prompt_choices, prompts_choices
6 |
7 | from .diffuse_prompt import diffuse_prompt
8 |
--------------------------------------------------------------------------------
/sd_fused/clip/parser/add_delimiter4words.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import re
4 |
5 |
6 | def add_delimiter4words(prompt: str) -> str:
7 | """Replaces `word:weight` -> `(word):weight`."""
8 |
9 | prompt = re.sub(r"(\w+):([+-]?\d+(?:.\d+)?)", r"(\1):\2", prompt)
10 |
11 | return prompt
12 |
--------------------------------------------------------------------------------
/sd_fused/clip/parser/add_split_maker4emphasis.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import re
4 |
5 |
6 | def add_split_maker4emphasis(prompt: str) -> str:
7 | """Add ⏎ to the begginig and end of (..):value"""
8 |
9 | pattern = r"(\(.+?\):[+-]?\d+(?:.\d+)?)"
10 | prompt = re.sub(pattern, r"⏎\1⏎", prompt)
11 |
12 | return prompt
13 |
14 |
15 | def split_prompt_into_segments(prompt: str) -> list[str]:
16 | """Split a prompt at ⏎ to give prompt-segments."""
17 |
18 | return prompt.split("⏎")
19 |
--------------------------------------------------------------------------------
/sd_fused/clip/parser/clean_spaces.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import re
4 |
5 |
6 | def clean_spaces(prompt: str) -> str:
7 | """Clean-up spaces/return characters."""
8 |
9 | prompt = prompt.replace("\n", " ")
10 | prompt = re.sub(r"[ ]+", r" ", prompt)
11 | prompt = prompt.strip()
12 |
13 | return prompt
14 |
--------------------------------------------------------------------------------
/sd_fused/clip/parser/diffuse_prompt.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import re
4 |
5 | import math
6 | from typing import Optional
7 | import torch
8 |
9 |
10 | def diffuse_prompt(
11 | prompt: str,
12 | vmax: float = 1,
13 | size: int = 1,
14 | seed: Optional[int] = None,
15 | ) -> list[str]:
16 | """Diffuse attention-weights to a prompt."""
17 |
18 | assert ":" not in prompt
19 |
20 | pattern = re.compile(r"(\w+)")
21 | words = pattern.split(prompt)
22 | n = len(words)
23 |
24 | generator = torch.Generator()
25 | if seed is not None:
26 | generator.manual_seed(seed)
27 |
28 | weigths = torch.randn(size, n, generator=generator)
29 | weigths = weigths.cumsum(0) / math.sqrt(size) * vmax
30 | weigths += 1 - weigths[[0]] # start at the same weight
31 | weigths = weigths.tolist()
32 |
33 | prompts: list[str] = []
34 | for weight in weigths:
35 | prompt = "".join([f"{w}:{a:.3f}" if pattern.search(w) else w for w, a in zip(words, weight)])
36 | prompts.append(prompt)
37 |
38 | return prompts
39 |
--------------------------------------------------------------------------------
/sd_fused/clip/parser/expand_delimiters.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import re
4 |
5 | FACTOR = 1.1
6 | MAX_EMPHASIS = 8
7 |
8 |
9 | def expand_delimiters(prompt: str) -> str:
10 | """Replace `(^n ... )^n` with `( ... ):factor^n`"""
11 |
12 | delimiters = [(r"\(" * repeat, r"\)" * repeat, repeat) for repeat in range(MAX_EMPHASIS, 0, -1)]
13 |
14 | avoid = r"\(\)"
15 | for left, right, repeat in delimiters:
16 | pattern = f"{left}([^{avoid}]+?){right}([^:]|$)"
17 | repl = f"(\\1):{FACTOR**repeat:.4f}\\2"
18 | prompt = re.sub(pattern, repl, prompt)
19 |
20 | # recover back parantheses and brackets
21 | prompt = prompt.replace(r"\(", "(").replace(r"\)", ")")
22 | prompt = prompt.replace(r"\[", "[").replace(r"\]", "]")
23 |
24 | return prompt
25 |
--------------------------------------------------------------------------------
/sd_fused/clip/parser/prompt_choices.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import re
4 |
5 | from ...utils.diverse import to_list
6 | from ...utils.typing import MaybeIterable
7 |
8 |
9 | def prompt_choices(prompt: str) -> list[str]:
10 | """Create a set of prompt word-choices from
11 | `[word 1 | another word | yet another one]`
12 | """
13 |
14 | pattern = re.compile(r"\[([^\[\]]+)\]")
15 |
16 | temp: list[str] = [prompt]
17 | prompts: list[str] = []
18 |
19 | while len(temp) != 0:
20 | prompt = temp.pop()
21 |
22 | match = pattern.search(prompt)
23 | if match is not None:
24 | start, end = match.span()
25 |
26 | choices = match.group(1).split("|")
27 | assert len(choices) > 1
28 |
29 | for choice in choices:
30 | choice = choice.strip()
31 | new_text = "".join([prompt[:start], choice, prompt[end:]])
32 | temp.append(new_text.strip())
33 | else:
34 | prompts.append(prompt)
35 |
36 | return prompts[::-1]
37 |
38 |
39 | def prompts_choices(prompts: MaybeIterable[str]) -> list[str]:
40 | """Create a set of prompt word-choices from
41 | `[word 1 | another word | yet another one]`"""
42 |
43 | return [choice for prompt in to_list(prompts) for choice in prompt_choices(prompt)]
44 |
--------------------------------------------------------------------------------
/sd_fused/clip/text_segment.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import re
4 |
5 |
6 | class TextSegment:
7 | """Split `(prompt):weight` for parsing."""
8 |
9 | prompt: str
10 | weight: float = 1
11 |
12 | def __init__(self, prompt: str) -> None:
13 | self.prompt = prompt
14 |
15 | pattern = r"\((.+?)\):([+-]?\d+(?:.\d+)?)"
16 | match = re.match(pattern, prompt)
17 | if match:
18 | self.prompt = match.group(1)
19 | self.weight = float(match.group(2))
20 |
21 | def __repr__(self) -> str:
22 | return f'{self.__class__.__qualname__}(text="{self.prompt}"; weight={self.weight})'
23 |
--------------------------------------------------------------------------------
/sd_fused/layers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tfernd/sd-fused/50570f983cc00dd4bd0dc415d0b515da638311b0/sd_fused/layers/__init__.py
--------------------------------------------------------------------------------
/sd_fused/layers/activation/__init__.py:
--------------------------------------------------------------------------------
1 | from .geglu import GEGLU
2 | from .silu import SiLU
3 |
--------------------------------------------------------------------------------
/sd_fused/layers/activation/geglu.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import torch.nn.functional as F
4 | from torch import Tensor
5 |
6 | from ..base import Module
7 | from ..basic import Linear
8 |
9 |
10 | class GEGLU(Module):
11 | def __init__(self, dim_in: int, dim_out: int) -> None:
12 | super().__init__()
13 |
14 | self.dim_in = dim_in
15 | self.dim_out = dim_out
16 |
17 | self.proj = Linear(dim_in, 2 * dim_out)
18 |
19 | def __call__(self, x: Tensor) -> Tensor:
20 | x, gate = self.proj(x).chunk(2, dim=-1)
21 |
22 | return F.gelu(gate).mul_(x)
23 |
--------------------------------------------------------------------------------
/sd_fused/layers/activation/silu.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from torch import Tensor
4 |
5 | from ..base import Module
6 |
7 |
8 | class SiLU(Module):
9 | def __call__(self, x: Tensor) -> Tensor:
10 | return x.sigmoid().mul_(x)
11 |
--------------------------------------------------------------------------------
/sd_fused/layers/auto_encoder/__init__.py:
--------------------------------------------------------------------------------
1 | from .encoder import Encoder
2 | from .decoder import Decoder
3 |
--------------------------------------------------------------------------------
/sd_fused/layers/auto_encoder/decoder.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from torch import Tensor
4 |
5 | from ..base import Module, ModuleList
6 | from ..basic import Conv2d
7 | from ..blocks.basic import GroupNormSiLUConv2d
8 | from ..blocks.spatial import UpDecoderBlock2D, UNetMidBlock2DSelfAttention
9 |
10 |
11 | class Decoder(Module):
12 | def __init__(
13 | self,
14 | *,
15 | in_channels: int,
16 | out_channels: int,
17 | block_out_channels: tuple[int, ...],
18 | layers_per_block: int,
19 | resnet_groups: int,
20 | ) -> None:
21 | super().__init__()
22 |
23 | self.in_channels = in_channels
24 | self.out_channels = out_channels
25 | self.block_out_channels = block_out_channels
26 | self.layers_per_block = layers_per_block
27 | self.resnet_groups = resnet_groups
28 |
29 | num_blocks = len(block_out_channels)
30 |
31 | self.pre_process = Conv2d(in_channels, block_out_channels[-1], kernel_size=3, padding=1)
32 |
33 | # mid
34 | self.mid_block = UNetMidBlock2DSelfAttention(
35 | in_channels=block_out_channels[-1],
36 | temb_channels=None,
37 | num_layers=1,
38 | resnet_groups=resnet_groups,
39 | attn_num_head_channels=None,
40 | )
41 |
42 | # up
43 | reversed_block_out_channels = list(reversed(block_out_channels))
44 | output_channel = reversed_block_out_channels[0]
45 | self.up_blocks = ModuleList[UpDecoderBlock2D]()
46 | for i in range(num_blocks):
47 | is_final_block = i == num_blocks - 1
48 |
49 | prev_output_channel = output_channel
50 | output_channel = reversed_block_out_channels[i]
51 |
52 | block = UpDecoderBlock2D(
53 | in_channels=prev_output_channel,
54 | out_channels=output_channel,
55 | num_layers=layers_per_block + 1,
56 | resnet_groups=resnet_groups,
57 | add_upsample=not is_final_block,
58 | )
59 |
60 | self.up_blocks.append(block)
61 | prev_output_channel = output_channel
62 |
63 | # out
64 | self.post_process = GroupNormSiLUConv2d(
65 | resnet_groups,
66 | block_out_channels[0],
67 | out_channels,
68 | kernel_size=3,
69 | padding=1,
70 | )
71 |
72 | def __call__(self, x: Tensor) -> Tensor:
73 | x = self.pre_process(x)
74 | x = self.mid_block(x)
75 |
76 | for up_block in self.up_blocks:
77 | x = up_block(x)
78 |
79 | return self.post_process(x)
80 |
--------------------------------------------------------------------------------
/sd_fused/layers/auto_encoder/encoder.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from torch import Tensor
4 |
5 | from ..base import Module, ModuleList
6 | from ..basic import Conv2d
7 | from ..blocks.basic import GroupNormSiLUConv2d
8 | from ..blocks.spatial import DownEncoderBlock2D, UNetMidBlock2DSelfAttention
9 |
10 |
11 | class Encoder(Module):
12 | def __init__(
13 | self,
14 | *,
15 | in_channels: int,
16 | out_channels: int,
17 | block_out_channels: tuple[int, ...],
18 | layers_per_block: int,
19 | resnet_groups: int,
20 | double_z: bool,
21 | ) -> None:
22 | super().__init__()
23 |
24 | self.in_channels = in_channels
25 | self.out_channels = out_channels
26 | self.block_out_channels = block_out_channels
27 | self.layers_per_block = layers_per_block
28 | self.resnet_groups = resnet_groups
29 | self.double_z = double_z
30 |
31 | num_blocks = len(block_out_channels)
32 |
33 | self.pre_process = Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=1)
34 |
35 | # down
36 | output_channel = block_out_channels[0]
37 | self.down_blocks = ModuleList[DownEncoderBlock2D]()
38 | for i in range(num_blocks):
39 | is_final_block = i == num_blocks - 1
40 |
41 | input_channel = output_channel
42 | output_channel = block_out_channels[i]
43 |
44 | block = DownEncoderBlock2D(
45 | in_channels=input_channel,
46 | out_channels=output_channel,
47 | num_layers=layers_per_block,
48 | resnet_groups=resnet_groups,
49 | downsample_padding=0,
50 | add_downsample=not is_final_block,
51 | )
52 | self.down_blocks.append(block)
53 |
54 | # mid
55 | self.mid_block = UNetMidBlock2DSelfAttention(
56 | in_channels=block_out_channels[-1],
57 | temb_channels=None,
58 | num_layers=1,
59 | resnet_groups=resnet_groups,
60 | attn_num_head_channels=None,
61 | )
62 |
63 | # out
64 | conv_out_channels = 2 * out_channels if double_z else out_channels
65 | self.post_process = GroupNormSiLUConv2d(
66 | resnet_groups,
67 | block_out_channels[-1],
68 | conv_out_channels,
69 | kernel_size=3,
70 | padding=1,
71 | )
72 |
73 | def __call__(self, x: Tensor) -> Tensor:
74 | x = self.pre_process(x)
75 |
76 | for down_block in self.down_blocks:
77 | x = down_block(x)
78 |
79 | x = self.mid_block(x)
80 |
81 | return self.post_process(x)
82 |
--------------------------------------------------------------------------------
/sd_fused/layers/base/__init__.py:
--------------------------------------------------------------------------------
1 | from .module import Module
2 | from .sequential import Sequential
3 | from .module_list import ModuleList
4 |
--------------------------------------------------------------------------------
/sd_fused/layers/base/base.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import torch
4 |
5 | from .types import Device
6 |
7 |
8 | class Base:
9 | dtype: torch.dtype = torch.float16
10 | device: Device = "cuda"
11 |
--------------------------------------------------------------------------------
/sd_fused/layers/base/module.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Optional
3 | from typing_extensions import Self
4 |
5 | from abc import ABC
6 |
7 | import torch
8 | import torch.nn as nn
9 | from torch import Tensor
10 |
11 | from .base import Base
12 | from .types import Device
13 |
14 |
15 | class Module(Base, ABC):
16 | # TODO Types?
17 | def __call__(self, *args, **kwargs) -> Tensor:
18 | raise NotImplementedError
19 |
20 | def named_modules(self) -> dict[str, Module]:
21 | modules: dict[str, Module] = {}
22 |
23 | modules[""] = self
24 | for key, value in self.__dict__.items():
25 | if isinstance(value, Module):
26 | modules[key] = value
27 |
28 | # recursive call
29 | for sub_key, sub_value in value.named_modules().items():
30 | if sub_key != "":
31 | modules[f"{key}.{sub_key}"] = sub_value
32 |
33 | return modules
34 |
35 | def state_dict(self) -> dict[str, nn.Parameter]:
36 | params: dict[str, nn.Parameter] = {}
37 |
38 | for key, value in self.__dict__.items():
39 | # single parameter
40 | if isinstance(value, nn.Parameter):
41 | params[key] = value
42 |
43 | # recursive call
44 | elif isinstance(value, Module):
45 | for sub_key, sub_value in value.state_dict().items():
46 | params[f"{key}.{sub_key}"] = sub_value
47 |
48 | return params
49 |
50 | def load_state_dict(self, state: dict[str, nn.Parameter] | dict[str, Tensor], strict: bool = False) -> Self:
51 | current_state = self.state_dict()
52 | assert len(current_state) == len(state)
53 |
54 | for key, value in current_state.items():
55 | assert key in state
56 |
57 | new_value = state[key].data
58 | if strict:
59 | assert new_value.shape == value.shape
60 |
61 | new_value = new_value.to(device=self.device, dtype=self.dtype, non_blocking=True)
62 | value.data = new_value
63 |
64 | return self
65 |
66 | def float(self) -> Self:
67 | return self.to(dtype=torch.float32)
68 |
69 | def half(self) -> Self:
70 | return self.to(dtype=torch.float16)
71 |
72 | def cpu(self) -> Self:
73 | return self.to(device=f"cpu")
74 |
75 | def cuda(self, index: int = 0) -> Self:
76 | return self.to(device=f"cuda:{index}")
77 |
78 | def to(self, *, device: Optional[Device] = None, dtype: Optional[torch.dtype] = None):
79 | if device is not None:
80 | self.device = device
81 | if dtype is not None:
82 | self.dtype = dtype
83 |
84 | for key, value in self.state_dict().items():
85 | data = value.data
86 | value.data = data.to(device=device, dtype=dtype, non_blocking=True)
87 |
88 | return self
89 |
--------------------------------------------------------------------------------
/sd_fused/layers/base/module_list.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Generator, Generic, NoReturn
3 |
4 | import torch.nn as nn
5 |
6 | from ...utils.typing import TypeVar, TypeVarTuple, Unpack
7 | from .module import Module
8 |
9 |
10 | T = TypeVar("T", bound=Module)
11 | Ts = TypeVarTuple("Ts") # ? bound
12 |
13 |
14 | class _ModuleSequence(Module):
15 | layers: tuple[Module, ...] | list[Module]
16 |
17 | def state_dict(self) -> dict[str, nn.Parameter]:
18 | params: dict[str, nn.Parameter] = {}
19 | for index, layer in enumerate(self.layers):
20 | for key, value in layer.state_dict().items():
21 | params[f"{index}.{key}"] = value
22 |
23 | return params
24 |
25 | def named_modules(self) -> dict[str, Module]:
26 | modules: dict[str, Module] = {}
27 | for index, layer in enumerate(self.layers):
28 | for key, value in layer.named_modules().items():
29 | name = str(index) if key == "" else f"{index}.{key}"
30 | modules[name] = value
31 |
32 | return modules
33 |
34 | def __call__(self) -> NoReturn:
35 | raise ValueError(f"{self.__class__.__qualname__} is not callable.")
36 |
37 |
38 | class ModuleList(_ModuleSequence, Generic[T]):
39 | layers: list[T]
40 |
41 | def __init__(self, *layers: T) -> None:
42 | self.layers = list(layers)
43 |
44 | def append(self, layer: T) -> None:
45 | self.layers.append(layer)
46 |
47 | def __iter__(self) -> Generator[T, None, None]:
48 | for layer in self.layers:
49 | yield layer
50 |
51 |
52 | # # ! only for debug
53 | # class ModuleTuple(Module, Generic[Unpack[Ts]]):
54 | # layers: tuple[Unpack[Ts]]
55 |
56 | # def __init__(self, *layers: Unpack[Ts]) -> None:
57 | # self.layers = layers
58 |
59 | # def __iter__(self): # ? type?
60 | # for layer in self.layers:
61 | # yield layer
62 |
--------------------------------------------------------------------------------
/sd_fused/layers/base/sequential.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from torch import Tensor
4 |
5 | from .module import Module
6 | from .module_list import ModuleList, T
7 |
8 |
9 | class Sequential(ModuleList[T], Module):
10 | def __call__(self, x: Tensor) -> Tensor:
11 | for layer in self.layers:
12 | x = layer(x)
13 |
14 | return x
15 |
--------------------------------------------------------------------------------
/sd_fused/layers/base/types.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Union
3 |
4 | import torch
5 |
6 | Device = Union[str, torch.device]
7 |
--------------------------------------------------------------------------------
/sd_fused/layers/basic/__init__.py:
--------------------------------------------------------------------------------
1 | from .identity import Identity
2 |
3 | from .linear import Linear
4 | from .conv2d import Conv2d
5 |
6 | from .layer_norm import LayerNorm
7 | from .group_norm import GroupNorm
8 |
--------------------------------------------------------------------------------
/sd_fused/layers/basic/conv2d.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Optional
3 |
4 | from functools import partial
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from torch import Tensor
10 |
11 | from ..base import Module
12 | from ..modifiers import HalfWeightsModule, half_weights
13 |
14 |
15 | class Conv2d(HalfWeightsModule, Module):
16 | def __init__(
17 | self,
18 | in_channels: int,
19 | out_channels: Optional[int] = None,
20 | *,
21 | kernel_size: int = 1,
22 | stride: int = 1,
23 | padding: int = 0,
24 | groups: int = 1,
25 | dilation: int = 1,
26 | bias: bool = True,
27 | ) -> None:
28 | assert in_channels % groups == 0
29 |
30 | self.in_channels = in_channels
31 | self.out_channels = out_channels = out_channels or in_channels
32 | self.kernel_size = kernel_size
33 | self.stride = stride
34 | self.padding = padding
35 | self.groups = groups
36 | self.dilation = dilation
37 |
38 | # TODO duplication
39 | empty = partial(torch.empty, dtype=self.dtype, device=self.device)
40 | parameter = partial(nn.Parameter, requires_grad=False)
41 |
42 | w = empty(out_channels, in_channels // groups, kernel_size, kernel_size)
43 | self.weight = parameter(w)
44 | self.bias = parameter(empty(out_channels)) if bias else None
45 |
46 | @half_weights
47 | def __call__(self, x: Tensor) -> Tensor:
48 | return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
49 |
--------------------------------------------------------------------------------
/sd_fused/layers/basic/group_norm.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from functools import partial
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from torch import Tensor
9 |
10 | from ..base import Module
11 | from ..modifiers import HalfWeightsModule, half_weights
12 |
13 |
14 | class GroupNorm(HalfWeightsModule, Module):
15 | def __init__(
16 | self,
17 | num_groups: int,
18 | num_channels: int,
19 | *,
20 | eps: float = 1e-6,
21 | affine: bool = True,
22 | ) -> None:
23 | self.num_groups = num_groups
24 | self.num_channels = num_channels
25 | self.eps = eps
26 | self.affine = affine
27 |
28 | empty = partial(torch.empty, dtype=self.dtype, device=self.device)
29 | parameter = partial(nn.Parameter, requires_grad=False)
30 |
31 | self.weight = parameter(empty(num_channels)) if affine else None
32 | self.bias = parameter(empty(num_channels)) if affine else None
33 |
34 | @half_weights
35 | def __call__(self, x: Tensor) -> Tensor:
36 | return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
37 |
--------------------------------------------------------------------------------
/sd_fused/layers/basic/identity.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from torch import Tensor
4 |
5 | from ..base import Module
6 |
7 |
8 | class Identity(Module):
9 | def __call__(self, x: Tensor) -> Tensor:
10 | return x
11 |
--------------------------------------------------------------------------------
/sd_fused/layers/basic/layer_norm.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from functools import partial
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from torch import Tensor
9 |
10 | from ..base import Module
11 | from ..modifiers import HalfWeightsModule, half_weights
12 |
13 |
14 | class LayerNorm(HalfWeightsModule, Module):
15 | def __init__(
16 | self,
17 | shape: int | tuple[int, ...],
18 | *,
19 | eps: float = 1e-6,
20 | elementwise_affine: bool = True,
21 | ) -> None:
22 | self.shape = shape = shape if isinstance(shape, tuple) else (shape,)
23 | self.eps = eps
24 | self.elementwise_affine = elementwise_affine
25 |
26 | empty = partial(torch.empty, dtype=self.dtype, device=self.device)
27 | parameter = partial(nn.Parameter, requires_grad=False)
28 |
29 | self.weight = parameter(empty(shape)) if elementwise_affine else NotImplemented
30 | self.bias = parameter(empty(shape)) if elementwise_affine else NotImplemented
31 |
32 | @half_weights
33 | def __call__(self, x: Tensor) -> Tensor:
34 | return F.layer_norm(x, self.shape, self.weight, self.bias, self.eps)
35 |
--------------------------------------------------------------------------------
/sd_fused/layers/basic/linear.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Optional
3 |
4 | from functools import partial
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from torch import Tensor
10 |
11 | from ..base import Module
12 | from ..modifiers import HalfWeightsModule, half_weights
13 |
14 |
15 | class Linear(HalfWeightsModule, Module):
16 | def __init__(
17 | self,
18 | in_features: int,
19 | out_features: Optional[int] = None,
20 | *,
21 | bias: bool = True,
22 | ) -> None:
23 | self.in_features = in_features
24 | self.out_features = out_features = out_features or in_features
25 |
26 | empty = partial(torch.empty, dtype=self.dtype, device=self.device)
27 | parameter = partial(nn.Parameter, requires_grad=False)
28 |
29 | self.weight = parameter(empty(out_features, in_features))
30 | self.bias = parameter(empty(out_features)) if bias else None
31 |
32 | @half_weights
33 | def __call__(self, x: Tensor) -> Tensor:
34 | return F.linear(x, self.weight, self.bias)
35 |
--------------------------------------------------------------------------------
/sd_fused/layers/blocks/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tfernd/sd-fused/50570f983cc00dd4bd0dc415d0b515da638311b0/sd_fused/layers/blocks/__init__.py
--------------------------------------------------------------------------------
/sd_fused/layers/blocks/attention/__init__.py:
--------------------------------------------------------------------------------
1 | from .self_attention import SelfAttention
2 | from .cross_attention import CrossAttention
3 |
--------------------------------------------------------------------------------
/sd_fused/layers/blocks/attention/base_attention.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Optional
3 |
4 | from torch import Tensor
5 |
6 | from ....utils.typing import Literal
7 | from .compute import attention, ChunkType
8 |
9 |
10 | class BaseAttention:
11 | attention_chunks: Optional[int | Literal["auto"]] = None
12 | chunk_type: Optional[ChunkType] = None
13 | use_flash_attention: bool = False
14 | tome_r: Optional[int | float] = None
15 |
16 | def attention(self, q: Tensor, k: Tensor, v: Tensor, weights: Optional[Tensor] = None) -> Tensor:
17 | return attention(
18 | q,
19 | k,
20 | v,
21 | chunks=self.attention_chunks,
22 | chunk_type=self.chunk_type,
23 | use_flash_attention=self.use_flash_attention,
24 | tome_r=self.tome_r,
25 | weights=weights,
26 | )
27 |
--------------------------------------------------------------------------------
/sd_fused/layers/blocks/attention/compute/__init__.py:
--------------------------------------------------------------------------------
1 | from .attention import attention
2 | from .auto_chunk_size import ChunkType
3 |
--------------------------------------------------------------------------------
/sd_fused/layers/blocks/attention/compute/attention.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Optional
3 |
4 | from torch import Tensor
5 |
6 | from .....utils.typing import Literal
7 | from .join_spatial_dim import join_spatial_dim
8 | from .auto_chunk_size import auto_chunk_size, ChunkType
9 | from .standard_attention import standard_attention
10 | from .chunked_attention import batch_chunked_attention, sequence_chunked_attention
11 | from .flash_attention import flash_attention
12 | from .weighted_values import weighted_values
13 | from .tome import token_average
14 |
15 |
16 | def attention(
17 | q: Tensor, # (B, heads, H, W, C)
18 | k: Tensor, # (B, heads, *[H,W] | T', C)
19 | v: Tensor, # (B, heads, *[H,W] | T', C)
20 | *,
21 | weights: Optional[Tensor] = None, # (B, T')
22 | chunks: Optional[int | Literal["auto"]] = None,
23 | chunk_type: Optional[ChunkType] = None,
24 | use_flash_attention: bool = False,
25 | tome_r: Optional[int | float] = None,
26 | ) -> Tensor:
27 | """General attention computation."""
28 |
29 | # header
30 | is_self_attention = q.shape == k.shape
31 | dtype = q.dtype
32 | B, heads, H, W, C = q.shape
33 | T = H * W
34 |
35 | if weights is not None:
36 | assert not is_self_attention
37 |
38 | v = weighted_values(v, weights) # ?keys is vad?
39 |
40 | q, k, v = join_spatial_dim(q, k, v)
41 | Tl = k.size(2)
42 |
43 | chunks = auto_chunk_size(chunks, B, heads, T, Tl, C, dtype, chunk_type)
44 |
45 | if is_self_attention and tome_r is not None:
46 | k, v, bias = token_average(k, v, tome_r)
47 | bias = None # ! not used for now.
48 |
49 | if chunks is not None:
50 | assert not use_flash_attention
51 |
52 | if chunk_type is None or chunk_type == "batch":
53 | out = batch_chunked_attention(q, k, v, chunks, bias)
54 | else:
55 | out = sequence_chunked_attention(q, k, v, chunks, bias)
56 |
57 | elif use_flash_attention:
58 | assert chunk_type is None
59 | assert tome_r is None
60 |
61 | out = flash_attention(q, k, v, bias)
62 | else:
63 | out = standard_attention(q, k, v, bias)
64 |
65 | out = out.unflatten(2, (H, W)) # separate-spatial-dim
66 |
67 | return out
68 |
--------------------------------------------------------------------------------
/sd_fused/layers/blocks/attention/compute/auto_chunk_size.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Optional
3 |
4 | import torch
5 |
6 | from .....utils.typing import Literal
7 | from .....utils.cuda import free_memory
8 |
9 |
10 | ChunkType = Literal["batch", "sequence"]
11 |
12 |
13 | def auto_chunk_size(
14 | chunks: Optional[int | Literal["auto"]],
15 | B: int,
16 | heads: int,
17 | T: int,
18 | Tl: int,
19 | C: int,
20 | dtype: torch.dtype,
21 | chunk_type: Optional[ChunkType],
22 | ) -> Optional[int]:
23 | """Determine the maximum chunk size according to the available free memory."""
24 |
25 | if chunks != "auto":
26 | return chunks
27 |
28 | B *= heads # ! ugly but...
29 |
30 | assert chunk_type is not None
31 | assert dtype in (torch.float32, torch.float16)
32 |
33 | num_bytes = 2 if dtype == torch.float16 else 4
34 | free = free_memory()
35 |
36 | if chunk_type is None or chunk_type == "batch":
37 | # Memory used: (2*Bchunks*T*Tl + Bchunks*T*C) * num_bytes
38 | Bchunks = free // (num_bytes * T * (C + 2 * Tl))
39 |
40 | if Bchunks >= B:
41 | return None
42 | return Bchunks
43 |
44 | # Memory used: (2*B*Tchunk*Tl + B*Tchunk*C) * num_bytes
45 | Tchunks = free // (num_bytes * B * (C + 2 * Tl))
46 |
47 | if Tchunks >= T:
48 | return None
49 | return Tchunks
50 |
--------------------------------------------------------------------------------
/sd_fused/layers/blocks/attention/compute/chunked_attention.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Optional
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | from torch import Tensor
7 |
8 | from .scale_qk import scale_qk
9 |
10 |
11 | def batch_chunked_attention(
12 | q: Tensor, # (B, heads, T, C)
13 | k: Tensor, # (B, heads, T', C)
14 | v: Tensor, # (B, heads, T', C)
15 | chunks: int,
16 | bias: Optional[Tensor] = None, # (B, heads, T', C)
17 | ) -> Tensor:
18 | """Batch-chunked attention computation."""
19 |
20 | assert chunks >= 1
21 | B, heads, T, C = q.shape
22 |
23 | # join batch-heads
24 | q = q.flatten(0, 1)
25 | k = k.flatten(0, 1)
26 | v = v.flatten(0, 1)
27 |
28 | q, k = scale_qk(q, k)
29 | kT = k.transpose(-1, -2)
30 |
31 | out = torch.empty_like(q)
32 | for i in range(0, B * heads, chunks):
33 | s = slice(i, min(i + chunks, B * heads))
34 |
35 | score = q[s] @ kT[s]
36 | if bias is not None:
37 | score += bias[s]
38 | attn = F.softmax(score, dim=-1, dtype=q.dtype)
39 | del score
40 |
41 | out[s] = attn @ v[s]
42 | del attn
43 |
44 | return out.unflatten(0, (B, heads))
45 |
46 |
47 | def sequence_chunked_attention(
48 | q: Tensor, # (B, heads, T, C)
49 | k: Tensor, # (B, heads, T', C)
50 | v: Tensor, # (B, heads, T', C)
51 | chunks: int,
52 | bias: Optional[Tensor] = None, # (B, heads, T', C)
53 | ) -> Tensor:
54 | """Sequence-chunked attention computation."""
55 |
56 | # https://github.com/Doggettx/stable-diffusion/blob/main/ldm/modules/attention.py#L209
57 |
58 | assert chunks >= 1
59 | B, heads, T, C = q.shape
60 |
61 | # join batch-heads
62 | q, k, v = map(lambda x: x.flatten(0, 1), (q, k, v))
63 |
64 | q, k = scale_qk(q, k)
65 | kT = k.transpose(-1, -2)
66 |
67 | out = torch.empty_like(q)
68 | for i in range(0, T, chunks):
69 | s = slice(i, min(i + chunks, T))
70 |
71 | score = q[:, s] @ kT
72 | if bias is not None:
73 | score += bias[:, s]
74 | attn = F.softmax(score, dim=-1, dtype=q.dtype)
75 | del score
76 |
77 | out[:, s] = attn @ v
78 | del attn
79 |
80 | return out.unflatten(0, (B, heads))
81 |
--------------------------------------------------------------------------------
/sd_fused/layers/blocks/attention/compute/flash_attention.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Optional
3 |
4 | from torch import Tensor
5 |
6 | try:
7 | from xformers.ops import memory_efficient_attention # type: ignore
8 | except ImportError:
9 | memory_efficient_attention = None
10 |
11 |
12 | def flash_attention(
13 | q: Tensor, # (B, heads, T, C)
14 | k: Tensor, # (B, heads, T', C)
15 | v: Tensor, # (B, heads, T', C)
16 | bias: Optional[Tensor] = None, # (B, heads, T', C)
17 | ) -> Tensor:
18 | """xformers flash-attention computation."""
19 |
20 | assert memory_efficient_attention is not None
21 |
22 | q, k, v = map(lambda x: x.contiguous(), (q, k, v))
23 | bias = bias.contiguous() if bias is not None else None
24 |
25 | out = memory_efficient_attention(q, k, v, bias)
26 |
27 | return out
28 |
--------------------------------------------------------------------------------
/sd_fused/layers/blocks/attention/compute/join_spatial_dim.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from torch import Tensor
4 |
5 |
6 | def join_spatial_dim(q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor, Tensor]:
7 | "Join height-width spatial dimensions of qkv."
8 |
9 | is_self_attention = q.ndim == k.ndim
10 |
11 | q = q.flatten(2, 3)
12 | if is_self_attention:
13 | k = k.flatten(2, 3)
14 | v = v.flatten(2, 3)
15 |
16 | return q, k, v
17 |
--------------------------------------------------------------------------------
/sd_fused/layers/blocks/attention/compute/scale_qk.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import math
4 |
5 | from torch import Tensor
6 |
7 |
8 | def scale_qk(q: Tensor, k: Tensor) -> tuple[Tensor, Tensor]:
9 | "Scale the query and key."
10 |
11 | C = q.size(-1)
12 | scale = math.pow(C, -1 / 4)
13 |
14 | return q * scale, k * scale
15 |
--------------------------------------------------------------------------------
/sd_fused/layers/blocks/attention/compute/standard_attention.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Optional
3 |
4 | import torch.nn.functional as F
5 | from torch import Tensor
6 |
7 | from .scale_qk import scale_qk
8 |
9 |
10 | def standard_attention(
11 | q: Tensor, # (B, heads, T, C)
12 | k: Tensor, # (B, heads, T', C)
13 | v: Tensor, # (B, heads, T', C)
14 | bias: Optional[Tensor] = None, # (B, heads, T', C)
15 | ) -> Tensor:
16 | """Standard attention computation."""
17 |
18 | q, k = scale_qk(q, k)
19 | score = q @ k.transpose(-1, -2)
20 | if bias is not None:
21 | score += bias
22 | attn = F.softmax(score, dim=-1, dtype=q.dtype)
23 | del score
24 |
25 | return attn @ v
26 |
--------------------------------------------------------------------------------
/sd_fused/layers/blocks/attention/compute/tome.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Optional
3 |
4 | import torch
5 | from torch import Tensor
6 | import math
7 |
8 | from .....utils.typing import Literal, Protocol
9 |
10 | Modes = Literal["sum", "mean"]
11 |
12 |
13 | class Merge(Protocol):
14 | def __call__(self, x: Tensor, mode: Modes = "mean") -> Tensor:
15 | ...
16 |
17 |
18 | def tome(metric: Tensor, r: int | float) -> Merge:
19 | # https://arxiv.org/abs/2210.09461
20 | # https://github.com/facebookresearch/ToMe/blob/main/tome/merge.py
21 |
22 | B, T, C = metric.shape
23 |
24 | if not isinstance(r, int):
25 | r = math.floor(r * T)
26 |
27 | assert 0 < r <= T / 2 # max 50% reduction
28 |
29 | with torch.no_grad():
30 | metric = metric / metric.norm(dim=2, keepdim=True)
31 | a, b = metric[:, ::2], metric[:, 1::2]
32 | score = a @ b.transpose(1, 2)
33 | del a, b
34 |
35 | node_max, node_idx = score.max(dim=2, keepdim=True)
36 | del score
37 | edge_idx = node_max.argsort(dim=1, descending=True)
38 |
39 | # Unmerged/Merged Tokens
40 | size = (math.ceil(T / 2) - r, r)
41 | unmerged_idx, src_idx = edge_idx.split(size, dim=1)
42 |
43 | dst_idx = node_idx.gather(dim=1, index=src_idx)
44 |
45 | def merge(x: Tensor, mode: Modes = "mean") -> Tensor:
46 | src, dst = x[:, ::2], x[:, 1::2]
47 | B, T, C = src.shape
48 |
49 | unmerged = src.gather(dim=1, index=unmerged_idx.expand(-1, -1, C))
50 | src = src.gather(dim=1, index=src_idx.expand(-1, -1, C))
51 | dst = dst.scatter_reduce(1, dst_idx.expand(-1, -1, C), src, reduce=mode)
52 |
53 | return torch.cat([unmerged, dst], dim=1)
54 |
55 | return merge
56 |
57 |
58 | def merge_weighted_average(
59 | merge: Merge,
60 | x: Tensor,
61 | size: Optional[Tensor] = None,
62 | ) -> tuple[Tensor, Tensor]:
63 | B, T, C = x.shape
64 |
65 | if size is None:
66 | size = torch.ones(B, T, 1, dtype=x.dtype, device=x.device)
67 |
68 | x = merge(x * size, mode="sum")
69 | size = merge(size, mode="sum")
70 |
71 | x = x / size
72 |
73 | return x, size
74 |
75 |
76 | def token_average(
77 | k: Tensor, # (B, heads, T', C)
78 | v: Tensor, # (B, heads, T', C)
79 | r: int | float,
80 | ) -> tuple[Tensor, Tensor, Tensor]:
81 | B, heads, Tl, C = k.shape
82 |
83 | # join batch-heads
84 | k = k.flatten(0, 1)
85 | v = v.flatten(0, 1)
86 |
87 | merge = tome(k, r)
88 |
89 | k, size = merge_weighted_average(merge, k)
90 | v, size = merge_weighted_average(merge, v)
91 | bias = size.log()
92 |
93 | # separate batch-heads
94 | k = k.unflatten(0, (B, heads))
95 | v = v.unflatten(0, (B, heads))
96 | bias = bias.unflatten(0, (B, heads))
97 |
98 | return k, v, bias
99 |
--------------------------------------------------------------------------------
/sd_fused/layers/blocks/attention/compute/weighted_values.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Optional
3 |
4 | from torch import Tensor
5 |
6 | from einops import rearrange
7 |
8 |
9 | def weighted_values(
10 | x: Tensor, # (B, heads, T, C)
11 | weights: Optional[Tensor] = None, # (B, T)
12 | ) -> Tensor:
13 | "Modify the values to give emphasis to some tokens."
14 |
15 | if weights is None:
16 | return x
17 |
18 | weights = rearrange(weights, "B T -> B 1 T 1")
19 |
20 | return x * weights
21 |
--------------------------------------------------------------------------------
/sd_fused/layers/blocks/attention/cross_attention.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Optional
3 |
4 | from torch import Tensor
5 |
6 | from ...external import Rearrange
7 | from ...base import Module
8 | from ...basic import LayerNorm, Linear
9 | from .base_attention import BaseAttention
10 |
11 |
12 | class CrossAttention(BaseAttention, Module):
13 | def __init__(
14 | self,
15 | *,
16 | query_features: int,
17 | context_features: Optional[int],
18 | head_features: int,
19 | num_heads: int,
20 | ) -> None:
21 | super().__init__()
22 |
23 | is_cross_attention = context_features is not None
24 |
25 | context_features = context_features or query_features
26 | inner_dim = head_features * num_heads
27 |
28 | self.query_features = query_features
29 | self.context_features = context_features
30 | self.head_features = head_features
31 | self.num_heads = num_heads
32 |
33 | # TODO These 4+1 operations can be fused
34 | self.norm = LayerNorm(query_features)
35 | self.to_q = Linear(query_features, inner_dim, bias=False)
36 | self.to_k = Linear(context_features, inner_dim, bias=False)
37 | self.to_v = Linear(context_features, inner_dim, bias=False)
38 |
39 | self.to_out = Linear(inner_dim, query_features)
40 |
41 | self.heads_to_batch = Rearrange("B H W (heads C) -> B heads H W C", heads=num_heads)
42 | self.heads_to_batch2 = self.heads_to_batch
43 | if is_cross_attention:
44 | self.heads_to_batch2 = Rearrange("B T (heads C) -> B heads T C", heads=num_heads)
45 |
46 | self.heads_to_channel = Rearrange("B heads H W C -> B H W (heads C)")
47 |
48 | def __call__(
49 | self,
50 | x: Tensor,
51 | *,
52 | context: Optional[Tensor] = None,
53 | weights: Optional[Tensor] = None,
54 | ) -> Tensor:
55 |
56 | xin = x
57 | x = self.norm(x)
58 | context = context if context is not None else x
59 |
60 | # key, query, value projections
61 | q = self.heads_to_batch(self.to_q(x))
62 | k = self.heads_to_batch2(self.to_k(context))
63 | v = self.heads_to_batch2(self.to_v(context))
64 |
65 | x = self.attention(q, k, v, weights)
66 | del q, k, v
67 | x = self.heads_to_channel(x)
68 |
69 | return xin + self.to_out(x)
70 |
--------------------------------------------------------------------------------
/sd_fused/layers/blocks/attention/self_attention.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Optional
3 |
4 | from torch import Tensor
5 |
6 | from ...external import Rearrange
7 | from ...base import Module
8 | from ...basic import GroupNorm, Linear
9 | from .base_attention import BaseAttention
10 |
11 |
12 | class SelfAttention(BaseAttention, Module):
13 | def __init__(
14 | self,
15 | *,
16 | in_features: int,
17 | head_features: Optional[int],
18 | num_groups: int,
19 | ) -> None:
20 | super().__init__()
21 |
22 | self.in_features = in_features
23 | self.head_features = head_features
24 | self.num_groups = num_groups
25 |
26 | head_features = head_features or in_features
27 | num_heads = in_features // head_features
28 |
29 | # TODO These 4+1 operations can be fused
30 | self.group_norm = GroupNorm(num_groups, in_features)
31 | self.query = Linear(in_features)
32 | self.key = Linear(in_features)
33 | self.value = Linear(in_features)
34 |
35 | self.proj_attn = Linear(in_features)
36 |
37 | self.channel_last = Rearrange("B C H W -> B H W C")
38 | self.channel_first = Rearrange("B H W C -> B C H W")
39 | self.heads_to_batch = Rearrange("B H W (heads C) -> B heads H W C", heads=num_heads)
40 | self.heads_to_channel = Rearrange("B heads H W C -> B H W (heads C)")
41 |
42 | def __call__(self, x: Tensor) -> Tensor:
43 | B, C, H, W = x.shape
44 |
45 | xin = x
46 | x = self.group_norm(x)
47 | x = self.channel_last(x)
48 |
49 | # key, query, value projections
50 | q = self.heads_to_batch(self.query(x))
51 | k = self.heads_to_batch(self.key(x))
52 | v = self.heads_to_batch(self.value(x))
53 |
54 | x = self.attention(q, k, v)
55 | del q, k, v
56 | x = self.proj_attn(self.heads_to_channel(x))
57 |
58 | # output
59 | x = self.channel_first(x, H=H, W=W)
60 |
61 | return xin + x
62 |
--------------------------------------------------------------------------------
/sd_fused/layers/blocks/basic/__init__.py:
--------------------------------------------------------------------------------
1 | from .gn_conv import GroupNormConv2d
2 | from .gn_silu_conv import GroupNormSiLUConv2d
3 |
4 | from .ln_geglu_linear import LayerNormGEGLULinear
5 |
--------------------------------------------------------------------------------
/sd_fused/layers/blocks/basic/gn_conv.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Optional
3 |
4 | from ...base import Sequential
5 | from ...basic import GroupNorm, Conv2d
6 |
7 | # TODO improve types by using ModuleTuple
8 | class GroupNormConv2d(Sequential):
9 | def __init__(
10 | self,
11 | num_groups: int,
12 | in_channels: int,
13 | out_channels: Optional[int] = None,
14 | *,
15 | kernel_size: int = 1,
16 | padding: int = 0,
17 | ) -> None:
18 | layers = (
19 | GroupNorm(num_groups, in_channels),
20 | Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding),
21 | )
22 |
23 | super().__init__(*layers)
24 |
--------------------------------------------------------------------------------
/sd_fused/layers/blocks/basic/gn_silu_conv.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Optional
3 |
4 | from ...base import Sequential
5 | from ...basic import GroupNorm, Conv2d
6 | from ...activation import SiLU
7 |
8 |
9 | class GroupNormSiLUConv2d(Sequential):
10 | def __init__(
11 | self,
12 | num_groups: int,
13 | in_channels: int,
14 | out_channels: Optional[int] = None,
15 | *,
16 | kernel_size: int = 1,
17 | padding: int = 0,
18 | ) -> None:
19 | layers = (
20 | GroupNorm(num_groups, in_channels),
21 | SiLU(),
22 | Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding),
23 | )
24 |
25 | super().__init__(*layers)
26 |
--------------------------------------------------------------------------------
/sd_fused/layers/blocks/basic/ln_geglu_linear.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from torch import Tensor
4 |
5 | from ...base import Sequential
6 | from ...basic import LayerNorm, Linear
7 | from ...activation import GEGLU
8 |
9 |
10 | class LayerNormGEGLULinear(Sequential):
11 | def __init__(self, dim: int, *, expand: float) -> None:
12 |
13 | self.dim = dim
14 | self.expand = expand
15 |
16 | inner_dim = int(dim * expand)
17 | dim = dim or dim
18 |
19 | layers = (
20 | LayerNorm(dim),
21 | GEGLU(dim, inner_dim),
22 | Linear(inner_dim, dim),
23 | )
24 |
25 | super().__init__(*layers)
26 |
27 | def __call__(self, x: Tensor) -> Tensor:
28 | return x + super().__call__(x)
29 |
--------------------------------------------------------------------------------
/sd_fused/layers/blocks/spatial/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import DownBlock2D, UpBlock2D
2 |
3 | from .resampling import Upsample2D, Downsample2D
4 |
5 | from .ae import DownEncoderBlock2D, UpDecoderBlock2D
6 | from .cross_attention import CrossAttentionUpBlock2D, CrossAttentionDownBlock2D
7 |
8 | from .unet_mid import UNetMidBlock2DSelfAttention, UNetMidBlock2DCrossAttention
9 | from .resnet import ResnetBlock2D
10 |
--------------------------------------------------------------------------------
/sd_fused/layers/blocks/spatial/ae/__init__.py:
--------------------------------------------------------------------------------
1 | from .down_encoder import DownEncoderBlock2D
2 | from .up_decoder import UpDecoderBlock2D
3 |
--------------------------------------------------------------------------------
/sd_fused/layers/blocks/spatial/ae/down_encoder.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import torch.nn as nn
4 | from torch import Tensor
5 |
6 | from ....base import Module, ModuleList
7 | from ..resampling import Downsample2D
8 | from ..resnet import ResnetBlock2D
9 |
10 |
11 | class DownEncoderBlock2D(Module):
12 | def __init__(
13 | self,
14 | *,
15 | in_channels: int,
16 | out_channels: int,
17 | num_layers: int,
18 | resnet_groups: int,
19 | add_downsample: bool,
20 | downsample_padding: int,
21 | ) -> None:
22 | super().__init__()
23 |
24 | self.in_channels = in_channels
25 | self.out_channels = out_channels
26 | self.num_layers = num_layers
27 | self.resnet_groups = resnet_groups
28 | self.add_downsample = add_downsample
29 | self.downsample_padding = downsample_padding
30 |
31 | self.resnets = ModuleList[ResnetBlock2D]()
32 | for i in range(num_layers):
33 | in_channels = in_channels if i == 0 else out_channels
34 |
35 | self.resnets.append(
36 | ResnetBlock2D(
37 | in_channels=in_channels,
38 | out_channels=out_channels,
39 | temb_channels=None,
40 | num_groups=resnet_groups,
41 | num_out_groups=None,
42 | )
43 | )
44 |
45 | if add_downsample:
46 | self.downsampler = Downsample2D(in_channels, out_channels, padding=downsample_padding)
47 | else:
48 | self.downsampler = nn.Identity()
49 |
50 | def __call__(self, x: Tensor) -> Tensor:
51 | for resnet in self.resnets:
52 | x = resnet(x)
53 |
54 | x = self.downsampler(x)
55 |
56 | return x
57 |
--------------------------------------------------------------------------------
/sd_fused/layers/blocks/spatial/ae/up_decoder.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import torch.nn as nn
4 | from torch import Tensor
5 |
6 | from ....base import Module, ModuleList
7 | from ..resampling import Upsample2D
8 | from ..resnet import ResnetBlock2D
9 |
10 |
11 | class UpDecoderBlock2D(Module):
12 | def __init__(
13 | self,
14 | *,
15 | in_channels: int,
16 | out_channels: int,
17 | num_layers: int,
18 | resnet_groups: int,
19 | add_upsample: bool,
20 | ) -> None:
21 | super().__init__()
22 |
23 | self.in_channels = in_channels
24 | self.out_channels = out_channels
25 | self.num_layers = num_layers
26 | self.resnet_groups = resnet_groups
27 | self.add_upsample = add_upsample
28 |
29 | self.resnets = ModuleList[ResnetBlock2D]()
30 | for i in range(num_layers):
31 | input_channels = in_channels if i == 0 else out_channels
32 |
33 | self.resnets.append(
34 | ResnetBlock2D(
35 | in_channels=input_channels,
36 | out_channels=out_channels,
37 | temb_channels=None,
38 | num_groups=resnet_groups,
39 | num_out_groups=None,
40 | )
41 | )
42 |
43 | if add_upsample:
44 | self.upsampler = Upsample2D(out_channels, out_channels)
45 | else:
46 | self.upsampler = nn.Identity()
47 |
48 | def __call__(self, x: Tensor) -> Tensor:
49 | for resnet in self.resnets:
50 | x = resnet(x)
51 |
52 | x = self.upsampler(x)
53 |
54 | return x
55 |
--------------------------------------------------------------------------------
/sd_fused/layers/blocks/spatial/base/__init__.py:
--------------------------------------------------------------------------------
1 | from .down import DownBlock2D
2 | from .up import UpBlock2D
3 |
--------------------------------------------------------------------------------
/sd_fused/layers/blocks/spatial/base/down.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Optional
3 |
4 | import torch.nn as nn
5 | from torch import Tensor
6 |
7 | from ....base import Module, ModuleList
8 | from ..resampling import Downsample2D
9 | from ..resnet import ResnetBlock2D
10 | from ..output_states import OutputStates
11 |
12 |
13 | class DownBlock2D(Module):
14 | def __init__(
15 | self,
16 | *,
17 | in_channels: int,
18 | out_channels: int,
19 | temb_channels: Optional[int],
20 | num_layers: int,
21 | num_groups: int,
22 | add_downsample: bool,
23 | downsample_padding: int,
24 | ) -> None:
25 | super().__init__()
26 |
27 | self.in_channels = in_channels
28 | self.out_channels = out_channels
29 | self.temb_channels = temb_channels
30 | self.num_layers = num_layers
31 | self.num_groups = num_groups
32 | self.add_downsample = add_downsample
33 | self.downsample_padding = downsample_padding
34 |
35 | self.resnets = ModuleList[ResnetBlock2D]()
36 | for i in range(num_layers):
37 | in_channels = in_channels if i == 0 else out_channels
38 |
39 | self.resnets.append(
40 | ResnetBlock2D(
41 | in_channels=in_channels,
42 | out_channels=out_channels,
43 | temb_channels=temb_channels,
44 | num_groups=num_groups,
45 | num_out_groups=None,
46 | )
47 | )
48 |
49 | if add_downsample:
50 | self.downsampler = Downsample2D(in_channels, out_channels, padding=downsample_padding)
51 | else:
52 | self.downsampler = None
53 |
54 | def __call__(
55 | self,
56 | x: Tensor,
57 | *,
58 | temb: Optional[Tensor] = None,
59 | ) -> OutputStates:
60 | states: list[Tensor] = []
61 | for resnet in self.resnets:
62 | x = resnet(x, temb=temb)
63 | states.append(x)
64 |
65 | if self.downsampler is not None:
66 | x = self.downsampler(x)
67 |
68 | states.append(x)
69 |
70 | return OutputStates(x, states)
71 |
--------------------------------------------------------------------------------
/sd_fused/layers/blocks/spatial/base/up.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Optional
3 |
4 | import torch
5 | import torch.nn as nn
6 | from torch import Tensor
7 |
8 | from ....base import Module, ModuleList
9 | from ..resampling import Upsample2D
10 | from ..resnet import ResnetBlock2D
11 |
12 |
13 | class UpBlock2D(Module):
14 | def __init__(
15 | self,
16 | *,
17 | in_channels: int,
18 | prev_output_channel: int,
19 | out_channels: int,
20 | temb_channels: Optional[int],
21 | num_layers: int,
22 | resnet_groups: int,
23 | add_upsample: bool,
24 | ) -> None:
25 | super().__init__()
26 |
27 | self.in_channels = in_channels
28 | self.prev_output_channel = prev_output_channel
29 | self.out_channels = out_channels
30 | self.temb_channels = temb_channels
31 | self.num_layers = num_layers
32 | self.resnet_groups = resnet_groups
33 | self.add_upsample = add_upsample
34 |
35 | self.resnets = ModuleList[ResnetBlock2D]()
36 | for i in range(num_layers):
37 | if i == num_layers - 1:
38 | res_skip_channels = in_channels
39 | else:
40 | res_skip_channels = out_channels
41 |
42 | if i == 0:
43 | resnet_in_channels = prev_output_channel
44 | else:
45 | resnet_in_channels = out_channels
46 |
47 | self.resnets.append(
48 | ResnetBlock2D(
49 | in_channels=resnet_in_channels + res_skip_channels,
50 | out_channels=out_channels,
51 | temb_channels=temb_channels,
52 | num_groups=resnet_groups,
53 | num_out_groups=None,
54 | )
55 | )
56 |
57 | if add_upsample:
58 | self.upsampler = Upsample2D(out_channels)
59 | else:
60 | self.upsampler = nn.Identity()
61 |
62 | def __call__(
63 | self,
64 | x: Tensor,
65 | *,
66 | states: list[Tensor],
67 | temb: Optional[Tensor] = None,
68 | ) -> Tensor:
69 | assert len(states) == self.num_layers
70 |
71 | for resnet, state in zip(self.resnets, states):
72 | x = torch.cat([x, state], dim=1)
73 | x = resnet(x, temb=temb)
74 |
75 | x = self.upsampler(x)
76 |
77 | return x
78 |
--------------------------------------------------------------------------------
/sd_fused/layers/blocks/spatial/cross_attention/__init__.py:
--------------------------------------------------------------------------------
1 | from .up import CrossAttentionUpBlock2D
2 | from .down import CrossAttentionDownBlock2D
3 |
--------------------------------------------------------------------------------
/sd_fused/layers/blocks/spatial/cross_attention/down.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Optional
3 |
4 | import torch.nn as nn
5 | from torch import Tensor
6 |
7 | from ....base import Module, ModuleList
8 | from ...transformer import SpatialTransformer
9 | from ..resampling import Downsample2D
10 | from ..resnet import ResnetBlock2D
11 | from ..output_states import OutputStates
12 |
13 |
14 | class CrossAttentionDownBlock2D(Module):
15 | def __init__(
16 | self,
17 | *,
18 | in_channels: int,
19 | out_channels: int,
20 | temb_channels: int,
21 | num_layers: int,
22 | resnet_groups: int,
23 | attn_num_head_channels: int,
24 | cross_attention_dim: int,
25 | downsample_padding: int,
26 | add_downsample: bool,
27 | ) -> None:
28 | super().__init__()
29 |
30 | self.in_channels = in_channels
31 | self.out_channels = out_channels
32 | self.temb_channels = temb_channels
33 | self.num_layers = num_layers
34 | self.resnet_groups = resnet_groups
35 | self.attn_num_head_channels = attn_num_head_channels
36 | self.cross_attention_dim = cross_attention_dim
37 | self.downsample_padding = downsample_padding
38 | self.add_downsample = add_downsample
39 |
40 | self.resnets = ModuleList[ResnetBlock2D]()
41 | self.attentions = ModuleList[SpatialTransformer]()
42 | for i in range(num_layers):
43 | in_channels = in_channels if i == 0 else out_channels
44 |
45 | self.resnets.append(
46 | ResnetBlock2D(
47 | in_channels=in_channels,
48 | out_channels=out_channels,
49 | temb_channels=temb_channels,
50 | num_groups=resnet_groups,
51 | num_out_groups=None,
52 | )
53 | )
54 |
55 | self.attentions.append(
56 | SpatialTransformer(
57 | in_channels=out_channels,
58 | num_heads=attn_num_head_channels,
59 | head_features=out_channels // attn_num_head_channels,
60 | depth=1,
61 | num_groups=resnet_groups,
62 | context_features=cross_attention_dim,
63 | )
64 | )
65 |
66 | if add_downsample:
67 | self.downsampler = Downsample2D(in_channels, out_channels, padding=downsample_padding)
68 | else:
69 | self.downsampler = None
70 |
71 | def __call__(
72 | self,
73 | x: Tensor,
74 | *,
75 | temb: Optional[Tensor] = None,
76 | context: Optional[Tensor] = None,
77 | weights: Optional[Tensor] = None,
78 | ) -> OutputStates:
79 | states: list[Tensor] = []
80 | for resnet, attn in zip(self.resnets, self.attentions):
81 | x = resnet(x, temb=temb)
82 | x = attn(x, context=context, weights=weights)
83 |
84 | states.append(x)
85 |
86 | if self.downsampler is not None:
87 | x = self.downsampler(x)
88 |
89 | states.append(x)
90 |
91 | return OutputStates(x, states)
92 |
--------------------------------------------------------------------------------
/sd_fused/layers/blocks/spatial/cross_attention/up.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Optional
3 |
4 | import torch
5 | import torch.nn as nn
6 | from torch import Tensor
7 |
8 | from ....base import Module, ModuleList
9 | from ...transformer import SpatialTransformer
10 | from ..resampling import Upsample2D
11 | from ..resnet import ResnetBlock2D
12 |
13 |
14 | class CrossAttentionUpBlock2D(Module):
15 | def __init__(
16 | self,
17 | *,
18 | in_channels: int,
19 | out_channels: int,
20 | prev_output_channel: int,
21 | temb_channels: int,
22 | num_layers: int,
23 | resnet_groups: int,
24 | attn_num_head_channels: int,
25 | cross_attention_dim: int,
26 | add_upsample: bool,
27 | ) -> None:
28 | super().__init__()
29 |
30 | self.in_channels = in_channels
31 | self.out_channels = out_channels
32 | self.prev_output_channel = prev_output_channel
33 | self.temb_channels = temb_channels
34 | self.num_layers = num_layers
35 | self.resnet_groups = resnet_groups
36 | self.attn_num_head_channels = attn_num_head_channels
37 | self.cross_attention_dim = cross_attention_dim
38 | self.add_upsample = add_upsample
39 |
40 | self.resnets = ModuleList[ResnetBlock2D]()
41 | self.attentions = ModuleList[SpatialTransformer]()
42 | for i in range(num_layers):
43 | if i == num_layers - 1:
44 | res_skip_channels = in_channels
45 | else:
46 | res_skip_channels = out_channels
47 |
48 | if i == 0:
49 | resnet_in_channels = prev_output_channel
50 | else:
51 | resnet_in_channels = out_channels
52 |
53 | self.resnets.append(
54 | ResnetBlock2D(
55 | in_channels=resnet_in_channels + res_skip_channels,
56 | out_channels=out_channels,
57 | temb_channels=temb_channels,
58 | num_groups=resnet_groups,
59 | num_out_groups=None,
60 | )
61 | )
62 |
63 | self.attentions.append(
64 | SpatialTransformer(
65 | in_channels=out_channels,
66 | num_heads=attn_num_head_channels,
67 | head_features=out_channels // attn_num_head_channels,
68 | depth=1,
69 | num_groups=resnet_groups,
70 | context_features=cross_attention_dim,
71 | )
72 | )
73 |
74 | if add_upsample:
75 | self.upsampler = Upsample2D(out_channels)
76 | else:
77 | self.upsampler = nn.Identity()
78 |
79 | def __call__(
80 | self,
81 | x: Tensor,
82 | *,
83 | states: list[Tensor],
84 | temb: Optional[Tensor] = None,
85 | context: Optional[Tensor] = None,
86 | weights: Optional[Tensor] = None,
87 | ) -> Tensor:
88 | assert len(states) == self.num_layers
89 |
90 | for resnet, attn, state in zip(self.resnets, self.attentions, states):
91 | x = torch.cat([x, state], dim=1)
92 | x = resnet(x, temb=temb)
93 | x = attn(x, context=context, weights=weights)
94 |
95 | x = self.upsampler(x)
96 |
97 | return x
98 |
--------------------------------------------------------------------------------
/sd_fused/layers/blocks/spatial/output_states.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import NamedTuple
3 |
4 | from torch import Tensor
5 |
6 |
7 | class OutputStates(NamedTuple):
8 | x: Tensor
9 | states: list[Tensor]
10 |
--------------------------------------------------------------------------------
/sd_fused/layers/blocks/spatial/resampling/__init__.py:
--------------------------------------------------------------------------------
1 | from .upsample2d import Upsample2D
2 | from .downsample2d import Downsample2D
3 |
--------------------------------------------------------------------------------
/sd_fused/layers/blocks/spatial/resampling/downsample2d.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Callable, Optional
3 |
4 | from functools import partial
5 |
6 | import torch.nn.functional as F
7 | from torch import Tensor
8 |
9 | from ....base import Module
10 | from ....basic import Conv2d
11 |
12 |
13 | class Downsample2D(Module):
14 | pad: Callable[[Tensor], Tensor]
15 |
16 | def __init__(
17 | self,
18 | in_channels: int,
19 | out_channels: Optional[int],
20 | *,
21 | padding: int,
22 | ) -> None:
23 | super().__init__()
24 |
25 | self.channels = in_channels
26 | self.out_channels = out_channels
27 | self.padding = padding
28 |
29 | out_channels = out_channels or in_channels
30 | stride = 2
31 |
32 | self.conv = Conv2d(
33 | in_channels,
34 | out_channels,
35 | kernel_size=3,
36 | stride=stride,
37 | padding=padding,
38 | )
39 |
40 | self.pad = lambda x: x
41 | if padding == 0:
42 | # ? Why?
43 | self.pad = partial(
44 | F.pad,
45 | pad=(0, 1, 0, 1),
46 | mode="constant",
47 | value=0,
48 | )
49 |
50 | def __call__(self, x: Tensor) -> Tensor:
51 | x = self.pad(x)
52 |
53 | return self.conv(x)
54 |
--------------------------------------------------------------------------------
/sd_fused/layers/blocks/spatial/resampling/upsample2d.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Optional
3 |
4 | from functools import partial
5 |
6 | import torch.nn.functional as F
7 | from torch import Tensor
8 |
9 | from ....base import Module
10 | from ....basic import Conv2d
11 |
12 |
13 | class Upsample2D(Module):
14 | def __init__(
15 | self,
16 | channels: int,
17 | out_channels: Optional[int] = None,
18 | ) -> None:
19 | super().__init__()
20 |
21 | self.channels = channels
22 | self.out_channels = out_channels
23 |
24 | out_channels = out_channels or channels
25 |
26 | self.conv = Conv2d(channels, out_channels, kernel_size=3, padding=1)
27 | self.upscale = partial(F.interpolate, mode="nearest", scale_factor=2)
28 |
29 | def __call__(self, x: Tensor) -> Tensor:
30 | x = self.upscale(x)
31 |
32 | return self.conv(x)
33 |
--------------------------------------------------------------------------------
/sd_fused/layers/blocks/spatial/resnet.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Optional
3 |
4 | from torch import Tensor
5 |
6 | from ...external import Rearrange
7 | from ...base import Module, Sequential
8 | from ...activation import SiLU
9 | from ...basic import Conv2d, Linear, Identity
10 | from ..basic import GroupNormSiLUConv2d
11 |
12 |
13 | class ResnetBlock2D(Module):
14 | def __init__(
15 | self,
16 | *,
17 | in_channels: int,
18 | out_channels: Optional[int],
19 | temb_channels: Optional[int],
20 | num_groups: int,
21 | num_out_groups: Optional[int],
22 | ) -> None:
23 | super().__init__()
24 |
25 | self.in_channels = in_channels
26 | self.out_channels = out_channels
27 | self.temb_channels = temb_channels
28 | self.num_groups = num_groups
29 | self.num_out_groups = num_out_groups
30 |
31 | out_channels = out_channels or in_channels
32 | num_out_groups = num_out_groups or num_groups
33 |
34 | if in_channels != out_channels:
35 | self.conv_shortcut = Conv2d(in_channels, out_channels)
36 | else:
37 | self.conv_shortcut = Identity()
38 |
39 | self.pre_process = GroupNormSiLUConv2d(num_groups, in_channels, out_channels, kernel_size=3, padding=1)
40 |
41 | self.post_process = GroupNormSiLUConv2d(
42 | num_out_groups,
43 | out_channels,
44 | out_channels,
45 | kernel_size=3,
46 | padding=1,
47 | )
48 |
49 | if temb_channels is not None:
50 | self.time_emb_proj = Sequential(
51 | SiLU(),
52 | Linear(temb_channels, out_channels),
53 | Rearrange("b c -> b c 1 1"),
54 | )
55 | else:
56 | self.time_emb_proj = None
57 |
58 | def __call__(self, x: Tensor, *, temb: Optional[Tensor] = None) -> Tensor:
59 | xin = self.conv_shortcut(x)
60 |
61 | x = self.pre_process(x)
62 |
63 | if self.time_emb_proj is not None:
64 | assert temb is not None
65 |
66 | temb = self.time_emb_proj(temb)
67 | x = x + temb
68 |
69 | x = self.post_process(x)
70 |
71 | return xin + x
72 |
--------------------------------------------------------------------------------
/sd_fused/layers/blocks/spatial/unet_mid/__init__.py:
--------------------------------------------------------------------------------
1 | from .cross_attention import UNetMidBlock2DCrossAttention
2 | from .self_attention import UNetMidBlock2DSelfAttention
3 |
--------------------------------------------------------------------------------
/sd_fused/layers/blocks/spatial/unet_mid/cross_attention.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Optional
3 |
4 | from torch import Tensor
5 |
6 | from ....base import Module, ModuleList
7 | from ...transformer import SpatialTransformer
8 | from ..resnet import ResnetBlock2D
9 |
10 |
11 | class UNetMidBlock2DCrossAttention(Module):
12 | def __init__(
13 | self,
14 | *,
15 | in_channels: int,
16 | temb_channels: int,
17 | num_layers: int,
18 | resnet_groups: int,
19 | attn_num_head_channels: int,
20 | cross_attention_dim: int,
21 | ) -> None:
22 | super().__init__()
23 |
24 | self.in_channels = in_channels
25 | self.temb_channels = temb_channels
26 | self.num_layers = num_layers
27 | self.resnet_groups = resnet_groups
28 | self.attn_num_head_channels = attn_num_head_channels
29 | self.cross_attention_dim = cross_attention_dim
30 |
31 | resnet_groups = resnet_groups or min(in_channels // 4, 32)
32 |
33 | self.attentions = ModuleList[SpatialTransformer]()
34 | self.resnets = ModuleList[ResnetBlock2D]()
35 | for i in range(num_layers + 1):
36 | if i > 0:
37 | self.attentions.append(
38 | SpatialTransformer(
39 | in_channels=in_channels,
40 | num_heads=attn_num_head_channels,
41 | head_features=in_channels // attn_num_head_channels,
42 | depth=1,
43 | num_groups=resnet_groups,
44 | context_features=cross_attention_dim,
45 | )
46 | )
47 |
48 | self.resnets.append(
49 | ResnetBlock2D(
50 | in_channels=in_channels,
51 | out_channels=in_channels,
52 | temb_channels=temb_channels,
53 | num_groups=resnet_groups,
54 | num_out_groups=None,
55 | )
56 | )
57 |
58 | def __call__(
59 | self,
60 | x: Tensor,
61 | *,
62 | temb: Optional[Tensor] = None,
63 | context: Optional[Tensor] = None,
64 | weights: Optional[Tensor] = None,
65 | ) -> Tensor:
66 | first_resnet, *rest_resnets = self.resnets
67 |
68 | x = first_resnet(x, temb=temb)
69 |
70 | for attn, resnet in zip(self.attentions, rest_resnets):
71 | x = attn(x, context=context, weights=weights)
72 | x = resnet(x, temb=temb)
73 |
74 | return x
75 |
--------------------------------------------------------------------------------
/sd_fused/layers/blocks/spatial/unet_mid/self_attention.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Optional
3 |
4 |
5 | from torch import Tensor
6 |
7 | from ....base import Module, ModuleList
8 | from ...attention import SelfAttention
9 | from ..resnet import ResnetBlock2D
10 |
11 |
12 | class UNetMidBlock2DSelfAttention(Module):
13 | def __init__(
14 | self,
15 | *,
16 | in_channels: int,
17 | temb_channels: Optional[int],
18 | num_layers: int,
19 | resnet_groups: int,
20 | attn_num_head_channels: Optional[int],
21 | ) -> None:
22 | super().__init__()
23 |
24 | self.in_channels = in_channels
25 | self.temb_channels = temb_channels
26 | self.num_layers = num_layers
27 | self.resnet_groups = resnet_groups
28 | self.attn_num_head_channels = attn_num_head_channels
29 |
30 | resnet_groups = resnet_groups or min(in_channels // 4, 32)
31 |
32 | self.attentions = ModuleList[SelfAttention]()
33 | self.resnets = ModuleList[ResnetBlock2D]()
34 | for i in range(num_layers + 1):
35 | if i > 0:
36 | self.attentions.append(
37 | SelfAttention(
38 | in_features=in_channels,
39 | num_groups=resnet_groups,
40 | head_features=attn_num_head_channels,
41 | )
42 | )
43 |
44 | self.resnets.append(
45 | ResnetBlock2D(
46 | in_channels=in_channels,
47 | out_channels=in_channels,
48 | temb_channels=temb_channels,
49 | num_groups=resnet_groups,
50 | num_out_groups=None,
51 | )
52 | )
53 |
54 | def __call__(
55 | self,
56 | x: Tensor,
57 | *,
58 | temb: Optional[Tensor] = None,
59 | ) -> Tensor:
60 | first_resnet, *rest_resnets = self.resnets
61 |
62 | x = first_resnet(x, temb=temb)
63 |
64 | for attn, resnet in zip(self.attentions, rest_resnets):
65 | x = attn(x)
66 | x = resnet(x, temb=temb)
67 |
68 | return x
69 |
--------------------------------------------------------------------------------
/sd_fused/layers/blocks/transformer/__init__.py:
--------------------------------------------------------------------------------
1 | from .basic_transformer import BasicTransformer
2 | from .spatial_transformer import SpatialTransformer
3 |
--------------------------------------------------------------------------------
/sd_fused/layers/blocks/transformer/basic_transformer.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Optional
3 |
4 | from torch import Tensor
5 |
6 | from ...base import Module
7 | from ..attention import CrossAttention
8 | from ..basic import LayerNormGEGLULinear
9 |
10 |
11 | class BasicTransformer(Module):
12 | def __init__(
13 | self,
14 | *,
15 | in_features: int,
16 | num_heads: int,
17 | head_features: int,
18 | context_features: Optional[int],
19 | ):
20 | super().__init__()
21 |
22 | self.in_features = in_features
23 | self.num_heads = num_heads
24 | self.head_features = head_features
25 | self.context_features = context_features
26 |
27 | self.attn1 = CrossAttention(
28 | query_features=in_features,
29 | num_heads=num_heads,
30 | head_features=head_features,
31 | context_features=None,
32 | )
33 | self.attn2 = CrossAttention(
34 | query_features=in_features,
35 | num_heads=num_heads,
36 | head_features=head_features,
37 | context_features=context_features,
38 | )
39 |
40 | self.ff = LayerNormGEGLULinear(in_features, expand=4)
41 |
42 | def __call__(
43 | self,
44 | x: Tensor,
45 | *,
46 | context: Optional[Tensor] = None,
47 | weights: Optional[Tensor] = None,
48 | ) -> Tensor:
49 | x = self.attn1(x)
50 | x = self.attn2(x, context=context, weights=weights)
51 | x = self.ff(x)
52 |
53 | return x
54 |
--------------------------------------------------------------------------------
/sd_fused/layers/blocks/transformer/spatial_transformer.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Optional
3 |
4 | from torch import Tensor
5 |
6 | from ...external import Rearrange
7 | from ...base import Module, ModuleList
8 | from ...basic import Conv2d
9 | from ..basic import GroupNormConv2d
10 | from .basic_transformer import BasicTransformer
11 |
12 |
13 | class SpatialTransformer(Module):
14 | def __init__(
15 | self,
16 | *,
17 | in_channels: int,
18 | num_heads: int,
19 | head_features: int,
20 | depth: int,
21 | num_groups: int,
22 | context_features: Optional[int],
23 | ):
24 | super().__init__()
25 |
26 | self.in_channels = in_channels
27 | self.num_heads = num_heads
28 | self.head_features = head_features
29 | self.depth = depth
30 | self.num_groups = num_groups
31 | self.context_features = context_features
32 |
33 | inner_dim = num_heads * head_features
34 |
35 | self.proj_in = GroupNormConv2d(num_groups, in_channels, inner_dim)
36 | self.proj_out = Conv2d(inner_dim, in_channels)
37 |
38 | self.transformer_blocks = ModuleList[BasicTransformer]()
39 | for _ in range(depth):
40 | self.transformer_blocks.append(
41 | BasicTransformer(
42 | in_features=inner_dim,
43 | num_heads=num_heads,
44 | head_features=head_features,
45 | context_features=context_features,
46 | )
47 | )
48 |
49 | self.channel_last = Rearrange("B C H W -> B H W C")
50 | self.channel_first = Rearrange("B H W C -> B C H W")
51 |
52 | def __call__(
53 | self,
54 | x: Tensor,
55 | *,
56 | context: Optional[Tensor] = None,
57 | weights: Optional[Tensor] = None,
58 | ) -> Tensor:
59 | B, C, H, W = x.shape
60 |
61 | xin = x
62 |
63 | x = self.proj_in(x)
64 | x = self.channel_last(x)
65 |
66 | for block in self.transformer_blocks:
67 | x = block(x, context=context, weights=weights)
68 | x = self.channel_first(x, H=H, W=W)
69 |
70 | return xin + self.proj_out(x)
71 |
--------------------------------------------------------------------------------
/sd_fused/layers/distribution/__init__.py:
--------------------------------------------------------------------------------
1 | from .diag_gaussian import DiagonalGaussianDistribution
2 |
--------------------------------------------------------------------------------
/sd_fused/layers/distribution/diag_gaussian.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Optional
3 |
4 | import torch
5 | from torch import Tensor
6 |
7 |
8 | class DiagonalGaussianDistribution:
9 | def __init__(
10 | self,
11 | mean: Tensor,
12 | logvar: Tensor,
13 | ) -> None:
14 | super().__init__()
15 |
16 | self.device = mean.device
17 | self.dtype = mean.dtype
18 |
19 | self.mean = mean
20 |
21 | logvar = logvar.clamp(-30, 20)
22 | self.std = torch.exp(logvar / 2)
23 |
24 | def sample(self, generator: Optional[torch.Generator] = None) -> Tensor:
25 | # TODO use seeds?
26 | noise = torch.randn(self.std.shape, generator=generator, device=self.device, dtype=self.dtype)
27 |
28 | return self.mean + self.std * noise
29 |
--------------------------------------------------------------------------------
/sd_fused/layers/embedding/__init__.py:
--------------------------------------------------------------------------------
1 | from .time_steps import Timesteps
2 | from .time_step_emb import TimestepEmbedding
3 |
--------------------------------------------------------------------------------
/sd_fused/layers/embedding/time_step_emb.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from ..base import Sequential
4 | from ..basic import Linear, Identity
5 | from ..activation import SiLU
6 |
7 |
8 | class TimestepEmbedding(Sequential):
9 | def __init__(
10 | self,
11 | *,
12 | channel: int,
13 | time_embed_dim: int,
14 | use_silu: bool = True, # !always true
15 | ) -> None:
16 |
17 | self.channel = channel
18 | self.time_embed_dim = time_embed_dim
19 | self.use_silu = use_silu
20 |
21 | layers = (
22 | Linear(channel, time_embed_dim),
23 | SiLU() if use_silu else Identity(), # ? silu always true?
24 | Linear(time_embed_dim, time_embed_dim),
25 | )
26 |
27 | super().__init__(*layers)
28 |
--------------------------------------------------------------------------------
/sd_fused/layers/embedding/time_steps.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import math
4 |
5 | import torch
6 | from torch import Tensor
7 |
8 | from ..base import Module
9 |
10 |
11 | class Timesteps(Module):
12 | freq: Tensor
13 | amplitude: Tensor
14 | phase: Tensor
15 |
16 | def __init__(
17 | self,
18 | *,
19 | num_channels: int,
20 | flip_sin_to_cos: bool,
21 | downscale_freq_shift: float,
22 | scale: float = 1,
23 | max_period: int = 10_000,
24 | ) -> None:
25 | super().__init__()
26 |
27 | self.num_channels = num_channels
28 | self.flip_sin_to_cos = flip_sin_to_cos
29 | self.downscale_freq_shift = downscale_freq_shift
30 | self.scale = scale
31 | self.max_period = max_period
32 |
33 | assert num_channels % 2 == 0
34 | half_dim = num_channels // 2
35 |
36 | idx = torch.arange(half_dim)
37 |
38 | exponent = -math.log(max_period) * idx
39 | exponent /= half_dim - downscale_freq_shift
40 |
41 | freq = exponent.exp()
42 | freq = torch.cat([freq, freq]).unsqueeze(0)
43 |
44 | amplitude = torch.full((1, num_channels), scale)
45 |
46 | zeros = torch.zeros((1, half_dim))
47 | ones = torch.ones((1, half_dim))
48 | phase = torch.cat([zeros, ones], dim=1) * torch.pi / 2
49 |
50 | if flip_sin_to_cos:
51 | freq = reverse(freq, half_dim)
52 | phase = reverse(phase, half_dim)
53 | amplitude = reverse(amplitude, half_dim)
54 |
55 | # TODO create nn.Buffer
56 | self.freq = freq
57 | self.phase = phase
58 | self.amplitude = amplitude
59 |
60 | def __call__(self, x: Tensor) -> Tensor:
61 | x = x[..., None]
62 |
63 | kwargs = dict(device=self.device, dtype=self.dtype, non_blocking=True)
64 | self.freq = self.freq.to(**kwargs)
65 | self.phase = self.phase.to(**kwargs)
66 | self.amplitude = self.amplitude.to(**kwargs)
67 |
68 | return self.amplitude * torch.sin(x * self.freq + self.phase)
69 |
70 |
71 | def reverse(x: Tensor, half_dim: int) -> Tensor:
72 | return torch.cat([x[:, half_dim:], x[:, :half_dim]], dim=1)
73 |
--------------------------------------------------------------------------------
/sd_fused/layers/external/__init__.py:
--------------------------------------------------------------------------------
1 | from .rearrange import Rearrange
2 |
--------------------------------------------------------------------------------
/sd_fused/layers/external/rearrange.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from torch import Tensor
4 | from einops import rearrange
5 |
6 | from ..base import Module
7 |
8 |
9 | class Rearrange(Module):
10 | def __init__(self, pattern: str, **axes_length: int) -> None:
11 | self.pattern = pattern
12 | self.axes_length = axes_length
13 |
14 | def __call__(self, x: Tensor, **axes_length: int) -> Tensor:
15 | return rearrange(x, self.pattern, **self.axes_length, **axes_length)
16 |
17 | def make_inverse(self) -> Rearrange:
18 | left, right = self.pattern.split("->")
19 |
20 | new = Rearrange(f"{right} -> {left}")
21 |
22 | return new
23 |
--------------------------------------------------------------------------------
/sd_fused/layers/modifiers/__init__.py:
--------------------------------------------------------------------------------
1 | from .half_weights import HalfWeightsModule, half_weights
2 |
--------------------------------------------------------------------------------
/sd_fused/layers/modifiers/half_weights.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing_extensions import Self
3 |
4 | from functools import wraps
5 |
6 | import torch
7 | from torch import Tensor
8 |
9 | from ..base.module import Module
10 |
11 |
12 | class HalfWeightsModule(Module):
13 | use_half_weights: bool = False
14 |
15 | def half_weights(self, use: bool = True) -> Self:
16 | self.use_half_weights = use
17 |
18 | return self.half() if use or self.dtype == torch.float16 else self.float()
19 |
20 |
21 | def half_weights(fun):
22 | @wraps(fun)
23 | def wrapper(self: HalfWeightsModule, *args, **kwargs):
24 | if self.use_half_weights:
25 | self.float()
26 |
27 | args = tuple(a.float() if isinstance(a, Tensor) else a for a in args)
28 | kwargs = {k: v.float() if isinstance(v, Tensor) else v for k, v in kwargs.items()}
29 |
30 | out = fun(self, *args, **kwargs)
31 | self.half()
32 | else:
33 | out = fun(self, *args, **kwargs)
34 |
35 | return out
36 |
37 | return wrapper
38 |
--------------------------------------------------------------------------------
/sd_fused/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .ae_kl import AutoencoderKL
2 | from .unet_conditional import UNet2DConditional
3 |
--------------------------------------------------------------------------------
/sd_fused/models/ae_kl.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing_extensions import Self
3 |
4 | from pathlib import Path
5 | import json
6 |
7 | import torch
8 | from torch import Tensor
9 |
10 | from ..layers.base import Module
11 | from ..layers.basic import Conv2d
12 | from ..layers.auto_encoder import Encoder, Decoder
13 | from ..layers.distribution import DiagonalGaussianDistribution
14 | from ..utils.tensors import normalize, denormalize
15 | from .config import VaeConfig
16 | from .convert import diffusers2fused_vae
17 | from .convert.states import debug_state_replacements
18 | from .modifiers import HalfWeightsModel, SplitAttentionModel, FlashAttentionModel, ToMeModel
19 |
20 |
21 | class AutoencoderKL(HalfWeightsModel, SplitAttentionModel, FlashAttentionModel, ToMeModel, Module):
22 | @classmethod
23 | def from_config(cls, path: str | Path) -> Self:
24 | """Creates a model from a (diffusers) config file."""
25 |
26 | path = Path(path)
27 | if path.is_dir():
28 | path /= "config.json"
29 | assert path.suffix == ".json"
30 |
31 | db = json.load(open(path, "r"))
32 | config = VaeConfig(**db)
33 |
34 | return cls(
35 | in_channels=config.in_channels,
36 | out_channels=config.out_channels,
37 | block_out_channels=tuple(config.block_out_channels),
38 | layers_per_block=config.layers_per_block,
39 | latent_channels=config.latent_channels,
40 | norm_num_groups=config.norm_num_groups,
41 | )
42 |
43 | def __init__(
44 | self,
45 | *,
46 | in_channels: int = 3,
47 | out_channels: int = 3,
48 | block_out_channels: tuple[int, ...] = (128, 256, 512, 512),
49 | layers_per_block: int = 2,
50 | latent_channels: int = 4,
51 | norm_num_groups: int = 32,
52 | ) -> None:
53 | super().__init__()
54 |
55 | self.in_channels = in_channels
56 | self.out_channels = out_channels
57 | self.block_out_channels = block_out_channels
58 | self.layers_per_block = layers_per_block
59 | self.latent_channels = latent_channels
60 | self.norm_num_groups = norm_num_groups
61 |
62 | self.encoder = Encoder(
63 | in_channels=in_channels,
64 | out_channels=latent_channels,
65 | block_out_channels=block_out_channels,
66 | layers_per_block=layers_per_block,
67 | resnet_groups=norm_num_groups,
68 | double_z=True,
69 | )
70 |
71 | self.decoder = Decoder(
72 | in_channels=latent_channels,
73 | out_channels=out_channels,
74 | block_out_channels=block_out_channels,
75 | layers_per_block=layers_per_block,
76 | resnet_groups=norm_num_groups,
77 | )
78 |
79 | # TODO very bad names...
80 | self.quant_conv = Conv2d(2 * latent_channels)
81 | self.post_quant_conv = Conv2d(latent_channels)
82 |
83 | def encode(self, x: Tensor) -> DiagonalGaussianDistribution:
84 | """Encode an byte-Tensor into a posterior distribution."""
85 |
86 | x = normalize(x, self.dtype)
87 | x = self.encoder(x)
88 |
89 | moments = self.quant_conv(x)
90 | mean, logvar = moments.chunk(2, dim=1)
91 |
92 | return DiagonalGaussianDistribution(mean, logvar)
93 |
94 | def decode(self, z: Tensor) -> Tensor:
95 | """Decode the latent's space into an image."""
96 |
97 | z = self.post_quant_conv(z)
98 | out = self.decoder(z)
99 |
100 | out = denormalize(out)
101 |
102 | return out
103 |
104 | def __call__(self):
105 | raise ValueError("This function is not callable")
106 |
107 | @classmethod
108 | def from_diffusers(cls, path: str | Path) -> Self:
109 | """Load Stable-Diffusion from diffusers checkpoint folder."""
110 |
111 | path = Path(path)
112 | model = cls.from_config(path)
113 |
114 | state_path = next(path.glob("*.bin"))
115 | old_state = torch.load(state_path, map_location="cpu")
116 | replaced_state = diffusers2fused_vae(old_state)
117 |
118 | debug_state_replacements(model.state_dict(), replaced_state)
119 |
120 | model.load_state_dict(replaced_state)
121 |
122 | return model
123 |
--------------------------------------------------------------------------------
/sd_fused/models/config.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Type
3 |
4 | from dataclasses import dataclass
5 |
6 | from ..layers.blocks.spatial import DownBlock2D, UpBlock2D, CrossAttentionDownBlock2D, CrossAttentionUpBlock2D
7 |
8 |
9 | @dataclass
10 | class VaeConfig:
11 | _class_name: str
12 | _diffusers_version: str
13 | act_fn: str
14 | in_channels: int
15 | latent_channels: int
16 | layers_per_block: int
17 | out_channels: int
18 | sample_size: int
19 | block_out_channels: list[int]
20 | down_block_types: list[str]
21 | up_block_types: list[str]
22 | norm_num_groups: int = 32
23 |
24 | def __post_init__(self) -> None:
25 | assert self._class_name == "AutoencoderKL"
26 | assert self.act_fn == "silu"
27 |
28 | for block in self.down_block_types:
29 | assert block == "DownEncoderBlock2D"
30 |
31 | for block in self.up_block_types:
32 | assert block == "UpDecoderBlock2D"
33 |
34 |
35 | @dataclass
36 | class UnetConfig:
37 | _class_name: str
38 | _diffusers_version: str
39 | act_fn: str
40 | attention_head_dim: int
41 | block_out_channels: list[int]
42 | center_input_sample: bool
43 | cross_attention_dim: int
44 | down_block_types: list[str]
45 | downsample_padding: int
46 | flip_sin_to_cos: bool
47 | freq_shift: int
48 | in_channels: int
49 | layers_per_block: int
50 | mid_block_scale_factor: int
51 | norm_eps: float
52 | norm_num_groups: int
53 | out_channels: int
54 | sample_size: int
55 | up_block_types: list[str]
56 |
57 | def __post_init__(self) -> None:
58 | assert self._class_name == "UNet2DConditionModel"
59 | assert self.act_fn == "silu"
60 |
61 | for block in self.down_block_types:
62 | assert block in ("CrossAttnDownBlock2D", "DownBlock2D")
63 | for block in self.up_block_types:
64 | assert block in ("UpBlock2D", "CrossAttnUpBlock2D")
65 |
66 | @property
67 | def down_blocks(
68 | self,
69 | ) -> tuple[Type[CrossAttentionDownBlock2D] | Type[DownBlock2D], ...]:
70 | def get_block(
71 | block: str,
72 | ) -> Type[CrossAttentionDownBlock2D] | Type[DownBlock2D]:
73 | if block == "CrossAttnDownBlock2D":
74 | return CrossAttentionDownBlock2D
75 | if block == "DownBlock2D":
76 | return DownBlock2D
77 |
78 | raise ValueError
79 |
80 | return tuple(get_block(block) for block in self.down_block_types)
81 |
82 | @property
83 | def up_blocks(
84 | self,
85 | ) -> tuple[Type[CrossAttentionUpBlock2D] | Type[UpBlock2D], ...]:
86 | def get_block(
87 | block: str,
88 | ) -> Type[CrossAttentionUpBlock2D] | Type[UpBlock2D]:
89 | if block == "CrossAttnUpBlock2D":
90 | return CrossAttentionUpBlock2D
91 | if block == "UpBlock2D":
92 | return UpBlock2D
93 |
94 | raise ValueError("Invalid")
95 |
96 | return tuple(get_block(block) for block in self.up_block_types)
97 |
--------------------------------------------------------------------------------
/sd_fused/models/convert/__init__.py:
--------------------------------------------------------------------------------
1 | from .vae.sd2diffusers import sd2diffusers as sd2diffusers_vae
2 | from .vae.diffusers2fused import diffusers2fused as diffusers2fused_vae
3 |
4 | from .unet.diffusers2fused import diffusers2fused as diffusers2fused_unet
5 |
--------------------------------------------------------------------------------
/sd_fused/models/convert/states.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import re
4 |
5 | import torch.nn as nn
6 | from torch import Tensor
7 |
8 |
9 | def replace_state(
10 | old_state: dict[str, Tensor],
11 | replacements: list[tuple[str, str]],
12 | ) -> dict[str, Tensor]:
13 | """Replace the state-dict with new keys"""
14 |
15 | state: dict[str, Tensor] = {}
16 | for key in old_state.keys():
17 | new_key = key
18 | for old, new in replacements:
19 | new_key = re.sub(old, new, new_key)
20 |
21 | state[new_key] = old_state[key]
22 |
23 | return state
24 |
25 |
26 | def debug_state_replacements(
27 | state: dict[str, Tensor] | dict[str, nn.Parameter],
28 | replaced_state: dict[str, Tensor],
29 | ) -> None:
30 | good_keys = set(state.keys())
31 | replaced_keys = set(replaced_state.keys())
32 |
33 | delta = good_keys - replaced_keys
34 | if len(delta) != 0:
35 | print("miss replacing some keys")
36 | print("=" * 32)
37 | for key in delta:
38 | print(key)
39 |
40 | delta = replaced_keys - good_keys
41 | if len(delta) != 0:
42 | print("wrongly replaced some keys")
43 | print("=" * 32)
44 | for key in replaced_keys - good_keys:
45 | print(key)
46 |
--------------------------------------------------------------------------------
/sd_fused/models/convert/unet/diffusers2fused.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from torch import Tensor
4 |
5 | from ..states import replace_state
6 |
7 |
8 | def diffusers2fused(old_state: dict[str, Tensor]) -> dict[str, Tensor]:
9 | """Convert a diffusers checkpoint into a sd-fused one for the unet."""
10 |
11 | return replace_state(old_state, REPLACEMENTS)
12 |
13 |
14 | # fmt: off
15 | # diffusers to sd-fused replacements
16 | REPLACEMENTS: list[tuple[str, str]] = [
17 | # Cross-attention
18 | (r"transformer_blocks.(\d).norm([12]).(weight|bias)", r"transformer_blocks.\1.attn\2.norm.\3"),
19 | ## FeedForward (norm)
20 | (r"transformer_blocks.(\d).norm3.(weight|bias)", r"transformer_blocks.\1.ff.0.\2"),
21 | ## FeedForward (geglu)
22 | (r"ff.net.0.proj.(weight|bias)", r"ff.1.proj.\1"),
23 | ## FeedForward-Linear
24 | (r"ff.net.2.(weight|bias)", r"ff.2.\1"),
25 | # up/down samplers
26 | (r"(up|down)samplers.0", r"\1sampler"),
27 | # CrossAttention projection
28 | (r"to_out.0.", r"to_out."),
29 | # TimeEmbedding
30 | (r"time_embedding.linear_1.(weight|bias)", r"time_embedding.0.\1"),
31 | (r"time_embedding.linear_2.(weight|bias)", r"time_embedding.2.\1"),
32 | # resnet-blocks pre/post-process
33 | (r"resnets.(\d).norm1.(bias|weight)", r"resnets.\1.pre_process.0.\2"),
34 | (r"resnets.(\d).conv1.(bias|weight)", r"resnets.\1.pre_process.2.\2"),
35 | (r"resnets.(\d).norm2.(bias|weight)", r"resnets.\1.post_process.0.\2"),
36 | (r"resnets.(\d).conv2.(bias|weight)", r"resnets.\1.post_process.2.\2"),
37 | # resnet-time-embedding
38 | (r"time_emb_proj.(bias|weight)", r"time_emb_proj.1.\1"),
39 | # spatial transformer fused
40 | (r"attentions.(\d).norm.(bias|weight)", r"attentions.\1.proj_in.0.\2"),
41 | (r"attentions.(\d).proj_in.(bias|weight)", r"attentions.\1.proj_in.1.\2"),
42 | # post-processing
43 | (r"conv_norm_out.(bias|weight)", r"post_process.0.\1"),
44 | (r"conv_out.(bias|weight)", r"post_process.2.\1"),
45 | # pre-processing
46 | (r'conv_in.(bias|weight)', r'pre_process.\1')
47 | ]
48 |
--------------------------------------------------------------------------------
/sd_fused/models/convert/vae/diffusers2fused.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from torch import Tensor
4 |
5 | from ..states import replace_state
6 |
7 | # ! Tensor/Parameter?
8 | def diffusers2fused(old_state: dict[str, Tensor]) -> dict[str, Tensor]:
9 | """Convert a diffusers checkpoint into a sd-fused one for the AutoencoderKL."""
10 |
11 | return replace_state(old_state, REPLACEMENTS)
12 |
13 |
14 | # fmt: off
15 | # diffusers to sd-fused replacements
16 | REPLACEMENTS: list[tuple[str, str]] = [
17 | # up/down samplers
18 | (r"(up|down)samplers.0", r"\1sampler"),
19 | # post_process
20 | (r"(decoder|encoder).conv_norm_out.(bias|weight)", r"\1.post_process.0.\2"),
21 | (r"(decoder|encoder).conv_out.(bias|weight)", r"\1.post_process.2.\2"),
22 | # resnet-blocks pre/post-process
23 | (r"resnets.(\d).norm1.(bias|weight)", r"resnets.\1.pre_process.0.\2"),
24 | (r"resnets.(\d).conv1.(bias|weight)", r"resnets.\1.pre_process.2.\2"),
25 | (r"resnets.(\d).norm2.(bias|weight)", r"resnets.\1.post_process.0.\2"),
26 | (r"resnets.(\d).conv2.(bias|weight)", r"resnets.\1.post_process.2.\2"),
27 | # pre-processing
28 | (r'(encoder|decoder).conv_in.(bias|weight)', r'\1.pre_process.\2')
29 | ]
30 |
--------------------------------------------------------------------------------
/sd_fused/models/convert/vae/sd2diffusers.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import re
4 | from torch import Tensor
5 |
6 | # TODO make a config from sd?
7 | # TODO work in progress...
8 |
9 |
10 | def sd2diffusers(old_state: dict[str, Tensor]) -> dict[str, Tensor]:
11 | """Convert a Stable-Diffusion checkpoint into a diffusers-one for the AutoencoderKL."""
12 |
13 | VAE = "first_stage_model."
14 |
15 | state: dict[str, Tensor] = {}
16 | for key in old_state.keys():
17 | if not key.startswith(VAE):
18 | continue
19 |
20 | new_key = key.replace(VAE, "")
21 |
22 | for old, new in REPLACEMENTS:
23 | new_key = re.sub(old, new, new_key)
24 |
25 | state[new_key] = old_state[key]
26 |
27 | # reshape attentions weights
28 | # needed due to conv -> linear conversion
29 | if re.match(
30 | r"(encoder|decoder).mid_block.attentions.\d.(key|query|value|proj_attn).weight",
31 | new_key,
32 | ):
33 | state[new_key] = state[new_key].squeeze()
34 |
35 | return state
36 |
37 |
38 | # fmt: off
39 | # Stable-diffusion to diffusers replacements
40 | REPLACEMENTS: list[tuple[str, str]] = [
41 | # reverse up-block order
42 | (r"up\.0", r"UP.3"),
43 | (r"up\.1", r"UP.2"),
44 | (r"up\.2", r"UP.1"),
45 | (r"up\.3", r"UP.0"),
46 | (r"UP", "up"), # recover lower-case
47 | # short-cut connection
48 | (r"nin_shortcut", r"conv_shortcut"),
49 | # encoder/decoder up/down-blocks
50 | (
51 | r"(encoder|decoder)\.(down|up)\.(\d)\.block\.(\d)\.(norm\d|conv\d|conv_shortcut)\.(weight|bias)",
52 | r"\1.\2_blocks.\3.resnets.\4.\5.\6",
53 | ),
54 | (
55 | r"(encoder|decoder)\.(down|up)\.(\d)\.(downsample|upsample)\.conv\.(weight|bias)",
56 | r"\1.\2_blocks.\3.\4rs.0.conv.\5",
57 | ),
58 | # mid-blocks
59 | (r"block_1", r"block_0"),
60 | (r"block_2", r"block_1"),
61 | (
62 | r"(encoder|decoder)\.mid\.block_(\d)\.(norm\d|conv\d)\.(weight|bias)",
63 | r"\1.mid_block.resnets.\2.\3.\4",
64 | ),
65 | # mid-block attention
66 | (r"\.q\.", r".query."),
67 | (r"\.k\.", r".key."),
68 | (r"\.v\.", r".value."),
69 | (r"attn_1\.proj_out", r"attn_1.proj_attn"),
70 | (r"attn_1\.norm", r"attn_1.group_norm"),
71 | (
72 | r"(encoder|decoder)\.mid\.attn_1\.(group_norm|query|key|value|proj_attn)\.(weight|bias)",
73 | r"\1.mid_block.attentions.0.\2.\3",
74 | ),
75 | # in-out
76 | (r"(encoder|decoder)\.(norm_out)\.(weight|bias)", r"\1.conv_\2.\3"),
77 | ]
78 |
--------------------------------------------------------------------------------
/sd_fused/models/modifiers/__init__.py:
--------------------------------------------------------------------------------
1 | from .half_weights import HalfWeightsModel
2 | from .split_attention import SplitAttentionModel
3 | from .flash_attention import FlashAttentionModel
4 | from .tome import ToMeModel
5 |
--------------------------------------------------------------------------------
/sd_fused/models/modifiers/flash_attention.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing_extensions import Self
3 |
4 | from ...layers.base import Module
5 | from ...layers.blocks.attention import CrossAttention, SelfAttention
6 |
7 |
8 | class FlashAttentionModel(Module):
9 | def flash_attention(self, use: bool = True) -> Self:
10 | """Use xformers flash-attention."""
11 |
12 | for name, module in self.named_modules().items():
13 | if isinstance(module, (CrossAttention, SelfAttention)):
14 | module.use_flash_attention = use
15 |
16 | if use:
17 | module.attention_chunks = None
18 | module.chunk_type = None
19 |
20 | return self
21 |
--------------------------------------------------------------------------------
/sd_fused/models/modifiers/half_weights.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing_extensions import Self
3 |
4 | from ...layers.base import Module
5 | from ...layers.blocks.attention import CrossAttention, SelfAttention
6 | from ...layers.modifiers import HalfWeightsModule
7 |
8 |
9 | class HalfWeightsModel(Module):
10 | def half_weights(self, use: bool = True) -> Self:
11 | """Store the weights in half-precision but
12 | compute forward pass in full precision.
13 | Useful for GPUs that gives NaN when used in half-precision.
14 | """
15 |
16 | for name, module in self.named_modules().items():
17 | if isinstance(module, HalfWeightsModule):
18 | module.half_weights(use)
19 |
20 | return self
21 |
--------------------------------------------------------------------------------
/sd_fused/models/modifiers/split_attention.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Optional
3 | from typing_extensions import Self
4 |
5 | from ...layers.base import Module
6 | from ...utils.typing import Literal
7 | from ...layers.blocks.attention import CrossAttention, SelfAttention
8 | from ...layers.blocks.attention.compute import ChunkType
9 |
10 |
11 | class SplitAttentionModel(Module):
12 | def split_attention(
13 | self,
14 | chunks: Optional[int | Literal["auto"]] = "auto",
15 | chunk_type: Optional[ChunkType] = None,
16 | ) -> Self:
17 | """Split cross/self-attention computation into chunks."""
18 |
19 | for name, module in self.named_modules().items():
20 | if isinstance(module, (CrossAttention, SelfAttention)):
21 | module.attention_chunks = chunks
22 | module.chunk_type = chunk_type
23 |
24 | if chunks is not None:
25 | module.use_flash_attention = False
26 |
27 | return self
28 |
--------------------------------------------------------------------------------
/sd_fused/models/modifiers/tome.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Optional
3 | from typing_extensions import Self
4 |
5 | from ...layers.base import Module
6 | from ...layers.blocks.attention import CrossAttention, SelfAttention
7 |
8 |
9 | class ToMeModel(Module):
10 | # https://arxiv.org/abs/2210.09461
11 | # https://github.com/facebookresearch/ToMe/blob/main/tome/merge.py
12 |
13 | def tome(self, r: Optional[int | float] = None) -> Self:
14 | """Merge similar tokens."""
15 |
16 | for name, module in self.named_modules().items():
17 | if isinstance(module, (CrossAttention, SelfAttention)):
18 | module.tome_r = r
19 |
20 | return self
21 |
--------------------------------------------------------------------------------
/sd_fused/models/unet_conditional.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Optional, Type
3 | from typing_extensions import Self
4 |
5 | from pathlib import Path
6 | import json
7 |
8 | import torch
9 | from torch import Tensor
10 |
11 | from ..layers.base import Module, ModuleList
12 | from ..layers.basic import Conv2d
13 | from ..layers.embedding import Timesteps, TimestepEmbedding
14 | from ..layers.blocks.basic import GroupNormSiLUConv2d
15 | from ..layers.blocks.spatial import (
16 | UNetMidBlock2DCrossAttention,
17 | DownBlock2D,
18 | UpBlock2D,
19 | CrossAttentionDownBlock2D,
20 | CrossAttentionUpBlock2D,
21 | )
22 | from ..utils.tensors import to_tensor
23 | from .modifiers import HalfWeightsModel, SplitAttentionModel, FlashAttentionModel, ToMeModel
24 | from .config import UnetConfig
25 | from .convert import diffusers2fused_unet
26 | from .convert.states import debug_state_replacements
27 |
28 |
29 | class UNet2DConditional(HalfWeightsModel, SplitAttentionModel, FlashAttentionModel, ToMeModel, Module):
30 | @classmethod
31 | def from_config(cls, path: str | Path) -> Self:
32 | """Creates a model from a (diffusers) config file."""
33 |
34 | path = Path(path)
35 | if path.is_dir():
36 | path /= "config.json"
37 | assert path.suffix == ".json"
38 |
39 | db = json.load(open(path, "r"))
40 | config = UnetConfig(**db)
41 |
42 | return cls(
43 | in_channels=config.in_channels,
44 | out_channels=config.out_channels,
45 | flip_sin_to_cos=config.flip_sin_to_cos,
46 | freq_shift=config.freq_shift,
47 | down_blocks=config.down_blocks,
48 | up_blocks=config.up_blocks,
49 | block_out_channels=tuple(config.block_out_channels),
50 | layers_per_block=config.layers_per_block,
51 | downsample_padding=config.downsample_padding,
52 | norm_num_groups=config.norm_num_groups,
53 | cross_attention_dim=config.cross_attention_dim,
54 | attention_head_dim=config.attention_head_dim,
55 | )
56 |
57 | def __init__(
58 | self,
59 | *,
60 | in_channels: int = 4,
61 | out_channels: int = 4,
62 | flip_sin_to_cos: bool = True,
63 | freq_shift: int = 0,
64 | down_blocks: tuple[Type[CrossAttentionDownBlock2D] | Type[DownBlock2D], ...] = (
65 | CrossAttentionDownBlock2D,
66 | CrossAttentionDownBlock2D,
67 | CrossAttentionDownBlock2D,
68 | DownBlock2D,
69 | ),
70 | up_blocks: tuple[Type[CrossAttentionUpBlock2D] | Type[UpBlock2D], ...] = (
71 | UpBlock2D,
72 | CrossAttentionUpBlock2D,
73 | CrossAttentionUpBlock2D,
74 | CrossAttentionUpBlock2D,
75 | ),
76 | block_out_channels: tuple[int, ...] = (320, 640, 1280, 1280),
77 | layers_per_block: int = 2,
78 | downsample_padding: int = 1,
79 | norm_num_groups: int = 32,
80 | cross_attention_dim: int = 768,
81 | attention_head_dim: int = 8,
82 | ) -> None:
83 | super().__init__()
84 |
85 | self.in_channels = in_channels
86 | self.out_channels = out_channels
87 | self.flip_sin_to_cos = flip_sin_to_cos
88 | self.freq_shift = freq_shift
89 | self.block_out_channels = block_out_channels
90 | self.layers_per_block = layers_per_block
91 | self.downsample_padding = downsample_padding
92 | self.norm_num_groups = norm_num_groups
93 | self.cross_attention_dim = cross_attention_dim
94 | self.attention_head_dim = attention_head_dim
95 |
96 | time_embed_dim = block_out_channels[0] * 4
97 |
98 | # input
99 | self.pre_process = Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=1)
100 |
101 | # time
102 | self.time_proj = Timesteps(
103 | num_channels=block_out_channels[0],
104 | flip_sin_to_cos=flip_sin_to_cos,
105 | downscale_freq_shift=freq_shift,
106 | )
107 | timestep_input_dim = block_out_channels[0]
108 |
109 | self.time_embedding = TimestepEmbedding(channel=timestep_input_dim, time_embed_dim=time_embed_dim)
110 |
111 | # down
112 | output_channel = block_out_channels[0]
113 | self.down_blocks = ModuleList[CrossAttentionDownBlock2D | DownBlock2D]()
114 | for i, block in enumerate(down_blocks):
115 | input_channel = output_channel
116 | output_channel = block_out_channels[i]
117 | is_final_block = i == len(block_out_channels) - 1
118 |
119 | if block == CrossAttentionDownBlock2D:
120 | self.down_blocks.append(
121 | CrossAttentionDownBlock2D(
122 | in_channels=input_channel,
123 | out_channels=output_channel,
124 | temb_channels=time_embed_dim,
125 | num_layers=layers_per_block,
126 | resnet_groups=norm_num_groups,
127 | attn_num_head_channels=attention_head_dim,
128 | cross_attention_dim=cross_attention_dim,
129 | downsample_padding=downsample_padding,
130 | add_downsample=not is_final_block,
131 | )
132 | )
133 | elif block == DownBlock2D:
134 | self.down_blocks.append(
135 | DownBlock2D(
136 | in_channels=input_channel,
137 | out_channels=output_channel,
138 | temb_channels=time_embed_dim,
139 | num_layers=layers_per_block,
140 | num_groups=norm_num_groups,
141 | add_downsample=not is_final_block,
142 | downsample_padding=downsample_padding,
143 | )
144 | )
145 |
146 | # mid
147 | self.mid_block = UNetMidBlock2DCrossAttention(
148 | in_channels=block_out_channels[-1],
149 | temb_channels=time_embed_dim,
150 | cross_attention_dim=cross_attention_dim,
151 | attn_num_head_channels=attention_head_dim,
152 | resnet_groups=norm_num_groups,
153 | num_layers=1,
154 | )
155 |
156 | # up
157 | reversed_block_out_channels = tuple(reversed(block_out_channels))
158 | output_channel = reversed_block_out_channels[0]
159 | self.up_blocks = ModuleList[CrossAttentionUpBlock2D | UpBlock2D]()
160 | for i, block in enumerate(up_blocks):
161 | prev_output_channel = output_channel
162 | output_channel = reversed_block_out_channels[i]
163 | input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
164 |
165 | is_final_block = i == len(block_out_channels) - 1
166 |
167 | if block == CrossAttentionUpBlock2D:
168 | self.up_blocks.append(
169 | CrossAttentionUpBlock2D(
170 | in_channels=input_channel,
171 | out_channels=output_channel,
172 | prev_output_channel=prev_output_channel,
173 | num_layers=layers_per_block + 1,
174 | temb_channels=time_embed_dim,
175 | add_upsample=not is_final_block,
176 | resnet_groups=norm_num_groups,
177 | cross_attention_dim=cross_attention_dim,
178 | attn_num_head_channels=attention_head_dim,
179 | )
180 | )
181 | elif block == UpBlock2D:
182 | self.up_blocks.append(
183 | UpBlock2D(
184 | in_channels=input_channel,
185 | out_channels=output_channel,
186 | prev_output_channel=prev_output_channel,
187 | num_layers=layers_per_block + 1,
188 | temb_channels=time_embed_dim,
189 | add_upsample=not is_final_block,
190 | resnet_groups=norm_num_groups,
191 | )
192 | )
193 | prev_output_channel = output_channel
194 |
195 | # out
196 | self.post_process = GroupNormSiLUConv2d(
197 | norm_num_groups,
198 | block_out_channels[0],
199 | out_channels,
200 | kernel_size=3,
201 | padding=1,
202 | )
203 |
204 | def __call__(
205 | self,
206 | x: Tensor,
207 | timestep: int | Tensor,
208 | context: Tensor,
209 | weights: Optional[Tensor] = None,
210 | ) -> Tensor:
211 | B, C, H, W = x.shape
212 |
213 | # 1. time embedding
214 | timestep = to_tensor(timestep, device=x.device, dtype=x.dtype)
215 | if timestep.size(0) != B:
216 | assert timestep.size(0) == 1
217 | timestep = timestep.expand(B)
218 |
219 | temb = self.time_proj(timestep)
220 | temb = self.time_embedding(temb)
221 |
222 | # 2. pre-process
223 | x = self.pre_process(x)
224 |
225 | # 3. down
226 | # TODO it is possible to make it a list[list[Tensor]]? or is the number of elements wrong?
227 | all_states: list[Tensor] = [x]
228 | for block in self.down_blocks:
229 | if isinstance(block, CrossAttentionDownBlock2D):
230 | x, states = block(x, temb=temb, context=context, weights=weights)
231 | elif isinstance(block, DownBlock2D):
232 | x, states = block(x, temb=temb)
233 | else:
234 | raise ValueError
235 |
236 | all_states.extend(states)
237 | del states
238 |
239 | # 4. mid
240 | x = self.mid_block(x, temb=temb, context=context, weights=weights)
241 |
242 | # 5. up
243 | for block in self.up_blocks:
244 | assert isinstance(block, (CrossAttentionUpBlock2D, UpBlock2D))
245 |
246 | # ! I don't like the construction of this...
247 | states = list(all_states.pop() for _ in range(block.num_layers))
248 |
249 | if isinstance(block, CrossAttentionUpBlock2D):
250 | x = block(x, states=states, temb=temb, context=context, weights=weights)
251 | elif isinstance(block, UpBlock2D):
252 | x = block(x, states=states, temb=temb)
253 | else:
254 | raise ValueError
255 |
256 | del states
257 | del all_states
258 |
259 | # 6. post-process
260 | x = self.post_process(x)
261 |
262 | return x
263 |
264 | @classmethod
265 | def from_diffusers(cls, path: str | Path) -> Self:
266 | """Load Stable-Diffusion from diffusers checkpoint folder."""
267 |
268 | path = Path(path)
269 | model = cls.from_config(path)
270 |
271 | state_path = next(path.glob("*.bin"))
272 | old_state = torch.load(state_path, map_location="cpu")
273 | replaced_state = diffusers2fused_unet(old_state)
274 |
275 | debug_state_replacements(model.state_dict(), replaced_state)
276 |
277 | model.load_state_dict(replaced_state)
278 |
279 | return model
280 |
--------------------------------------------------------------------------------
/sd_fused/scheduler/__init__.py:
--------------------------------------------------------------------------------
1 | from .ddim import Scheduler, DDIMScheduler
2 |
--------------------------------------------------------------------------------
/sd_fused/scheduler/ddim.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Optional
3 |
4 | import math
5 |
6 | import torch
7 | from torch import Tensor
8 |
9 | from ..layers.base.types import Device
10 | from ..utils.tensors import to_tensor
11 | from .scheduler import Scheduler
12 |
13 | TRAINED_STEPS = 1_000
14 | BETA_BEGIN = 0.00085
15 | BETA_END = 0.012
16 | POWER = 2
17 |
18 |
19 | class DDIMScheduler(Scheduler):
20 | """Denoising Diffusion Implicit Models scheduler."""
21 |
22 | # https://arxiv.org/abs/2010.02502
23 |
24 | ᾱ: Tensor
25 | ϖ: Tensor
26 | σ: Tensor
27 |
28 | @property # TODO Type
29 | def info(self) -> dict[str, str | int]:
30 | return dict(
31 | name=self.__class__.__qualname__,
32 | steps=self.steps,
33 | skip_timestep=self.skip_timestep,
34 | )
35 |
36 | def __init__(
37 | self,
38 | steps: int,
39 | shape: tuple[int, ...],
40 | seeds: list[int],
41 | strength: Optional[float] = None, # img2img/inpainting
42 | device: Optional[Device] = None,
43 | dtype: torch.dtype = torch.float32,
44 | ) -> None:
45 | super().__init__(steps, shape, seeds, strength, device, dtype)
46 |
47 | assert steps <= TRAINED_STEPS
48 |
49 | # scheduler betas and alphas
50 | β_begin = math.pow(BETA_BEGIN, 1 / POWER)
51 | β_end = math.pow(BETA_END, 1 / POWER)
52 | β = torch.linspace(β_begin, β_end, TRAINED_STEPS).pow(POWER)
53 | β = β.to(torch.float64) # extra-precision
54 |
55 | # increase steps by 1 to account last timestep
56 | steps += 1
57 |
58 | # trimmed timesteps for selection
59 | timesteps = torch.linspace(TRAINED_STEPS - 1, 0, steps).ceil().long()
60 |
61 | # cummulative ᾱ trimmed
62 | α = 1 - β
63 | ᾱ = α.cumprod(dim=0)
64 | ᾱ /= ᾱ.max() # makes last-value=1
65 | ᾱ = ᾱ[timesteps]
66 | ϖ = 1 - ᾱ
67 | del α, β # reminder that is not used anymore
68 |
69 | # standard deviation, eq (16)
70 | σ = torch.sqrt(ϖ[1:] / ϖ[:-1] * (1 - ᾱ[:-1] / ᾱ[1:]))
71 |
72 | # use device/dtype
73 | self.timesteps = timesteps.to(device=device)
74 | self.ᾱ = ᾱ.to(device=device, dtype=dtype)
75 | self.ϖ = ϖ.to(device=device, dtype=dtype)
76 | self.σ = σ.to(device=device, dtype=dtype)
77 |
78 | def add_noise(
79 | self,
80 | latents: Tensor,
81 | noise: Tensor,
82 | index: int,
83 | ) -> Tensor:
84 | # eq 4
85 | return latents * self.ᾱ[index].sqrt() + noise * self.ϖ[index].sqrt()
86 |
87 | def step(
88 | self,
89 | pred_noise: Tensor,
90 | latents: Tensor,
91 | index: int,
92 | etas: Optional[float | Tensor] = None,
93 | ) -> Tensor:
94 | etas = to_tensor(etas, self.device, self.dtype, add_spatial=True)
95 |
96 | # eq (12) part 1
97 | pred_latent = latents - self.ϖ[index].sqrt() * pred_noise
98 | pred_latent /= self.ᾱ[index].sqrt()
99 |
100 | # eq (12) part 2
101 | temp = 1 - self.ᾱ[index + 1] - self.σ[index].mul(etas).square()
102 | pred_dir = temp.abs().sqrt() * pred_noise
103 |
104 | # eq (12) part 3
105 | noise = self.noise[index] * self.σ[index] * etas
106 |
107 | # full eq (12)
108 | latents = pred_latent * self.ᾱ[index + 1].sqrt() + pred_dir + noise
109 |
110 | return latents
111 |
112 | def __repr__(self) -> str:
113 | name = self.__class__.__qualname__
114 |
115 | return f"{name}(steps={self.steps})"
116 |
--------------------------------------------------------------------------------
/sd_fused/scheduler/scheduler.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Optional
3 |
4 | from abc import ABC, abstractmethod
5 |
6 | import math
7 |
8 | import torch
9 | from torch import Tensor
10 |
11 | from ..layers.base.types import Device
12 | from ..models import UNet2DConditional
13 | from ..utils.tensors import generate_noise
14 |
15 |
16 | class Scheduler(ABC):
17 | """Base-class for all schedulers."""
18 |
19 | timesteps: Tensor
20 |
21 | def __init__(
22 | self,
23 | steps: int,
24 | shape: tuple[int, ...],
25 | seeds: list[int],
26 | strength: Optional[float] = None, # img2img/inpainting
27 | device: Optional[Device] = None,
28 | dtype: torch.dtype = torch.float32,
29 | ) -> None:
30 | if strength is not None:
31 | assert 0 < strength <= 1
32 |
33 | self.steps = steps
34 | self.shape = shape
35 | self.seeds = seeds
36 | self.strength = strength
37 | self.device = device
38 | self.dtype = dtype
39 |
40 | # TODO add sub-seeds
41 | self.noise = generate_noise(shape, seeds, device, dtype, steps)
42 |
43 | @property
44 | def skip_timestep(self) -> int:
45 | """Text-to-Image generation starting timestep."""
46 |
47 | if self.strength is None:
48 | return 0
49 |
50 | return math.ceil(self.steps * (1 - self.strength))
51 |
52 | @abstractmethod
53 | def add_noise(self, latents: Tensor, noise: Tensor, index: int) -> Tensor:
54 | """Add noise for a timestep."""
55 |
56 | def prepare_latents(
57 | self,
58 | image_latents: Optional[Tensor] = None,
59 | mask_latents: Optional[Tensor] = None, # ! old-stype inpainting
60 | masked_image_latents: Optional[Tensor] = None, # ! new-style inpainting
61 | ) -> Tensor:
62 | """Prepare initial latents for generation."""
63 |
64 | noise = self.noise[0]
65 |
66 | if image_latents is None:
67 | return noise
68 |
69 | if mask_latents is None and masked_image_latents is None:
70 | return self.add_noise(image_latents, noise, self.skip_timestep)
71 |
72 | # TODO inpainting
73 | raise NotImplementedError
74 |
75 | @abstractmethod
76 | def step(
77 | self,
78 | pred_noise: Tensor,
79 | latents: Tensor,
80 | index: int,
81 | **kwargs, # ! type
82 | ) -> Tensor:
83 | """Get the previous timestep for the latents."""
84 |
85 | @torch.no_grad()
86 | def pred_noise(
87 | self,
88 | unet: UNet2DConditional,
89 | latents: Tensor,
90 | timestep: int,
91 | context: Tensor,
92 | weights: Optional[Tensor],
93 | scale: Optional[Tensor],
94 | unconditional: bool,
95 | low_ram: bool,
96 | ) -> Tensor:
97 | """Noise prediction for a given timestep."""
98 |
99 | if unconditional:
100 | assert scale is None
101 |
102 | return unet(latents, timestep, context, weights)
103 |
104 | assert scale is not None
105 |
106 | if low_ram:
107 | negative_context, prompt_context = context.chunk(2, dim=0)
108 | if weights is not None:
109 | negative_weight, prompt_weight = weights.chunk(2, dim=0)
110 | else:
111 | negative_weight = prompt_weight = None
112 |
113 | pred_noise_prompt = unet(latents, timestep, prompt_context, prompt_weight)
114 | pred_noise_negative = unet(latents, timestep, negative_context, negative_weight)
115 | else:
116 | latents = torch.cat([latents] * 2, dim=0)
117 |
118 | pred_noise_all = unet(latents, timestep, context, weights)
119 | pred_noise_negative, pred_noise_prompt = pred_noise_all.chunk(2, dim=0)
120 |
121 | scale = scale[:, None, None, None] # add fake channel/spatial dimensions
122 |
123 | latents = pred_noise_negative + (pred_noise_prompt - pred_noise_negative) * scale
124 |
125 | return latents
126 |
--------------------------------------------------------------------------------
/sd_fused/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tfernd/sd-fused/50570f983cc00dd4bd0dc415d0b515da638311b0/sd_fused/utils/__init__.py
--------------------------------------------------------------------------------
/sd_fused/utils/cuda/__init__.py:
--------------------------------------------------------------------------------
1 | from .clear_cuda import clear_cuda
2 | from .free_memory import free_memory
3 |
--------------------------------------------------------------------------------
/sd_fused/utils/cuda/clear_cuda.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import gc
4 | import torch
5 |
6 |
7 | def clear_cuda() -> None:
8 | """Clear CUDA memory and garbage collection."""
9 |
10 | gc.collect()
11 | if torch.cuda.is_available():
12 | torch.cuda.empty_cache()
13 | torch.cuda.ipc_collect()
14 |
--------------------------------------------------------------------------------
/sd_fused/utils/cuda/free_memory.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import torch
4 |
5 |
6 | def free_memory() -> int:
7 | """Amount of free memory available."""
8 |
9 | assert torch.cuda.is_available()
10 |
11 | stats = torch.cuda.memory_stats()
12 |
13 | reserved = stats["reserved_bytes.all.current"]
14 | active = stats["active_bytes.all.current"]
15 | free = torch.cuda.mem_get_info()[0] # type: ignore
16 |
17 | free += reserved - active
18 |
19 | return free
20 |
--------------------------------------------------------------------------------
/sd_fused/utils/diverse/__init__.py:
--------------------------------------------------------------------------------
1 | from .product_args import product_args
2 | from .separate import separate
3 | from .single import single
4 | from .to_list import to_list
5 |
--------------------------------------------------------------------------------
/sd_fused/utils/diverse/product_args.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Optional
3 |
4 | from itertools import product
5 |
6 | from ..typing import MaybeIterable
7 | from .to_list import to_list
8 |
9 | # TODO types!
10 | def product_args(**kwargs: Optional[MaybeIterable]) -> list[dict]:
11 | """All possible combintations of kwargs"""
12 |
13 | args = list(kwargs.values())
14 | keys = list(kwargs.keys())
15 |
16 | args = tuple(map(to_list, args))
17 | perms = list(product(*args))
18 |
19 | return [dict(zip(keys, args)) for args in perms]
20 |
--------------------------------------------------------------------------------
/sd_fused/utils/diverse/separate.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Optional
3 |
4 | from torch import Tensor
5 |
6 | from ...utils.typing import TypeVar
7 |
8 | T = TypeVar("T", Tensor, int, float, str)
9 |
10 |
11 | def separate(xs: list[T | None]) -> Optional[list[T]]:
12 | """Separate a list of that may containt `None` into a list
13 | that does not contain `None` or is `None` itself."""
14 |
15 | assert len(xs) >= 1
16 |
17 | if xs[0] is None:
18 | for x in xs:
19 | assert x is None
20 |
21 | return None
22 |
23 | for x in xs:
24 | assert x is not None
25 |
26 | return [x for x in xs if x is not None]
27 |
--------------------------------------------------------------------------------
/sd_fused/utils/diverse/single.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Optional, overload
3 |
4 | import torch
5 |
6 | from ...utils.typing import TypeVar
7 | from ...layers.base.types import Device
8 |
9 | T = TypeVar("T", int, float, Device, torch.dtype)
10 |
11 |
12 | @overload
13 | def single(x: set[T]) -> T:
14 | ...
15 |
16 |
17 | @overload
18 | def single(x: set[Optional[T]]) -> Optional[T]:
19 | ...
20 |
21 |
22 | def single(x: set[Optional[T]] | set[T]) -> Optional[T]:
23 | """Get the single element of a set."""
24 |
25 | assert len(x) == 1
26 |
27 | return x.pop()
28 |
--------------------------------------------------------------------------------
/sd_fused/utils/diverse/to_list.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Optional, overload
3 |
4 | from PIL import Image
5 | from pathlib import Path
6 |
7 | from ..typing import MaybeIterable, T
8 |
9 |
10 | @overload
11 | def to_list(x: MaybeIterable[T]) -> list[T]:
12 | ...
13 |
14 |
15 | @overload
16 | def to_list(x: Optional[MaybeIterable[T]]) -> tuple[None] | list[T]:
17 | ...
18 |
19 |
20 | def to_list(x: Optional[MaybeIterable[T]]) -> tuple[None] | list[T]:
21 | """Convert a `MaybeIterable` into a list."""
22 |
23 | if x is None:
24 | return (None,)
25 |
26 | if isinstance(x, (int, float, str, Path, Image.Image)):
27 | return [x] # type: ignore
28 |
29 | return list(x) # type: ignore
30 |
--------------------------------------------------------------------------------
/sd_fused/utils/image/__init__.py:
--------------------------------------------------------------------------------
1 | from .image_base64 import image_base64
2 | from .image2tensor import image2tensor
3 | from .tensor2images import tensor2images
4 | from .open_image import open_image
5 | from .image_size import image_size
6 |
7 | from .types import ImageType, ResizeModes
8 |
--------------------------------------------------------------------------------
/sd_fused/utils/image/image2tensor.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Optional
3 |
4 | from functools import partial
5 | from PIL import Image
6 |
7 | import math
8 | import numpy as np
9 |
10 | import torch
11 | import torch.nn.functional as F
12 | from torch import Tensor
13 |
14 | from einops import rearrange
15 |
16 | from ...layers.base.types import Device
17 | from .open_image import open_image
18 | from .types import ResizeModes, ImageType
19 |
20 |
21 | def image2tensor(
22 | path: ImageType,
23 | height: Optional[int] = None,
24 | width: Optional[int] = None,
25 | mode: Optional[ResizeModes] = None,
26 | device: Optional[Device] = None,
27 | rescale: Optional[float] = None,
28 | ) -> Tensor:
29 | """Open an image/url as pytorch batched-Tensor (B=1 C H W)."""
30 |
31 | img = open_image(path)
32 | resize = partial(img.resize, resample=Image.LANCZOS)
33 |
34 | if height is None or width is None:
35 | assert height is None and width is None
36 |
37 | width, height = img.size
38 | if rescale is not None:
39 | width = math.ceil(width * rescale)
40 | height = math.ceil(height * rescale)
41 |
42 | if mode == "resize" or rescale is not None:
43 | img = resize((width, height))
44 | else:
45 | assert mode is not None
46 |
47 | ar = width / height
48 | src_ar = img.width / img.height
49 |
50 | diff = ar > src_ar
51 | if mode == "resize-pad":
52 | diff = not diff
53 |
54 | w = math.ceil(width if diff else height * src_ar)
55 | h = math.ceil(height if not diff else width / src_ar)
56 |
57 | img = resize((w, h))
58 |
59 | data = torch.from_numpy(np.asarray(img).copy()).to(device)
60 | data = rearrange(data, "H W C -> 1 C H W")
61 |
62 | # crop/padding size
63 | h, w = data.shape[-2:]
64 | dh, dw = height - h, width - w
65 |
66 | pad = (dw // 2, dw - dw // 2, dh // 2, dh - dh // 2)
67 | data = F.pad(data, pad, value=0)
68 |
69 | return data
70 |
--------------------------------------------------------------------------------
/sd_fused/utils/image/image_base64.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import base64
4 | from io import BytesIO
5 |
6 | from .open_image import open_image
7 | from .types import ImageType
8 |
9 |
10 | def image_base64(path: ImageType) -> str:
11 | """Encodes an image as base64 (JPGE) string."""
12 |
13 | img = open_image(path)
14 |
15 | buffered = BytesIO()
16 | img.save(buffered, format="JPEG")
17 | code = base64.b64encode(buffered.getvalue())
18 | code = code.decode("ascii")
19 |
20 | return f"data:image/jpg;base64,{code}"
21 |
--------------------------------------------------------------------------------
/sd_fused/utils/image/image_size.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from .open_image import open_image
4 | from .types import ImageType
5 |
6 |
7 | def image_size(path: ImageType) -> tuple[int, int]:
8 | img = open_image(path)
9 |
10 | width, height = img.size
11 |
12 | return height, width
13 |
--------------------------------------------------------------------------------
/sd_fused/utils/image/open_image.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from PIL import Image
4 | import requests
5 |
6 | from io import BytesIO
7 | import validators
8 |
9 | from .types import ImageType
10 |
11 |
12 | def open_image(path: ImageType) -> Image.Image:
13 | """Open a path or url as an image."""
14 |
15 | if isinstance(path, Image.Image):
16 | img = path
17 | elif isinstance(path, str) and validators.url(path): # type: ignore
18 | response = requests.get(path)
19 | img = Image.open(BytesIO(response.content))
20 | else:
21 | img = Image.open(path)
22 | img = img.convert("RGB")
23 |
24 | # TODO get alpha channel as mask?
25 |
26 | return img
27 |
--------------------------------------------------------------------------------
/sd_fused/utils/image/tensor2images.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from PIL import Image
4 |
5 | import torch
6 | from torch import Tensor
7 | from einops import rearrange
8 |
9 |
10 | def tensor2images(data: Tensor) -> list[Image.Image]:
11 | """Creates a list of images according to the batch size."""
12 |
13 | assert data.dtype == torch.uint8
14 |
15 | data = rearrange(data, "B C H W -> B H W C").cpu().numpy()
16 |
17 | return [Image.fromarray(v) for v in data]
18 |
--------------------------------------------------------------------------------
/sd_fused/utils/image/types.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Union
3 |
4 | from PIL import Image
5 | from pathlib import Path
6 |
7 | from ..typing import Literal
8 |
9 | ResizeModes = Literal["resize", "resize-crop", "resize-pad"]
10 | ImageType = Union[str, Path, Image.Image]
11 |
--------------------------------------------------------------------------------
/sd_fused/utils/parameters/__init__.py:
--------------------------------------------------------------------------------
1 | from .parameters import Parameters
2 | from .parameters_list import ParametersList
3 |
4 | from .batch_parameters import batch_parameters
5 | from .group_parameters import group_parameters
6 |
--------------------------------------------------------------------------------
/sd_fused/utils/parameters/batch_parameters.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from .parameters import Parameters
4 |
5 |
6 | def batch_parameters(
7 | groups: list[list[Parameters]],
8 | batch_size: int,
9 | ) -> list[list[Parameters]]:
10 | """Separate groups into batches of an specified size."""
11 |
12 | batched_parameters: list[list[Parameters]] = []
13 | for group in groups:
14 | for i in range(0, len(group), batch_size):
15 | s = slice(i, min(i + batch_size, len(group)))
16 |
17 | batched_parameters.append(group[s])
18 |
19 | return batched_parameters
20 |
--------------------------------------------------------------------------------
/sd_fused/utils/parameters/group_parameters.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from .parameters import Parameters
4 |
5 |
6 | def group_parameters(parameters: list[Parameters]) -> list[list[Parameters]]:
7 | """Group parameters that can share a batch."""
8 |
9 | groups: list[list[Parameters]] = []
10 | for parameter in parameters:
11 | if len(groups) == 0:
12 | groups.append([parameter])
13 | continue
14 |
15 | can_share = False
16 | for group in groups:
17 | can_share = True
18 | for other_parameter in group:
19 | can_share &= parameter.can_share_batch(other_parameter)
20 |
21 | if can_share:
22 | group.append(parameter)
23 | break
24 |
25 | if not can_share:
26 | groups.append([parameter])
27 |
28 | return groups
29 |
--------------------------------------------------------------------------------
/sd_fused/utils/parameters/parameters.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Optional
3 | from typing_extensions import Self
4 |
5 | from dataclasses import dataclass, field
6 | from PIL.PngImagePlugin import PngInfo
7 |
8 | import torch
9 | from torch import Tensor
10 |
11 | from ...layers.base.types import Device
12 | from ..image import ResizeModes, image2tensor, image_base64, ImageType
13 |
14 | SAVE_ARGS = [
15 | "steps",
16 | "height",
17 | "width",
18 | "seed",
19 | "negative_prompt",
20 | "scale",
21 | "eta",
22 | "sub_seed",
23 | "seed_interpolation",
24 | "prompt",
25 | "strength",
26 | "mode",
27 | "image_base64",
28 | "mask_base64",
29 | ]
30 |
31 |
32 | @dataclass
33 | class Parameters:
34 | """Hold information from a single image generation."""
35 |
36 | steps: int
37 | height: int
38 | width: int
39 | seed: int
40 | negative_prompt: str
41 | scale: Optional[float] = None
42 | eta: Optional[float] = None # DDIM
43 | sub_seed: Optional[int] = None
44 | seed_interpolation: Optional[float] = None
45 | prompt: Optional[str] = None
46 | img: Optional[ImageType] = None
47 | mask: Optional[ImageType] = None
48 | strength: Optional[float] = None
49 | mode: Optional[ResizeModes] = None
50 |
51 | device: Optional[Device] = field(default=None, repr=False)
52 | dtype: Optional[torch.dtype] = field(default=None, repr=False)
53 |
54 | def __post_init__(self) -> None:
55 | assert self.height % 8 == 0
56 | assert self.width % 8 == 0
57 |
58 | if self.img is None:
59 | assert self.mode is None
60 | assert self.strength is None
61 | assert self.mask is None
62 | else:
63 | assert self.mode is not None
64 | assert self.strength is not None
65 |
66 | if self.sub_seed is None:
67 | assert self.seed_interpolation is None
68 | else:
69 | assert self.seed_interpolation is not None
70 |
71 | if self.prompt is None:
72 | assert self.scale is None
73 | else:
74 | assert self.scale is not None
75 |
76 | @property
77 | def unconditional(self) -> bool:
78 | return self.prompt is None
79 |
80 | def can_share_batch(self, other: Self) -> bool:
81 | """Determine if two parameters can share a batch."""
82 |
83 | value = self.steps == other.steps
84 | value &= self.strength == other.strength
85 | value &= self.height == other.height
86 | value &= self.width == other.width
87 | value &= self.unconditional == other.unconditional
88 |
89 | return value
90 |
91 | @property
92 | def image_data(self) -> Optional[Tensor]:
93 | """Image data as a Tensor."""
94 |
95 | if self.img is None or self.mode is None:
96 | return
97 |
98 | return image2tensor(self.img, self.height, self.width, self.mode, self.device)
99 |
100 | @property
101 | def mask_data(self) -> Optional[Tensor]:
102 | """Mask data as a Tensor."""
103 |
104 | if self.mask is None or self.mode is None:
105 | return
106 |
107 | data = image2tensor(self.mask, self.height, self.width, self.mode, self.device)
108 |
109 | # single-channel
110 | data = data.float().mean(dim=1, keepdim=True)
111 |
112 | return data >= 255 / 2 # bool-Tensor
113 |
114 | @property
115 | def image_base64(self) -> Optional[str]:
116 | """Image data as a base64 string."""
117 |
118 | if self.img is None:
119 | return
120 |
121 | return image_base64(self.img)
122 |
123 | @property
124 | def mask_base64(self) -> Optional[str]:
125 | """Mask data as a base64 string."""
126 |
127 | if self.mask is None:
128 | return
129 |
130 | return image_base64(self.mask)
131 |
132 | @property
133 | def png_info(self) -> PngInfo:
134 | """PNG metadata."""
135 |
136 | info = PngInfo()
137 | for key in SAVE_ARGS:
138 | value = getattr(self, key)
139 |
140 | if value is None:
141 | continue
142 | if isinstance(value, (int, float)):
143 | value = str(value)
144 |
145 | info.add_text(f"SD {key}", value)
146 |
147 | return info
148 |
--------------------------------------------------------------------------------
/sd_fused/utils/parameters/parameters_list.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Iterator, Optional
3 |
4 | from functools import lru_cache
5 |
6 | import torch
7 | from torch import Tensor
8 |
9 | from ...layers.base.types import Device
10 | from ..diverse import separate, single
11 | from .parameters import Parameters
12 |
13 |
14 | class ParametersList:
15 | """Hold information from a many image generations."""
16 |
17 | def __init__(self, parameters: list[Parameters]) -> None:
18 | self.parameters = parameters
19 |
20 | def __len__(self) -> int:
21 | return len(self.parameters)
22 |
23 | def __iter__(self) -> Iterator[Parameters]:
24 | return iter(self.parameters)
25 |
26 | @property
27 | def prompts(self) -> Optional[list[str]]:
28 | prompts = [p.prompt for p in self.parameters]
29 | prompts = separate(prompts)
30 |
31 | return prompts
32 |
33 | @property
34 | def negative_prompts(self) -> list[str]:
35 | return [p.negative_prompt for p in self.parameters]
36 |
37 | @property
38 | def unconditional(self) -> bool:
39 | return self.prompts is None
40 |
41 | @property
42 | def seeds(self) -> list[int]:
43 | return [p.seed for p in self.parameters]
44 |
45 | @property
46 | def sub_seeds(self) -> Optional[list[int]]:
47 | sub_seeds = [p.sub_seed for p in self.parameters]
48 | sub_seeds = separate(sub_seeds)
49 |
50 | return sub_seeds
51 |
52 | @property
53 | def seeds_interpolation(self) -> Optional[Tensor]:
54 | seed_interpolations = [p.seed_interpolation for p in self.parameters]
55 | seed_interpolations = separate(seed_interpolations)
56 |
57 | if seed_interpolations is None:
58 | return None
59 |
60 | return torch.tensor(seed_interpolations, device=self.device, dtype=self.dtype)
61 |
62 | @property
63 | def height(self) -> int:
64 | height = set(p.height for p in self.parameters)
65 |
66 | return single(height)
67 |
68 | @property
69 | def width(self) -> int:
70 | width = set(p.width for p in self.parameters)
71 |
72 | return single(width)
73 |
74 | def shape(self, channels: int) -> tuple[int, int, int, int]:
75 | return (len(self), channels, self.height // 8, self.width // 8)
76 |
77 | @property
78 | def steps(self) -> int:
79 | steps = set(p.steps for p in self.parameters)
80 |
81 | return single(steps)
82 |
83 | @property
84 | def strength(self) -> Optional[float]:
85 | strength = set(p.strength for p in self.parameters)
86 |
87 | return single(strength)
88 |
89 | @property
90 | def device(self) -> Optional[Device]:
91 | devices = set(p.device for p in self.parameters)
92 |
93 | return single(devices)
94 |
95 | @property
96 | def dtype(self) -> Optional[torch.dtype]:
97 | dtypes = set(p.dtype for p in self.parameters)
98 |
99 | return single(dtypes)
100 |
101 | @property
102 | def scales(self) -> Optional[Tensor]:
103 | scales = [p.scale for p in self.parameters]
104 | scales = separate(scales)
105 |
106 | if scales is None:
107 | return None
108 |
109 | return torch.tensor(scales, device=self.device, dtype=self.dtype)
110 |
111 | @property
112 | def etas(self) -> Optional[Tensor]:
113 | etas = [p.eta for p in self.parameters]
114 | etas = separate(etas)
115 |
116 | if etas is None:
117 | return None
118 |
119 | return torch.tensor(etas, device=self.device, dtype=self.dtype)
120 |
121 | @property
122 | @lru_cache(None)
123 | def images_data(self) -> Optional[Tensor]:
124 | data = [p.image_data for p in self.parameters]
125 | data = separate(data)
126 |
127 | if data is None:
128 | return None
129 |
130 | return torch.cat(data, dim=0)
131 |
132 | @property
133 | @lru_cache(None)
134 | def masks_data(self) -> Optional[Tensor]:
135 | data = [p.mask_data for p in self.parameters]
136 | data = separate(data)
137 |
138 | if data is None:
139 | return None
140 |
141 | return torch.cat(data, dim=0)
142 |
143 | @property
144 | @lru_cache(None)
145 | def masked_images_data(self) -> Optional[Tensor]:
146 | if self.images_data is None or self.masks_data is None:
147 | return None
148 |
149 | return self.images_data * (~self.masks_data)
150 |
--------------------------------------------------------------------------------
/sd_fused/utils/tensors/__init__.py:
--------------------------------------------------------------------------------
1 | from .generate_noise import generate_noise, random_seeds
2 | from .normalize import normalize, denormalize
3 | from .slerp import slerp
4 | from .to_tensor import to_tensor
5 |
--------------------------------------------------------------------------------
/sd_fused/utils/tensors/generate_noise.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Optional
3 |
4 | import random
5 |
6 | import torch
7 | from torch import Tensor
8 |
9 | from ...layers.base.types import Device
10 |
11 |
12 | def generate_noise(
13 | shape: tuple[int, ...],
14 | seeds: list[int],
15 | device: Optional[Device] = None,
16 | dtype: Optional[torch.dtype] = None,
17 | repeat: int = 1, # TODO: support repeat
18 | ) -> Tensor:
19 | """Generate random noise with individual seeds per batch."""
20 |
21 | batch_size, *rest = shape
22 | assert len(seeds) == batch_size
23 |
24 | extended_shape = (repeat, batch_size, *rest)
25 |
26 | noise = torch.empty(extended_shape)
27 | for n in range(repeat):
28 | for i, s in enumerate(seeds):
29 | generator = torch.Generator()
30 | generator.manual_seed(s + n)
31 |
32 | noise[n, i] = torch.randn(*rest, generator=generator)
33 |
34 | noise = noise.to(device=device, dtype=dtype)
35 |
36 | if repeat == 1:
37 | return noise[0]
38 |
39 | return noise
40 |
41 |
42 | def random_seeds(size: int) -> list[int]:
43 | """Generate random seeds."""
44 |
45 | return [random.randint(0, 2**32 - 1) for _ in range(size)]
46 |
--------------------------------------------------------------------------------
/sd_fused/utils/tensors/normalize.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Optional
3 |
4 | import torch
5 | from torch import Tensor
6 |
7 |
8 | def normalize(data: Tensor, dtype: Optional[torch.dtype] = None) -> Tensor:
9 | """Normalize a byte-Tensor to the [-1, 1] range."""
10 |
11 | assert data.dtype == torch.uint8
12 |
13 | return data.div(255 / 2).sub(1).to(dtype)
14 |
15 |
16 | def denormalize(data: Tensor) -> Tensor:
17 | """Denormalize a tensor of the range [-1, 1] to a byte-Tensor."""
18 |
19 | assert data.requires_grad == False
20 |
21 | return data.add(1).mul(255 / 2).clamp(0, 255).byte()
22 |
--------------------------------------------------------------------------------
/sd_fused/utils/tensors/slerp.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import torch
4 | from torch import Tensor
5 |
6 |
7 | # TODO DEBUG
8 | def slerp(a: Tensor, b: Tensor, t: Tensor) -> Tensor:
9 | "Spherical linear interpolation."
10 |
11 | # https://en.wikipedia.org/wiki/Slerp
12 |
13 | # 0 <= t <= 1
14 | assert t.ge(0).all() and t.le(1).all()
15 |
16 | assert a.shape == b.shape
17 | assert t.shape[0] == a.shape[0]
18 | assert a.ndim == 4
19 | assert t.ndim == 1
20 |
21 | t = t[:, None, None, None]
22 |
23 | # ? that's how you normalize?
24 | an = a / a.norm(dim=1, keepdim=True)
25 | bn = b / b.norm(dim=1, keepdim=True)
26 |
27 | Ω = an.mul(bn).sum(1).clamp(-1, 1).acos()
28 |
29 | den = torch.sin(Ω)
30 |
31 | A = torch.sin((1 - t) * Ω)
32 | B = torch.sin(t * Ω)
33 |
34 | return (A * a + B * b) / den
35 |
--------------------------------------------------------------------------------
/sd_fused/utils/tensors/to_tensor.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Optional
3 |
4 | import torch
5 | from torch import Tensor
6 |
7 | from ...layers.base.types import Device
8 |
9 |
10 | def to_tensor(
11 | x: Optional[int | float | Tensor],
12 | device: Optional[Device] = None,
13 | dtype: Optional[torch.dtype] = None,
14 | *,
15 | add_spatial: bool = False,
16 | ) -> Tensor:
17 | """Convert a number to a Tensor with fake channel/spatial dimensions."""
18 |
19 | if x is None:
20 | x = 0
21 |
22 | if isinstance(x, (int, float)):
23 | x = torch.tensor([x], device=device, dtype=dtype)
24 | else:
25 | assert x.ndim == 1
26 | x = x.to(device=device, dtype=dtype)
27 |
28 | if add_spatial:
29 | x = x.view(-1, 1, 1, 1)
30 |
31 | return x
32 |
--------------------------------------------------------------------------------
/sd_fused/utils/typing.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import TypeVar, Union, Iterable
3 | from typing_extensions import Unpack
4 |
5 | import sys
6 |
7 | if sys.version_info >= (3, 8):
8 | from typing import Literal, Final, Protocol
9 | else:
10 | from typing_extensions import Literal, Final, Protocol
11 |
12 | if sys.version_info >= (3, 11):
13 | from typing import TypeVarTuple
14 | else:
15 | from typing_extensions import TypeVarTuple
16 |
17 | from .image import ImageType
18 |
19 | T = TypeVar("T", int, float, str, ImageType)
20 | MaybeIterable = Union[T, Iterable[T]]
21 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # type: ignore
2 | from distutils.core import setup
3 |
4 | from pathlib import Path
5 |
6 | from sd_fused import StableDiffusion
7 |
8 | packages = list(set(str(p.parent) for p in Path("sd_fused").rglob("*.py")))
9 |
10 | with open("./requirements.txt") as handle:
11 | requirements = [l.strip() for l in handle.read().split()]
12 |
13 |
14 | setup(
15 | name="sd-fused",
16 | version=StableDiffusion.version,
17 | description="Stable-Diffusion + Fused CUDA kernels",
18 | author="Thales Fernandes",
19 | author_email="thalesfdfernandes@gmail.com",
20 | url="https://github.com/tfernd/sd-fused",
21 | python_requires=">=3.7", # TODO Less?
22 | # packages=find_packages(exclude=["/cuda"]),
23 | packages=packages,
24 | install_requires=requirements,
25 | )
26 |
--------------------------------------------------------------------------------