├── setup.py ├── pyproject.toml ├── k_diffusion ├── __init__.py ├── models │ ├── __init__.py │ ├── flops.py │ ├── flags.py │ ├── axial_rope.py │ ├── image_v1.py │ ├── image_transformer_v1.py │ └── image_transformer_v2.py ├── augmentation.py ├── gns.py ├── evaluation.py ├── external.py ├── config.py ├── layers.py ├── utils.py └── sampling.py ├── .gitignore ├── requirements.txt ├── setup.cfg ├── LICENSE ├── .github └── workflows │ └── python-publish.yml ├── configs ├── config_mnist_transformer.json ├── config_cifar10.json ├── config_mnist.json ├── config_32x32_small.json ├── config_cifar10_transformer.json ├── config_32x32_small_butterflies.json ├── config_oxford_flowers.json └── config_oxford_flowers_shifted_window.json ├── config_from_inference.py ├── make_grid.py ├── convert_for_inference.py ├── sample.py ├── sample_clip_guided.py ├── README.md └── train.py /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | 4 | if __name__ == '__main__': 5 | setup() 6 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools"] 3 | build-backend = "setuptools.build_meta" 4 | -------------------------------------------------------------------------------- /k_diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | from . import augmentation, config, evaluation, external, gns, layers, models, sampling, utils 2 | from .layers import Denoiser 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | venv* 2 | __pycache__ 3 | .ipynb_checkpoints 4 | *.pth 5 | *.egg-info 6 | data 7 | *_demo_*.png 8 | wandb/* 9 | *.csv 10 | .env 11 | *.safetensors 12 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | clean-fid 3 | clip-anytorch 4 | dctorch 5 | einops 6 | jsonmerge 7 | kornia 8 | Pillow 9 | safetensors 10 | scikit-image 11 | scipy 12 | torch>=2.1 13 | torchdiffeq 14 | torchsde 15 | torchvision 16 | tqdm 17 | wandb 18 | -------------------------------------------------------------------------------- /k_diffusion/models/__init__.py: -------------------------------------------------------------------------------- 1 | from . import flops 2 | from .flags import checkpointing, get_checkpointing 3 | from .image_v1 import ImageDenoiserModelV1 4 | from .image_transformer_v1 import ImageTransformerDenoiserModelV1 5 | from .image_transformer_v2 import ImageTransformerDenoiserModelV2 6 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = k-diffusion 3 | version = 0.2.0.dev0 4 | author = Katherine Crowson 5 | author_email = crowsonkb@gmail.com 6 | url = https://github.com/crowsonkb/k-diffusion 7 | description = Karras et al. (2022) diffusion models for PyTorch 8 | long_description = file: README.md 9 | long_description_content_type = text/markdown 10 | license = MIT 11 | 12 | [options] 13 | packages = find: 14 | install_requires = 15 | accelerate 16 | clean-fid 17 | clip-anytorch 18 | dctorch 19 | einops 20 | jsonmerge 21 | kornia 22 | Pillow 23 | safetensors 24 | scikit-image 25 | scipy 26 | torch >= 2.1 27 | torchdiffeq 28 | torchsde 29 | torchvision 30 | tqdm 31 | wandb 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2022 Katherine Crowson 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | jobs: 8 | deploy: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v2 12 | - uses: actions-ecosystem/action-regex-match@v2 13 | id: regex-match 14 | with: 15 | text: ${{ github.event.head_commit.message }} 16 | regex: '^Release ([^ ]+)' 17 | - name: Set up Python 18 | uses: actions/setup-python@v2 19 | with: 20 | python-version: '3.8' 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install setuptools wheel twine 25 | - name: Release 26 | if: ${{ steps.regex-match.outputs.match != '' }} 27 | uses: softprops/action-gh-release@v1 28 | with: 29 | tag_name: v${{ steps.regex-match.outputs.group1 }} 30 | - name: Build and publish 31 | if: ${{ steps.regex-match.outputs.match != '' }} 32 | env: 33 | TWINE_USERNAME: __token__ 34 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 35 | run: | 36 | python setup.py sdist bdist_wheel 37 | twine upload dist/* 38 | -------------------------------------------------------------------------------- /configs/config_mnist_transformer.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": { 3 | "type": "image_transformer_v2", 4 | "input_channels": 1, 5 | "input_size": [28, 28], 6 | "patch_size": [4, 4], 7 | "depths": [8], 8 | "widths": [256], 9 | "loss_config": "karras", 10 | "loss_weighting": "soft-min-snr", 11 | "dropout_rate": 0.05, 12 | "augment_prob": 0.12, 13 | "sigma_data": 0.6162, 14 | "sigma_min": 1e-2, 15 | "sigma_max": 80, 16 | "sigma_sample_density": { 17 | "type": "cosine-interpolated" 18 | } 19 | }, 20 | "dataset": { 21 | "type": "mnist", 22 | "location": "data", 23 | "num_classes": 10, 24 | "cond_dropout_rate": 0.1 25 | }, 26 | "optimizer": { 27 | "type": "adamw", 28 | "lr": 5e-4, 29 | "betas": [0.9, 0.95], 30 | "eps": 1e-8, 31 | "weight_decay": 1e-4 32 | }, 33 | "lr_sched": { 34 | "type": "constant", 35 | "warmup": 0.0 36 | }, 37 | "ema_sched": { 38 | "type": "inverse", 39 | "power": 0.6667, 40 | "max_value": 0.9999 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /configs/config_cifar10.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": { 3 | "type": "image_v1", 4 | "input_channels": 3, 5 | "input_size": [32, 32], 6 | "patch_size": 1, 7 | "mapping_out": 256, 8 | "depths": [2, 4, 4], 9 | "channels": [128, 256, 512], 10 | "self_attn_depths": [false, true, true], 11 | "has_variance": false, 12 | "loss_config": "karras", 13 | "loss_weighting": "soft-min-snr", 14 | "dropout_rate": 0.05, 15 | "augment_wrapper": true, 16 | "augment_prob": 0.12, 17 | "sigma_data": 0.5, 18 | "sigma_min": 1e-2, 19 | "sigma_max": 80, 20 | "sigma_sample_density": { 21 | "type": "cosine-interpolated" 22 | } 23 | }, 24 | "dataset": { 25 | "type": "cifar10", 26 | "location": "data" 27 | }, 28 | "optimizer": { 29 | "type": "adamw", 30 | "lr": 1e-4, 31 | "betas": [0.95, 0.999], 32 | "eps": 1e-6, 33 | "weight_decay": 1e-3 34 | }, 35 | "lr_sched": { 36 | "type": "constant", 37 | "warmup": 0.0 38 | }, 39 | "ema_sched": { 40 | "type": "inverse", 41 | "power": 0.6667, 42 | "max_value": 0.9999 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /configs/config_mnist.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": { 3 | "type": "image_v1", 4 | "input_channels": 1, 5 | "input_size": [28, 28], 6 | "patch_size": 1, 7 | "mapping_out": 256, 8 | "depths": [2, 4, 4], 9 | "channels": [128, 128, 256], 10 | "self_attn_depths": [false, false, true], 11 | "has_variance": false, 12 | "loss_config": "karras", 13 | "loss_weighting": "soft-min-snr", 14 | "dropout_rate": 0.05, 15 | "augment_wrapper": true, 16 | "augment_prob": 0.12, 17 | "sigma_data": 0.6162, 18 | "sigma_min": 1e-2, 19 | "sigma_max": 80, 20 | "sigma_sample_density": { 21 | "type": "cosine-interpolated" 22 | } 23 | }, 24 | "dataset": { 25 | "type": "mnist", 26 | "location": "data" 27 | }, 28 | "optimizer": { 29 | "type": "adamw", 30 | "lr": 2e-4, 31 | "betas": [0.95, 0.999], 32 | "eps": 1e-6, 33 | "weight_decay": 1e-3 34 | }, 35 | "lr_sched": { 36 | "type": "constant", 37 | "warmup": 0.0 38 | }, 39 | "ema_sched": { 40 | "type": "inverse", 41 | "power": 0.6667, 42 | "max_value": 0.9999 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /config_from_inference.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """Extracts the configuration file from a slim inference checkpoint.""" 4 | 5 | import argparse 6 | import json 7 | from pathlib import Path 8 | import sys 9 | 10 | import k_diffusion as K 11 | import safetensors.torch as safetorch 12 | 13 | 14 | def main(): 15 | p = argparse.ArgumentParser(description=__doc__, 16 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 17 | p.add_argument("checkpoint", type=Path, 18 | help="the inference checkpoint to extract the configuration from") 19 | p.add_argument("--output", "-o", type=Path, 20 | help="the output configuration file") 21 | args = p.parse_args() 22 | 23 | print(f"Loading inference checkpoint {args.checkpoint}...", file=sys.stderr) 24 | metadata = K.utils.get_safetensors_metadata(args.checkpoint) 25 | if "config" not in metadata: 26 | raise ValueError("No configuration found in checkpoint") 27 | 28 | output_path = args.output or args.checkpoint.with_suffix(".json") 29 | 30 | print(f"Saving configuration to {output_path}...", file=sys.stderr) 31 | output_path.write_text(metadata["config"]) 32 | 33 | 34 | if __name__ == "__main__": 35 | main() 36 | -------------------------------------------------------------------------------- /configs/config_32x32_small.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": { 3 | "type": "image_v1", 4 | "input_channels": 3, 5 | "input_size": [32, 32], 6 | "patch_size": 1, 7 | "mapping_out": 256, 8 | "depths": [2, 4, 4], 9 | "channels": [128, 256, 512], 10 | "self_attn_depths": [false, true, true], 11 | "has_variance": false, 12 | "loss_config": "karras", 13 | "loss_weighting": "soft-min-snr", 14 | "dropout_rate": 0.05, 15 | "augment_wrapper": true, 16 | "augment_prob": 0.12, 17 | "sigma_data": 0.5, 18 | "sigma_min": 1e-2, 19 | "sigma_max": 80, 20 | "sigma_sample_density": { 21 | "type": "cosine-interpolated" 22 | } 23 | }, 24 | "dataset": { 25 | "type": "imagefolder", 26 | "location": "/path/to/dataset" 27 | }, 28 | "optimizer": { 29 | "type": "adamw", 30 | "lr": 1e-4, 31 | "betas": [0.95, 0.999], 32 | "eps": 1e-6, 33 | "weight_decay": 1e-3 34 | }, 35 | "lr_sched": { 36 | "type": "constant", 37 | "warmup": 0.0 38 | }, 39 | "ema_sched": { 40 | "type": "inverse", 41 | "power": 0.6667, 42 | "max_value": 0.9999 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /configs/config_cifar10_transformer.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": { 3 | "type": "image_transformer_v2", 4 | "input_channels": 3, 5 | "input_size": [32, 32], 6 | "patch_size": [2, 2], 7 | "depths": [2, 4], 8 | "widths": [256, 512], 9 | "self_attns": [ 10 | {"type": "global"}, 11 | {"type": "global"} 12 | ], 13 | "loss_config": "karras", 14 | "loss_weighting": "soft-min-snr", 15 | "dropout_rate": 0.05, 16 | "augment_prob": 0.12, 17 | "sigma_data": 0.5, 18 | "sigma_min": 1e-2, 19 | "sigma_max": 80, 20 | "sigma_sample_density": { 21 | "type": "cosine-interpolated" 22 | } 23 | }, 24 | "dataset": { 25 | "type": "cifar10", 26 | "location": "data", 27 | "num_classes": 10, 28 | "cond_dropout_rate": 0.1 29 | }, 30 | "optimizer": { 31 | "type": "adamw", 32 | "lr": 5e-4, 33 | "betas": [0.9, 0.95], 34 | "eps": 1e-8, 35 | "weight_decay": 1e-4 36 | }, 37 | "lr_sched": { 38 | "type": "constant", 39 | "warmup": 0.0 40 | }, 41 | "ema_sched": { 42 | "type": "inverse", 43 | "power": 0.6667, 44 | "max_value": 0.9999 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /configs/config_32x32_small_butterflies.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": { 3 | "type": "image_v1", 4 | "input_channels": 3, 5 | "input_size": [32, 32], 6 | "patch_size": 1, 7 | "mapping_out": 256, 8 | "depths": [2, 4, 4], 9 | "channels": [128, 256, 512], 10 | "self_attn_depths": [false, true, true], 11 | "has_variance": false, 12 | "loss_config": "karras", 13 | "loss_weighting": "soft-min-snr", 14 | "dropout_rate": 0.05, 15 | "augment_wrapper": true, 16 | "augment_prob": 0.12, 17 | "sigma_data": 0.5, 18 | "sigma_min": 1e-2, 19 | "sigma_max": 80, 20 | "sigma_sample_density": { 21 | "type": "cosine-interpolated" 22 | } 23 | }, 24 | "dataset": { 25 | "type": "huggingface", 26 | "location": "huggan/smithsonian_butterflies_subset", 27 | "image_key": "image" 28 | }, 29 | "optimizer": { 30 | "type": "adamw", 31 | "lr": 1e-4, 32 | "betas": [0.95, 0.999], 33 | "eps": 1e-6, 34 | "weight_decay": 1e-3 35 | }, 36 | "lr_sched": { 37 | "type": "constant", 38 | "warmup": 0.0 39 | }, 40 | "ema_sched": { 41 | "type": "inverse", 42 | "power": 0.6667, 43 | "max_value": 0.9999 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /k_diffusion/models/flops.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | import math 3 | import threading 4 | 5 | 6 | state = threading.local() 7 | state.flop_counter = None 8 | 9 | 10 | @contextmanager 11 | def flop_counter(enable=True): 12 | try: 13 | old_flop_counter = state.flop_counter 14 | state.flop_counter = FlopCounter() if enable else None 15 | yield state.flop_counter 16 | finally: 17 | state.flop_counter = old_flop_counter 18 | 19 | 20 | class FlopCounter: 21 | def __init__(self): 22 | self.ops = [] 23 | 24 | def op(self, op, *args, **kwargs): 25 | self.ops.append((op, args, kwargs)) 26 | 27 | @property 28 | def flops(self): 29 | flops = 0 30 | for op, args, kwargs in self.ops: 31 | flops += op(*args, **kwargs) 32 | return flops 33 | 34 | 35 | def op(op, *args, **kwargs): 36 | if getattr(state, "flop_counter", None): 37 | state.flop_counter.op(op, *args, **kwargs) 38 | 39 | 40 | def op_linear(x, weight): 41 | return math.prod(x) * weight[0] 42 | 43 | 44 | def op_attention(q, k, v): 45 | *b, s_q, d_q = q 46 | *b, s_k, d_k = k 47 | *b, s_v, d_v = v 48 | return math.prod(b) * s_q * s_k * (d_q + d_v) 49 | 50 | 51 | def op_natten(q, k, v, kernel_size): 52 | *q_rest, d_q = q 53 | *_, d_v = v 54 | return math.prod(q_rest) * (d_q + d_v) * kernel_size**2 55 | -------------------------------------------------------------------------------- /configs/config_oxford_flowers.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": { 3 | "type": "image_transformer_v2", 4 | "input_channels": 3, 5 | "input_size": [256, 256], 6 | "patch_size": [4, 4], 7 | "depths": [2, 2, 4], 8 | "widths": [128, 256, 512], 9 | "self_attns": [ 10 | {"type": "neighborhood", "d_head": 64, "kernel_size": 7}, 11 | {"type": "neighborhood", "d_head": 64, "kernel_size": 7}, 12 | {"type": "global", "d_head": 64} 13 | ], 14 | "loss_config": "karras", 15 | "loss_weighting": "soft-min-snr", 16 | "dropout_rate": [0.0, 0.0, 0.1], 17 | "mapping_dropout_rate": 0.0, 18 | "augment_prob": 0.0, 19 | "sigma_data": 0.5, 20 | "sigma_min": 1e-2, 21 | "sigma_max": 160, 22 | "sigma_sample_density": { 23 | "type": "cosine-interpolated" 24 | } 25 | }, 26 | "dataset": { 27 | "type": "huggingface", 28 | "location": "nelorth/oxford-flowers", 29 | "image_key": "image" 30 | }, 31 | "optimizer": { 32 | "type": "adamw", 33 | "lr": 5e-4, 34 | "betas": [0.9, 0.95], 35 | "eps": 1e-8, 36 | "weight_decay": 1e-3 37 | }, 38 | "lr_sched": { 39 | "type": "constant", 40 | "warmup": 0.0 41 | }, 42 | "ema_sched": { 43 | "type": "inverse", 44 | "power": 0.75, 45 | "max_value": 0.9999 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /configs/config_oxford_flowers_shifted_window.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": { 3 | "type": "image_transformer_v2", 4 | "input_channels": 3, 5 | "input_size": [256, 256], 6 | "patch_size": [4, 4], 7 | "depths": [2, 2, 4], 8 | "widths": [128, 256, 512], 9 | "self_attns": [ 10 | {"type": "shifted-window", "d_head": 64, "window_size": 8}, 11 | {"type": "shifted-window", "d_head": 64, "window_size": 8}, 12 | {"type": "global", "d_head": 64} 13 | ], 14 | "loss_config": "karras", 15 | "loss_weighting": "soft-min-snr", 16 | "dropout_rate": [0.0, 0.0, 0.1], 17 | "mapping_dropout_rate": 0.0, 18 | "augment_prob": 0.0, 19 | "sigma_data": 0.5, 20 | "sigma_min": 1e-2, 21 | "sigma_max": 160, 22 | "sigma_sample_density": { 23 | "type": "cosine-interpolated" 24 | } 25 | }, 26 | "dataset": { 27 | "type": "huggingface", 28 | "location": "nelorth/oxford-flowers", 29 | "image_key": "image" 30 | }, 31 | "optimizer": { 32 | "type": "adamw", 33 | "lr": 5e-4, 34 | "betas": [0.9, 0.95], 35 | "eps": 1e-8, 36 | "weight_decay": 1e-3 37 | }, 38 | "lr_sched": { 39 | "type": "constant", 40 | "warmup": 0.0 41 | }, 42 | "ema_sched": { 43 | "type": "inverse", 44 | "power": 0.75, 45 | "max_value": 0.9999 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /make_grid.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """Assembles images into a grid.""" 4 | 5 | import argparse 6 | import math 7 | import sys 8 | 9 | from PIL import Image 10 | 11 | 12 | def main(): 13 | p = argparse.ArgumentParser(description=__doc__) 14 | p.add_argument('images', type=str, nargs='+', metavar='image', 15 | help='the input images') 16 | p.add_argument('--output', '-o', type=str, default='out.png', 17 | help='the output image') 18 | p.add_argument('--nrow', type=int, 19 | help='the number of images per row') 20 | args = p.parse_args() 21 | 22 | images = [Image.open(image) for image in args.images] 23 | mode = images[0].mode 24 | size = images[0].size 25 | for image, name in zip(images, args.images): 26 | if image.mode != mode: 27 | print(f'Error: Image {name} had mode {image.mode}, expected {mode}', file=sys.stderr) 28 | sys.exit(1) 29 | if image.size != size: 30 | print(f'Error: Image {name} had size {image.size}, expected {size}', file=sys.stderr) 31 | sys.exit(1) 32 | 33 | n = len(images) 34 | x = args.nrow if args.nrow else math.ceil(n**0.5) 35 | y = math.ceil(n / x) 36 | 37 | output = Image.new(mode, (size[0] * x, size[1] * y)) 38 | for i, image in enumerate(images): 39 | cur_x, cur_y = i % x, i // x 40 | output.paste(image, (size[0] * cur_x, size[1] * cur_y)) 41 | 42 | output.save(args.output) 43 | 44 | 45 | if __name__ == '__main__': 46 | main() 47 | -------------------------------------------------------------------------------- /k_diffusion/models/flags.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | from functools import update_wrapper 3 | import os 4 | import threading 5 | 6 | import torch 7 | 8 | 9 | def get_use_compile(): 10 | return os.environ.get("K_DIFFUSION_USE_COMPILE", "1") == "1" 11 | 12 | 13 | def get_use_flash_attention_2(): 14 | return os.environ.get("K_DIFFUSION_USE_FLASH_2", "1") == "1" 15 | 16 | 17 | state = threading.local() 18 | state.checkpointing = False 19 | 20 | 21 | @contextmanager 22 | def checkpointing(enable=True): 23 | try: 24 | old_checkpointing, state.checkpointing = state.checkpointing, enable 25 | yield 26 | finally: 27 | state.checkpointing = old_checkpointing 28 | 29 | 30 | def get_checkpointing(): 31 | return getattr(state, "checkpointing", False) 32 | 33 | 34 | class compile_wrap: 35 | def __init__(self, function, *args, **kwargs): 36 | self.function = function 37 | self.args = args 38 | self.kwargs = kwargs 39 | self._compiled_function = None 40 | update_wrapper(self, function) 41 | 42 | @property 43 | def compiled_function(self): 44 | if self._compiled_function is not None: 45 | return self._compiled_function 46 | if get_use_compile(): 47 | try: 48 | self._compiled_function = torch.compile(self.function, *self.args, **self.kwargs) 49 | except RuntimeError: 50 | self._compiled_function = self.function 51 | else: 52 | self._compiled_function = self.function 53 | return self._compiled_function 54 | 55 | def __call__(self, *args, **kwargs): 56 | return self.compiled_function(*args, **kwargs) 57 | -------------------------------------------------------------------------------- /convert_for_inference.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """Converts a k-diffusion training checkpoint to a slim inference checkpoint.""" 4 | 5 | import argparse 6 | import json 7 | from pathlib import Path 8 | import sys 9 | 10 | import torch 11 | import safetensors.torch as safetorch 12 | 13 | 14 | def main(): 15 | p = argparse.ArgumentParser(description=__doc__, 16 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 17 | p.add_argument("checkpoint", type=Path, 18 | help="the training checkpoint to convert") 19 | p.add_argument("--config", type=Path, 20 | help="override the checkpoint's configuration") 21 | p.add_argument("--output", "-o", type=Path, 22 | help="the output slim checkpoint") 23 | p.add_argument("--dtype", type=str, choices=["fp32", "fp16", "bf16"], default="fp16", 24 | help="the output dtype") 25 | args = p.parse_args() 26 | 27 | print(f"Loading training checkpoint {args.checkpoint}...", file=sys.stderr) 28 | ckpt = torch.load(args.checkpoint, map_location="cpu") 29 | config = ckpt.get("config", None) 30 | model_ema = ckpt["model_ema"] 31 | del ckpt 32 | 33 | if args.config: 34 | config = json.loads(args.config.read_text()) 35 | 36 | if config is None: 37 | raise ValueError("No configuration found in checkpoint and no override provided") 38 | 39 | dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.dtype] 40 | model_ema = {k: v.to(dtype) for k, v in model_ema.items()} 41 | 42 | output_path = args.output or args.checkpoint.with_suffix(".safetensors") 43 | metadata = {"config": json.dumps(config, indent=4)} 44 | print(f"Saving inference checkpoint to {output_path}...", file=sys.stderr) 45 | safetorch.save_file(model_ema, output_path, metadata=metadata) 46 | 47 | 48 | if __name__ == "__main__": 49 | main() 50 | -------------------------------------------------------------------------------- /sample.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """Samples from k-diffusion models.""" 4 | 5 | import argparse 6 | from pathlib import Path 7 | 8 | import accelerate 9 | import safetensors.torch as safetorch 10 | import torch 11 | from tqdm import trange, tqdm 12 | 13 | import k_diffusion as K 14 | 15 | 16 | def main(): 17 | p = argparse.ArgumentParser(description=__doc__, 18 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 19 | p.add_argument('--batch-size', type=int, default=64, 20 | help='the batch size') 21 | p.add_argument('--checkpoint', type=Path, required=True, 22 | help='the checkpoint to use') 23 | p.add_argument('--config', type=Path, 24 | help='the model config') 25 | p.add_argument('-n', type=int, default=64, 26 | help='the number of images to sample') 27 | p.add_argument('--prefix', type=str, default='out', 28 | help='the output prefix') 29 | p.add_argument('--steps', type=int, default=50, 30 | help='the number of denoising steps') 31 | args = p.parse_args() 32 | 33 | config = K.config.load_config(args.config if args.config else args.checkpoint) 34 | model_config = config['model'] 35 | # TODO: allow non-square input sizes 36 | assert len(model_config['input_size']) == 2 and model_config['input_size'][0] == model_config['input_size'][1] 37 | size = model_config['input_size'] 38 | 39 | accelerator = accelerate.Accelerator() 40 | device = accelerator.device 41 | print('Using device:', device, flush=True) 42 | 43 | inner_model = K.config.make_model(config).eval().requires_grad_(False).to(device) 44 | inner_model.load_state_dict(safetorch.load_file(args.checkpoint)) 45 | 46 | accelerator.print('Parameters:', K.utils.n_params(inner_model)) 47 | model = K.Denoiser(inner_model, sigma_data=model_config['sigma_data']) 48 | 49 | sigma_min = model_config['sigma_min'] 50 | sigma_max = model_config['sigma_max'] 51 | 52 | @torch.no_grad() 53 | @K.utils.eval_mode(model) 54 | def run(): 55 | if accelerator.is_local_main_process: 56 | tqdm.write('Sampling...') 57 | sigmas = K.sampling.get_sigmas_karras(args.steps, sigma_min, sigma_max, rho=7., device=device) 58 | def sample_fn(n): 59 | x = torch.randn([n, model_config['input_channels'], size[0], size[1]], device=device) * sigma_max 60 | x_0 = K.sampling.sample_lms(model, x, sigmas, disable=not accelerator.is_local_main_process) 61 | return x_0 62 | x_0 = K.evaluation.compute_features(accelerator, sample_fn, lambda x: x, args.n, args.batch_size) 63 | if accelerator.is_main_process: 64 | for i, out in enumerate(x_0): 65 | filename = f'{args.prefix}_{i:05}.png' 66 | K.utils.to_pil_image(out).save(filename) 67 | 68 | try: 69 | run() 70 | except KeyboardInterrupt: 71 | pass 72 | 73 | 74 | if __name__ == '__main__': 75 | main() 76 | -------------------------------------------------------------------------------- /k_diffusion/models/axial_rope.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch._dynamo 5 | from torch import nn 6 | 7 | from . import flags 8 | 9 | if flags.get_use_compile(): 10 | torch._dynamo.config.suppress_errors = True 11 | 12 | 13 | def rotate_half(x): 14 | x1, x2 = x[..., 0::2], x[..., 1::2] 15 | x = torch.stack((-x2, x1), dim=-1) 16 | *shape, d, r = x.shape 17 | return x.view(*shape, d * r) 18 | 19 | 20 | @flags.compile_wrap 21 | def apply_rotary_emb(freqs, t, start_index=0, scale=1.0): 22 | freqs = freqs.to(t) 23 | rot_dim = freqs.shape[-1] 24 | end_index = start_index + rot_dim 25 | assert rot_dim <= t.shape[-1], f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}" 26 | t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:] 27 | t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale) 28 | return torch.cat((t_left, t, t_right), dim=-1) 29 | 30 | 31 | def centers(start, stop, num, dtype=None, device=None): 32 | edges = torch.linspace(start, stop, num + 1, dtype=dtype, device=device) 33 | return (edges[:-1] + edges[1:]) / 2 34 | 35 | 36 | def make_grid(h_pos, w_pos): 37 | grid = torch.stack(torch.meshgrid(h_pos, w_pos, indexing='ij'), dim=-1) 38 | h, w, d = grid.shape 39 | return grid.view(h * w, d) 40 | 41 | 42 | def bounding_box(h, w, pixel_aspect_ratio=1.0): 43 | # Adjusted dimensions 44 | w_adj = w 45 | h_adj = h * pixel_aspect_ratio 46 | 47 | # Adjusted aspect ratio 48 | ar_adj = w_adj / h_adj 49 | 50 | # Determine bounding box based on the adjusted aspect ratio 51 | y_min, y_max, x_min, x_max = -1.0, 1.0, -1.0, 1.0 52 | if ar_adj > 1: 53 | y_min, y_max = -1 / ar_adj, 1 / ar_adj 54 | elif ar_adj < 1: 55 | x_min, x_max = -ar_adj, ar_adj 56 | 57 | return y_min, y_max, x_min, x_max 58 | 59 | 60 | def make_axial_pos(h, w, pixel_aspect_ratio=1.0, align_corners=False, dtype=None, device=None): 61 | y_min, y_max, x_min, x_max = bounding_box(h, w, pixel_aspect_ratio) 62 | if align_corners: 63 | h_pos = torch.linspace(y_min, y_max, h, dtype=dtype, device=device) 64 | w_pos = torch.linspace(x_min, x_max, w, dtype=dtype, device=device) 65 | else: 66 | h_pos = centers(y_min, y_max, h, dtype=dtype, device=device) 67 | w_pos = centers(x_min, x_max, w, dtype=dtype, device=device) 68 | return make_grid(h_pos, w_pos) 69 | 70 | 71 | def freqs_pixel(max_freq=10.0): 72 | def init(shape): 73 | freqs = torch.linspace(1.0, max_freq / 2, shape[-1]) * math.pi 74 | return freqs.log().expand(shape) 75 | return init 76 | 77 | 78 | def freqs_pixel_log(max_freq=10.0): 79 | def init(shape): 80 | log_min = math.log(math.pi) 81 | log_max = math.log(max_freq * math.pi / 2) 82 | return torch.linspace(log_min, log_max, shape[-1]).expand(shape) 83 | return init 84 | 85 | 86 | class AxialRoPE(nn.Module): 87 | def __init__(self, dim, n_heads, start_index=0, freqs_init=freqs_pixel_log(max_freq=10.0)): 88 | super().__init__() 89 | self.n_heads = n_heads 90 | self.start_index = start_index 91 | log_freqs = freqs_init((n_heads, dim // 4)) 92 | self.freqs_h = nn.Parameter(log_freqs.clone()) 93 | self.freqs_w = nn.Parameter(log_freqs.clone()) 94 | 95 | def extra_repr(self): 96 | dim = (self.freqs_h.shape[-1] + self.freqs_w.shape[-1]) * 2 97 | return f"dim={dim}, n_heads={self.n_heads}, start_index={self.start_index}" 98 | 99 | def get_freqs(self, pos): 100 | if pos.shape[-1] != 2: 101 | raise ValueError("input shape must be (..., 2)") 102 | freqs_h = pos[..., None, None, 0] * self.freqs_h.exp() 103 | freqs_w = pos[..., None, None, 1] * self.freqs_w.exp() 104 | freqs = torch.cat((freqs_h, freqs_w), dim=-1).repeat_interleave(2, dim=-1) 105 | return freqs.transpose(-2, -3) 106 | 107 | def forward(self, x, pos): 108 | freqs = self.get_freqs(pos) 109 | return apply_rotary_emb(freqs, x, self.start_index) 110 | -------------------------------------------------------------------------------- /k_diffusion/augmentation.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | import math 3 | import operator 4 | 5 | import numpy as np 6 | from skimage import transform 7 | import torch 8 | from torch import nn 9 | 10 | 11 | def translate2d(tx, ty): 12 | mat = [[1, 0, tx], 13 | [0, 1, ty], 14 | [0, 0, 1]] 15 | return torch.tensor(mat, dtype=torch.float32) 16 | 17 | 18 | def scale2d(sx, sy): 19 | mat = [[sx, 0, 0], 20 | [ 0, sy, 0], 21 | [ 0, 0, 1]] 22 | return torch.tensor(mat, dtype=torch.float32) 23 | 24 | 25 | def rotate2d(theta): 26 | mat = [[torch.cos(theta), torch.sin(-theta), 0], 27 | [torch.sin(theta), torch.cos(theta), 0], 28 | [ 0, 0, 1]] 29 | return torch.tensor(mat, dtype=torch.float32) 30 | 31 | 32 | class KarrasAugmentationPipeline: 33 | def __init__(self, a_prob=0.12, a_scale=2**0.2, a_aniso=2**0.2, a_trans=1/8, disable_all=False): 34 | self.a_prob = a_prob 35 | self.a_scale = a_scale 36 | self.a_aniso = a_aniso 37 | self.a_trans = a_trans 38 | self.disable_all = disable_all 39 | 40 | def __call__(self, image): 41 | h, w = image.size 42 | mats = [translate2d(h / 2 - 0.5, w / 2 - 0.5)] 43 | 44 | # x-flip 45 | a0 = torch.randint(2, []).float() 46 | mats.append(scale2d(1 - 2 * a0, 1)) 47 | # y-flip 48 | do = (torch.rand([]) < self.a_prob).float() 49 | a1 = torch.randint(2, []).float() * do 50 | mats.append(scale2d(1, 1 - 2 * a1)) 51 | # scaling 52 | do = (torch.rand([]) < self.a_prob).float() 53 | a2 = torch.randn([]) * do 54 | mats.append(scale2d(self.a_scale ** a2, self.a_scale ** a2)) 55 | # rotation 56 | do = (torch.rand([]) < self.a_prob).float() 57 | a3 = (torch.rand([]) * 2 * math.pi - math.pi) * do 58 | mats.append(rotate2d(-a3)) 59 | # anisotropy 60 | do = (torch.rand([]) < self.a_prob).float() 61 | a4 = (torch.rand([]) * 2 * math.pi - math.pi) * do 62 | a5 = torch.randn([]) * do 63 | mats.append(rotate2d(a4)) 64 | mats.append(scale2d(self.a_aniso ** a5, self.a_aniso ** -a5)) 65 | mats.append(rotate2d(-a4)) 66 | # translation 67 | do = (torch.rand([]) < self.a_prob).float() 68 | a6 = torch.randn([]) * do 69 | a7 = torch.randn([]) * do 70 | mats.append(translate2d(self.a_trans * w * a6, self.a_trans * h * a7)) 71 | 72 | # form the transformation matrix and conditioning vector 73 | mats.append(translate2d(-h / 2 + 0.5, -w / 2 + 0.5)) 74 | mat = reduce(operator.matmul, mats) 75 | cond = torch.stack([a0, a1, a2, a3.cos() - 1, a3.sin(), a5 * a4.cos(), a5 * a4.sin(), a6, a7]) 76 | 77 | # apply the transformation 78 | image_orig = np.array(image, dtype=np.float32) / 255 79 | if image_orig.ndim == 2: 80 | image_orig = image_orig[..., None] 81 | tf = transform.AffineTransform(mat.numpy()) 82 | if not self.disable_all: 83 | image = transform.warp(image_orig, tf.inverse, order=3, mode='reflect', cval=0.5, clip=False, preserve_range=True) 84 | else: 85 | image = image_orig 86 | cond = torch.zeros_like(cond) 87 | image_orig = torch.as_tensor(image_orig).movedim(2, 0) * 2 - 1 88 | image = torch.as_tensor(image).movedim(2, 0) * 2 - 1 89 | return image, image_orig, cond 90 | 91 | 92 | class KarrasAugmentWrapper(nn.Module): 93 | def __init__(self, model): 94 | super().__init__() 95 | self.inner_model = model 96 | 97 | def forward(self, input, sigma, aug_cond=None, mapping_cond=None, **kwargs): 98 | if aug_cond is None: 99 | aug_cond = input.new_zeros([input.shape[0], 9]) 100 | if mapping_cond is None: 101 | mapping_cond = aug_cond 102 | else: 103 | mapping_cond = torch.cat([aug_cond, mapping_cond], dim=1) 104 | return self.inner_model(input, sigma, mapping_cond=mapping_cond, **kwargs) 105 | 106 | def param_groups(self, *args, **kwargs): 107 | return self.inner_model.param_groups(*args, **kwargs) 108 | 109 | def set_skip_stages(self, skip_stages): 110 | return self.inner_model.set_skip_stages(skip_stages) 111 | 112 | def set_patch_size(self, patch_size): 113 | return self.inner_model.set_patch_size(patch_size) 114 | -------------------------------------------------------------------------------- /k_diffusion/gns.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class DDPGradientStatsHook: 6 | def __init__(self, ddp_module): 7 | try: 8 | ddp_module.register_comm_hook(self, self._hook_fn) 9 | except AttributeError: 10 | raise ValueError('DDPGradientStatsHook does not support non-DDP wrapped modules') 11 | self._clear_state() 12 | 13 | def _clear_state(self): 14 | self.bucket_sq_norms_small_batch = [] 15 | self.bucket_sq_norms_large_batch = [] 16 | 17 | @staticmethod 18 | def _hook_fn(self, bucket): 19 | buf = bucket.buffer() 20 | self.bucket_sq_norms_small_batch.append(buf.pow(2).sum(dtype=torch.float32)) 21 | fut = torch.distributed.all_reduce(buf, op=torch.distributed.ReduceOp.AVG, async_op=True).get_future() 22 | def callback(fut): 23 | buf = fut.value()[0] 24 | self.bucket_sq_norms_large_batch.append(buf.pow(2).sum(dtype=torch.float32)) 25 | return buf 26 | return fut.then(callback) 27 | 28 | def get_stats(self): 29 | sq_norm_small_batch = sum(self.bucket_sq_norms_small_batch) 30 | sq_norm_large_batch = sum(self.bucket_sq_norms_large_batch) 31 | self._clear_state() 32 | stats = torch.stack([sq_norm_small_batch, sq_norm_large_batch]) 33 | torch.distributed.all_reduce(stats, op=torch.distributed.ReduceOp.AVG) 34 | return stats[0].item(), stats[1].item() 35 | 36 | 37 | class GradientNoiseScale: 38 | """Calculates the gradient noise scale (1 / SNR), or critical batch size, 39 | from _An Empirical Model of Large-Batch Training_, 40 | https://arxiv.org/abs/1812.06162). 41 | 42 | Args: 43 | beta (float): The decay factor for the exponential moving averages used to 44 | calculate the gradient noise scale. 45 | Default: 0.9998 46 | eps (float): Added for numerical stability. 47 | Default: 1e-8 48 | """ 49 | 50 | def __init__(self, beta=0.9998, eps=1e-8): 51 | self.beta = beta 52 | self.eps = eps 53 | self.ema_sq_norm = 0. 54 | self.ema_var = 0. 55 | self.beta_cumprod = 1. 56 | self.gradient_noise_scale = float('nan') 57 | 58 | def state_dict(self): 59 | """Returns the state of the object as a :class:`dict`.""" 60 | return dict(self.__dict__.items()) 61 | 62 | def load_state_dict(self, state_dict): 63 | """Loads the object's state. 64 | Args: 65 | state_dict (dict): object state. Should be an object returned 66 | from a call to :meth:`state_dict`. 67 | """ 68 | self.__dict__.update(state_dict) 69 | 70 | def update(self, sq_norm_small_batch, sq_norm_large_batch, n_small_batch, n_large_batch): 71 | """Updates the state with a new batch's gradient statistics, and returns the 72 | current gradient noise scale. 73 | 74 | Args: 75 | sq_norm_small_batch (float): The mean of the squared 2-norms of microbatch or 76 | per sample gradients. 77 | sq_norm_large_batch (float): The squared 2-norm of the mean of the microbatch or 78 | per sample gradients. 79 | n_small_batch (int): The batch size of the individual microbatch or per sample 80 | gradients (1 if per sample). 81 | n_large_batch (int): The total batch size of the mean of the microbatch or 82 | per sample gradients. 83 | """ 84 | est_sq_norm = (n_large_batch * sq_norm_large_batch - n_small_batch * sq_norm_small_batch) / (n_large_batch - n_small_batch) 85 | est_var = (sq_norm_small_batch - sq_norm_large_batch) / (1 / n_small_batch - 1 / n_large_batch) 86 | self.ema_sq_norm = self.beta * self.ema_sq_norm + (1 - self.beta) * est_sq_norm 87 | self.ema_var = self.beta * self.ema_var + (1 - self.beta) * est_var 88 | self.beta_cumprod *= self.beta 89 | self.gradient_noise_scale = max(self.ema_var, self.eps) / max(self.ema_sq_norm, self.eps) 90 | return self.gradient_noise_scale 91 | 92 | def get_gns(self): 93 | """Returns the current gradient noise scale.""" 94 | return self.gradient_noise_scale 95 | 96 | def get_stats(self): 97 | """Returns the current (debiased) estimates of the squared mean gradient 98 | and gradient variance.""" 99 | return self.ema_sq_norm / (1 - self.beta_cumprod), self.ema_var / (1 - self.beta_cumprod) 100 | -------------------------------------------------------------------------------- /sample_clip_guided.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """CLIP guided sampling from k-diffusion models.""" 4 | 5 | import argparse 6 | 7 | import accelerate 8 | import clip 9 | from kornia import augmentation as KA 10 | from resize_right import resize 11 | import safetensors.torch as safetorch 12 | import torch 13 | from torch.nn import functional as F 14 | from torchvision import transforms 15 | from tqdm import trange, tqdm 16 | 17 | import k_diffusion as K 18 | 19 | 20 | def spherical_dist_loss(x, y): 21 | x = F.normalize(x, dim=-1) 22 | y = F.normalize(y, dim=-1) 23 | return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2) 24 | 25 | 26 | def make_cond_model_fn(model, cond_fn): 27 | def model_fn(x, sigma, **kwargs): 28 | with torch.enable_grad(): 29 | x = x.detach().requires_grad_() 30 | denoised = model(x, sigma, **kwargs) 31 | cond_grad = cond_fn(x, sigma, denoised=denoised, **kwargs).detach() 32 | cond_denoised = denoised.detach() + cond_grad * K.utils.append_dims(sigma**2, x.ndim) 33 | return cond_denoised 34 | return model_fn 35 | 36 | 37 | def make_static_thresh_model_fn(model, value=1.): 38 | def model_fn(x, sigma, **kwargs): 39 | return model(x, sigma, **kwargs).clamp(-value, value) 40 | return model_fn 41 | 42 | 43 | def main(): 44 | p = argparse.ArgumentParser(description=__doc__, 45 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 46 | p.add_argument('prompt', type=str, 47 | default='the prompt to use') 48 | p.add_argument('--batch-size', type=int, default=16, 49 | help='the batch size') 50 | p.add_argument('--checkpoint', type=str, required=True, 51 | help='the checkpoint to use') 52 | p.add_argument('--clip-guidance-scale', '-cgs', type=float, default=500., 53 | help='the CLIP guidance scale') 54 | p.add_argument('--clip-model', type=str, default='ViT-B/16', choices=clip.available_models(), 55 | help='the CLIP model to use') 56 | p.add_argument('--config', type=str, 57 | help='the model config') 58 | p.add_argument('-n', type=int, default=64, 59 | help='the number of images to sample') 60 | p.add_argument('--prefix', type=str, default='out', 61 | help='the output prefix') 62 | p.add_argument('--steps', type=int, default=100, 63 | help='the number of denoising steps') 64 | args = p.parse_args() 65 | 66 | config = K.config.load_config(args.config if args.config else args.checkpoint) 67 | model_config = config['model'] 68 | # TODO: allow non-square input sizes 69 | assert len(model_config['input_size']) == 2 and model_config['input_size'][0] == model_config['input_size'][1] 70 | size = model_config['input_size'] 71 | 72 | accelerator = accelerate.Accelerator() 73 | device = accelerator.device 74 | print('Using device:', device, flush=True) 75 | 76 | inner_model = K.config.make_model(config).eval().requires_grad_(False).to(device) 77 | inner_model.load_state_dict(safetorch.load_file(args.checkpoint)) 78 | 79 | accelerator.print('Parameters:', K.utils.n_params(inner_model)) 80 | model = K.Denoiser(inner_model, sigma_data=model_config['sigma_data']) 81 | 82 | sigma_min = model_config['sigma_min'] 83 | sigma_max = model_config['sigma_max'] 84 | 85 | clip_model = clip.load(args.clip_model, device=device)[0].eval().requires_grad_(False) 86 | clip_normalize = transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), 87 | std=(0.26862954, 0.26130258, 0.27577711)) 88 | clip_size = (clip_model.visual.input_resolution, clip_model.visual.input_resolution) 89 | aug = KA.RandomAffine(0, (1/14, 1/14), p=1, padding_mode='border') 90 | 91 | def get_image_embed(x): 92 | if x.shape[2:4] != clip_size: 93 | x = resize(x, out_shape=clip_size, pad_mode='reflect') 94 | x = clip_normalize(x) 95 | x = clip_model.encode_image(x).float() 96 | return F.normalize(x) 97 | 98 | target_embed = F.normalize(clip_model.encode_text(clip.tokenize(args.prompt, truncate=True).to(device)).float()) 99 | 100 | def cond_fn(x, t, denoised): 101 | image_embed = get_image_embed(aug(denoised.add(1).div(2))) 102 | loss = spherical_dist_loss(image_embed, target_embed).sum() * args.clip_guidance_scale 103 | grad = -torch.autograd.grad(loss, x)[0] 104 | return grad 105 | 106 | model_fn = make_cond_model_fn(model, cond_fn) 107 | model_fn = make_static_thresh_model_fn(model_fn) 108 | 109 | @torch.no_grad() 110 | @K.utils.eval_mode(model) 111 | def run(): 112 | if accelerator.is_local_main_process: 113 | tqdm.write('Sampling...') 114 | sigmas = K.sampling.get_sigmas_karras(args.steps, sigma_min, sigma_max, rho=7., device=device) 115 | def sample_fn(n): 116 | x = torch.randn([n, model_config['input_channels'], size[0], size[1]], device=device) * sigmas[0] 117 | x_0 = K.sampling.sample_dpmpp_2s_ancestral(model_fn, x, sigmas, eta=1., disable=not accelerator.is_local_main_process) 118 | return x_0 119 | x_0 = K.evaluation.compute_features(accelerator, sample_fn, lambda x: x, args.n, args.batch_size) 120 | if accelerator.is_main_process: 121 | for i, out in enumerate(x_0): 122 | filename = f'{args.prefix}_{i:05}.png' 123 | K.utils.to_pil_image(out).save(filename) 124 | 125 | try: 126 | run() 127 | except KeyboardInterrupt: 128 | pass 129 | 130 | 131 | if __name__ == '__main__': 132 | main() 133 | -------------------------------------------------------------------------------- /k_diffusion/evaluation.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | from pathlib import Path 4 | 5 | from cleanfid.inception_torchscript import InceptionV3W 6 | import clip 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | from torchvision import transforms 11 | from tqdm.auto import trange 12 | 13 | from . import utils 14 | 15 | 16 | class InceptionV3FeatureExtractor(nn.Module): 17 | def __init__(self, device='cpu'): 18 | super().__init__() 19 | path = Path(os.environ.get('XDG_CACHE_HOME', Path.home() / '.cache')) / 'k-diffusion' 20 | url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt' 21 | digest = 'f58cb9b6ec323ed63459aa4fb441fe750cfe39fafad6da5cb504a16f19e958f4' 22 | utils.download_file(path / 'inception-2015-12-05.pt', url, digest) 23 | self.model = InceptionV3W(str(path), resize_inside=False).to(device) 24 | self.size = (299, 299) 25 | 26 | def forward(self, x): 27 | x = F.interpolate(x, self.size, mode='bicubic', align_corners=False, antialias=True) 28 | if x.shape[1] == 1: 29 | x = torch.cat([x] * 3, dim=1) 30 | x = (x * 127.5 + 127.5).clamp(0, 255) 31 | return self.model(x) 32 | 33 | 34 | class CLIPFeatureExtractor(nn.Module): 35 | def __init__(self, name='ViT-B/16', device='cpu'): 36 | super().__init__() 37 | self.model = clip.load(name, device=device)[0].eval().requires_grad_(False) 38 | self.normalize = transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), 39 | std=(0.26862954, 0.26130258, 0.27577711)) 40 | self.size = self.model.visual.input_resolution, self.model.visual.input_resolution 41 | 42 | @classmethod 43 | def available_models(cls): 44 | return clip.available_models() 45 | 46 | def forward(self, x): 47 | x = (x + 1) / 2 48 | x = F.interpolate(x, self.size, mode='bicubic', align_corners=False, antialias=True) 49 | if x.shape[1] == 1: 50 | x = torch.cat([x] * 3, dim=1) 51 | x = self.normalize(x) 52 | x = self.model.encode_image(x).float() 53 | x = F.normalize(x) * x.shape[-1] ** 0.5 54 | return x 55 | 56 | 57 | class DINOv2FeatureExtractor(nn.Module): 58 | def __init__(self, name='vitl14', device='cpu'): 59 | super().__init__() 60 | self.model = torch.hub.load('facebookresearch/dinov2', 'dinov2_' + name).to(device).eval().requires_grad_(False) 61 | self.normalize = transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) 62 | self.size = 224, 224 63 | 64 | @classmethod 65 | def available_models(cls): 66 | return ['vits14', 'vitb14', 'vitl14', 'vitg14'] 67 | 68 | def forward(self, x): 69 | x = (x + 1) / 2 70 | x = F.interpolate(x, self.size, mode='bicubic', align_corners=False, antialias=True) 71 | if x.shape[1] == 1: 72 | x = torch.cat([x] * 3, dim=1) 73 | x = self.normalize(x) 74 | with torch.cuda.amp.autocast(dtype=torch.float16): 75 | x = self.model(x).float() 76 | x = F.normalize(x) * x.shape[-1] ** 0.5 77 | return x 78 | 79 | 80 | def compute_features(accelerator, sample_fn, extractor_fn, n, batch_size): 81 | n_per_proc = math.ceil(n / accelerator.num_processes) 82 | feats_all = [] 83 | try: 84 | for i in trange(0, n_per_proc, batch_size, disable=not accelerator.is_main_process): 85 | cur_batch_size = min(n - i, batch_size) 86 | samples = sample_fn(cur_batch_size)[:cur_batch_size] 87 | feats_all.append(accelerator.gather(extractor_fn(samples))) 88 | except StopIteration: 89 | pass 90 | return torch.cat(feats_all)[:n] 91 | 92 | 93 | def polynomial_kernel(x, y): 94 | d = x.shape[-1] 95 | dot = x @ y.transpose(-2, -1) 96 | return (dot / d + 1) ** 3 97 | 98 | 99 | def squared_mmd(x, y, kernel=polynomial_kernel): 100 | m = x.shape[-2] 101 | n = y.shape[-2] 102 | kxx = kernel(x, x) 103 | kyy = kernel(y, y) 104 | kxy = kernel(x, y) 105 | kxx_sum = kxx.sum([-1, -2]) - kxx.diagonal(dim1=-1, dim2=-2).sum(-1) 106 | kyy_sum = kyy.sum([-1, -2]) - kyy.diagonal(dim1=-1, dim2=-2).sum(-1) 107 | kxy_sum = kxy.sum([-1, -2]) 108 | term_1 = kxx_sum / m / (m - 1) 109 | term_2 = kyy_sum / n / (n - 1) 110 | term_3 = kxy_sum * 2 / m / n 111 | return term_1 + term_2 - term_3 112 | 113 | 114 | @utils.tf32_mode(matmul=False) 115 | def kid(x, y, max_size=5000): 116 | x_size, y_size = x.shape[0], y.shape[0] 117 | n_partitions = math.ceil(max(x_size / max_size, y_size / max_size)) 118 | total_mmd = x.new_zeros([]) 119 | for i in range(n_partitions): 120 | cur_x = x[round(i * x_size / n_partitions):round((i + 1) * x_size / n_partitions)] 121 | cur_y = y[round(i * y_size / n_partitions):round((i + 1) * y_size / n_partitions)] 122 | total_mmd = total_mmd + squared_mmd(cur_x, cur_y) 123 | return total_mmd / n_partitions 124 | 125 | 126 | class _MatrixSquareRootEig(torch.autograd.Function): 127 | @staticmethod 128 | def forward(ctx, a): 129 | vals, vecs = torch.linalg.eigh(a) 130 | ctx.save_for_backward(vals, vecs) 131 | return vecs @ vals.abs().sqrt().diag_embed() @ vecs.transpose(-2, -1) 132 | 133 | @staticmethod 134 | def backward(ctx, grad_output): 135 | vals, vecs = ctx.saved_tensors 136 | d = vals.abs().sqrt().unsqueeze(-1).repeat_interleave(vals.shape[-1], -1) 137 | vecs_t = vecs.transpose(-2, -1) 138 | return vecs @ (vecs_t @ grad_output @ vecs / (d + d.transpose(-2, -1))) @ vecs_t 139 | 140 | 141 | def sqrtm_eig(a): 142 | if a.ndim < 2: 143 | raise RuntimeError('tensor of matrices must have at least 2 dimensions') 144 | if a.shape[-2] != a.shape[-1]: 145 | raise RuntimeError('tensor must be batches of square matrices') 146 | return _MatrixSquareRootEig.apply(a) 147 | 148 | 149 | @utils.tf32_mode(matmul=False) 150 | def fid(x, y, eps=1e-8): 151 | x_mean = x.mean(dim=0) 152 | y_mean = y.mean(dim=0) 153 | mean_term = (x_mean - y_mean).pow(2).sum() 154 | x_cov = torch.cov(x.T) 155 | y_cov = torch.cov(y.T) 156 | eps_eye = torch.eye(x_cov.shape[0], device=x_cov.device, dtype=x_cov.dtype) * eps 157 | x_cov = x_cov + eps_eye 158 | y_cov = y_cov + eps_eye 159 | x_cov_sqrt = sqrtm_eig(x_cov) 160 | cov_term = torch.trace(x_cov + y_cov - 2 * sqrtm_eig(x_cov_sqrt @ y_cov @ x_cov_sqrt)) 161 | return mean_term + cov_term 162 | -------------------------------------------------------------------------------- /k_diffusion/external.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from . import sampling, utils 7 | 8 | 9 | class VDenoiser(nn.Module): 10 | """A v-diffusion-pytorch model wrapper for k-diffusion.""" 11 | 12 | def __init__(self, inner_model): 13 | super().__init__() 14 | self.inner_model = inner_model 15 | self.sigma_data = 1. 16 | 17 | def get_scalings(self, sigma): 18 | c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) 19 | c_out = -sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 20 | c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 21 | return c_skip, c_out, c_in 22 | 23 | def sigma_to_t(self, sigma): 24 | return sigma.atan() / math.pi * 2 25 | 26 | def t_to_sigma(self, t): 27 | return (t * math.pi / 2).tan() 28 | 29 | def loss(self, input, noise, sigma, **kwargs): 30 | c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] 31 | noised_input = input + noise * utils.append_dims(sigma, input.ndim) 32 | model_output = self.inner_model(noised_input * c_in, self.sigma_to_t(sigma), **kwargs) 33 | target = (input - c_skip * noised_input) / c_out 34 | return (model_output - target).pow(2).flatten(1).mean(1) 35 | 36 | def forward(self, input, sigma, **kwargs): 37 | c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] 38 | return self.inner_model(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip 39 | 40 | 41 | class DiscreteSchedule(nn.Module): 42 | """A mapping between continuous noise levels (sigmas) and a list of discrete noise 43 | levels.""" 44 | 45 | def __init__(self, sigmas, quantize): 46 | super().__init__() 47 | self.register_buffer('sigmas', sigmas) 48 | self.register_buffer('log_sigmas', sigmas.log()) 49 | self.quantize = quantize 50 | 51 | @property 52 | def sigma_min(self): 53 | return self.sigmas[0] 54 | 55 | @property 56 | def sigma_max(self): 57 | return self.sigmas[-1] 58 | 59 | def get_sigmas(self, n=None): 60 | if n is None: 61 | return sampling.append_zero(self.sigmas.flip(0)) 62 | t_max = len(self.sigmas) - 1 63 | t = torch.linspace(t_max, 0, n, device=self.sigmas.device) 64 | return sampling.append_zero(self.t_to_sigma(t)) 65 | 66 | def sigma_to_t(self, sigma, quantize=None): 67 | quantize = self.quantize if quantize is None else quantize 68 | log_sigma = sigma.log() 69 | dists = log_sigma - self.log_sigmas[:, None] 70 | if quantize: 71 | return dists.abs().argmin(dim=0).view(sigma.shape) 72 | low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2) 73 | high_idx = low_idx + 1 74 | low, high = self.log_sigmas[low_idx], self.log_sigmas[high_idx] 75 | w = (low - log_sigma) / (low - high) 76 | w = w.clamp(0, 1) 77 | t = (1 - w) * low_idx + w * high_idx 78 | return t.view(sigma.shape) 79 | 80 | def t_to_sigma(self, t): 81 | t = t.float() 82 | low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac() 83 | log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx] 84 | return log_sigma.exp() 85 | 86 | 87 | class DiscreteEpsDDPMDenoiser(DiscreteSchedule): 88 | """A wrapper for discrete schedule DDPM models that output eps (the predicted 89 | noise).""" 90 | 91 | def __init__(self, model, alphas_cumprod, quantize): 92 | super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize) 93 | self.inner_model = model 94 | self.sigma_data = 1. 95 | 96 | def get_scalings(self, sigma): 97 | c_out = -sigma 98 | c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 99 | return c_out, c_in 100 | 101 | def get_eps(self, *args, **kwargs): 102 | return self.inner_model(*args, **kwargs) 103 | 104 | def loss(self, input, noise, sigma, **kwargs): 105 | c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] 106 | noised_input = input + noise * utils.append_dims(sigma, input.ndim) 107 | eps = self.get_eps(noised_input * c_in, self.sigma_to_t(sigma), **kwargs) 108 | return (eps - noise).pow(2).flatten(1).mean(1) 109 | 110 | def forward(self, input, sigma, **kwargs): 111 | c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] 112 | eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs) 113 | return input + eps * c_out 114 | 115 | 116 | class OpenAIDenoiser(DiscreteEpsDDPMDenoiser): 117 | """A wrapper for OpenAI diffusion models.""" 118 | 119 | def __init__(self, model, diffusion, quantize=False, has_learned_sigmas=True, device='cpu'): 120 | alphas_cumprod = torch.tensor(diffusion.alphas_cumprod, device=device, dtype=torch.float32) 121 | super().__init__(model, alphas_cumprod, quantize=quantize) 122 | self.has_learned_sigmas = has_learned_sigmas 123 | 124 | def get_eps(self, *args, **kwargs): 125 | model_output = self.inner_model(*args, **kwargs) 126 | if self.has_learned_sigmas: 127 | return model_output.chunk(2, dim=1)[0] 128 | return model_output 129 | 130 | 131 | class CompVisDenoiser(DiscreteEpsDDPMDenoiser): 132 | """A wrapper for CompVis diffusion models.""" 133 | 134 | def __init__(self, model, quantize=False, device='cpu'): 135 | super().__init__(model, model.alphas_cumprod, quantize=quantize) 136 | 137 | def get_eps(self, *args, **kwargs): 138 | return self.inner_model.apply_model(*args, **kwargs) 139 | 140 | 141 | class DiscreteVDDPMDenoiser(DiscreteSchedule): 142 | """A wrapper for discrete schedule DDPM models that output v.""" 143 | 144 | def __init__(self, model, alphas_cumprod, quantize): 145 | super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize) 146 | self.inner_model = model 147 | self.sigma_data = 1. 148 | 149 | def get_scalings(self, sigma): 150 | c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) 151 | c_out = -sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 152 | c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 153 | return c_skip, c_out, c_in 154 | 155 | def get_v(self, *args, **kwargs): 156 | return self.inner_model(*args, **kwargs) 157 | 158 | def loss(self, input, noise, sigma, **kwargs): 159 | c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] 160 | noised_input = input + noise * utils.append_dims(sigma, input.ndim) 161 | model_output = self.get_v(noised_input * c_in, self.sigma_to_t(sigma), **kwargs) 162 | target = (input - c_skip * noised_input) / c_out 163 | return (model_output - target).pow(2).flatten(1).mean(1) 164 | 165 | def forward(self, input, sigma, **kwargs): 166 | c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] 167 | return self.get_v(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip 168 | 169 | 170 | class CompVisVDenoiser(DiscreteVDDPMDenoiser): 171 | """A wrapper for CompVis diffusion models that output v.""" 172 | 173 | def __init__(self, model, quantize=False, device='cpu'): 174 | super().__init__(model, model.alphas_cumprod, quantize=quantize) 175 | 176 | def get_v(self, x, t, cond, **kwargs): 177 | return self.inner_model.apply_model(x, t, cond) 178 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # k-diffusion 2 | 3 | [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.10284390.svg)](https://doi.org/10.5281/zenodo.10284390) 4 | 5 | An implementation of [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364) (Karras et al., 2022) for PyTorch, with enhancements and additional features, such as improved sampling algorithms and transformer-based diffusion models. 6 | 7 | ## Hourglass diffusion transformer 8 | 9 | `k-diffusion` contains a new model type, `image_transformer_v2`, that uses ideas from [Hourglass Transformer](https://arxiv.org/abs/2110.13711) and [DiT](https://arxiv.org/abs/2212.09748). 10 | 11 | ### Requirements 12 | 13 | To use the new model type you will need to install custom CUDA kernels: 14 | 15 | * [NATTEN](https://github.com/SHI-Labs/NATTEN/tree/main) for the sparse (neighborhood) attention used at low levels of the hierarchy. There is a [shifted window attention](https://arxiv.org/abs/2103.14030) version of the model type which does not require a custom CUDA kernel, but it does not perform as well and is slower to train and inference. 16 | 17 | * [FlashAttention-2](https://github.com/Dao-AILab/flash-attention) for global attention. It will fall back to plain PyTorch if it is not installed. 18 | 19 | Also, you should make sure your PyTorch installation is capable of using `torch.compile()`. It will fall back to eager mode if `torch.compile()` is not available, but it will be slower and use more memory in training. 20 | 21 | ### Usage 22 | 23 | #### Demo 24 | 25 | To train a 256x256 RGB model on [Oxford Flowers](https://www.robots.ox.ac.uk/~vgg/data/flowers) without installing custom CUDA kernels, install [Hugging Face Datasets](https://huggingface.co/docs/datasets/index): 26 | 27 | ```sh 28 | pip install datasets 29 | ``` 30 | 31 | and run: 32 | 33 | ```sh 34 | python train.py --config configs/config_oxford_flowers_shifted_window.json --name flowers_demo_001 --evaluate-n 0 --batch-size 32 --sample-n 36 --mixed-precision bf16 35 | ``` 36 | 37 | If you run out of memory, try adding `--checkpointing` or reducing the batch size. If you are using an older GPU (pre-Ampere), omit `--mixed-precision bf16` to train in FP32. It is not recommended to train in FP16. 38 | 39 | If you have NATTEN installed and working (preferred), you can train with neighborhood attention instead of shifted window attention by specifying `--config configs/config_oxford_flowers.json`. 40 | 41 | #### Config file 42 | 43 | In the `"model"` key of the config file: 44 | 45 | 1. Set the `"type"` key to `"image_transformer_v2"`. 46 | 47 | 1. The base patch size is set by the `"patch_size"` key, like `"patch_size": [4, 4]`. 48 | 49 | 1. Model depth for each level of the hierarchy is specified by the `"depths"` config key, like `"depths": [2, 2, 4]`. This constructs a model with two transformer layers at the first level (4x4 patches), followed by two at the second level (8x8 patches), followed by four at the highest level (16x16 patches), followed by two more at the second level, followed by two more at the first level. 50 | 51 | 1. Model width for each level of the hierarchy is specified by the `"widths"` config key, like `"widths": [192, 384, 768]`. The widths must be multiples of the attention head dimension. 52 | 53 | 1. The self-attention mechanism for each level of the hierarchy is specified by the `"self_attns"` config key, like: 54 | 55 | ```json 56 | "self_attns": [ 57 | {"type": "neighborhood", "d_head": 64, "kernel_size": 7}, 58 | {"type": "neighborhood", "d_head": 64, "kernel_size": 7}, 59 | {"type": "global", "d_head": 64}, 60 | ] 61 | ``` 62 | 63 | If not specified, all levels of the hierarchy except for the highest use neighborhood attention with 64 dim heads and a 7x7 kernel. The highest level uses global attention with 64 dim heads. So the token count at every level but the highest can be very large. 64 | 65 | 1. As a fallback if you or your users cannot use NATTEN, you can also train a model with [shifted window attention](https://arxiv.org/abs/2103.14030) at the low levels of the hierarchy. Shifted window attention does not perform as well as neighborhood attention and it is slower to train and inference, but it does not require custom CUDA kernels. Specify it like: 66 | 67 | ```json 68 | "self_attns": [ 69 | {"type": "shifted-window", "d_head": 64, "window_size": 8}, 70 | {"type": "shifted-window", "d_head": 64, "window_size": 8}, 71 | {"type": "global", "d_head": 64}, 72 | ] 73 | ``` 74 | 75 | The window size at each level must evenly divide the image size at that level. Models trained with one attention type must be fine-tuned to be used with a different type. 76 | 77 | #### Inference 78 | 79 | TODO: write this section 80 | 81 | ## Installation 82 | 83 | `k-diffusion` can be installed via PyPI (`pip install k-diffusion`) but it will not include training and inference scripts, only library code that others can depend on. To run the training and inference scripts, clone this repository and run `pip install -e `. 84 | 85 | ## Training 86 | 87 | To train models: 88 | 89 | ```sh 90 | $ ./train.py --config CONFIG_FILE --name RUN_NAME 91 | ``` 92 | 93 | For instance, to train a model on MNIST: 94 | 95 | ```sh 96 | $ ./train.py --config configs/config_mnist_transformer.json --name RUN_NAME 97 | ``` 98 | 99 | The configuration file allows you to specify the dataset type. Currently supported types are `"imagefolder"` (finds all images in that folder and its subfolders, recursively), `"cifar10"` (CIFAR-10), and `"mnist"` (MNIST). `"huggingface"` [Hugging Face Datasets](https://huggingface.co/docs/datasets/index) is also supported. 100 | 101 | Multi-GPU and multi-node training is supported with [Hugging Face Accelerate](https://huggingface.co/docs/accelerate/index). You can configure Accelerate by running: 102 | 103 | ```sh 104 | $ accelerate config 105 | ``` 106 | 107 | then running: 108 | 109 | ```sh 110 | $ accelerate launch train.py --config CONFIG_FILE --name RUN_NAME 111 | ``` 112 | 113 | ## Enhancements/additional features 114 | 115 | - k-diffusion supports a highly efficient hierarchical transformer model type. 116 | 117 | - k-diffusion supports a soft version of [Min-SNR loss weighting](https://arxiv.org/abs/2303.09556) for improved training at high resolutions with less hyperparameters than the loss weighting used in Karras et al. (2022). 118 | 119 | - k-diffusion has wrappers for [v-diffusion-pytorch](https://github.com/crowsonkb/v-diffusion-pytorch), [OpenAI diffusion](https://github.com/openai/guided-diffusion), and [CompVis diffusion](https://github.com/CompVis/latent-diffusion) models allowing them to be used with its samplers and ODE/SDE. 120 | 121 | - k-diffusion implements [DPM-Solver](https://arxiv.org/abs/2206.00927), which produces higher quality samples at the same number of function evalutions as Karras Algorithm 2, as well as supporting adaptive step size control. [DPM-Solver++(2S) and (2M)](https://arxiv.org/abs/2211.01095) are implemented now too for improved quality with low numbers of steps. 122 | 123 | - k-diffusion supports [CLIP](https://openai.com/blog/clip/) guided sampling from unconditional diffusion models (see `sample_clip_guided.py`). 124 | 125 | - k-diffusion supports log likelihood calculation (not a variational lower bound) for native models and all wrapped models. 126 | 127 | - k-diffusion can calculate, during training, the [FID](https://papers.nips.cc/paper/2017/file/8a1d694707eb0fefe65871369074926d-Paper.pdf) and [KID](https://arxiv.org/abs/1801.01401) vs the training set. 128 | 129 | - k-diffusion can calculate, during training, the gradient noise scale (1 / SNR), from _An Empirical Model of Large-Batch Training_, https://arxiv.org/abs/1812.06162). 130 | 131 | ## To do 132 | 133 | - Latent diffusion 134 | -------------------------------------------------------------------------------- /k_diffusion/models/image_v1.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | from .. import layers, utils 8 | 9 | 10 | def orthogonal_(module): 11 | nn.init.orthogonal_(module.weight) 12 | return module 13 | 14 | 15 | class ResConvBlock(layers.ConditionedResidualBlock): 16 | def __init__(self, feats_in, c_in, c_mid, c_out, group_size=32, dropout_rate=0.): 17 | skip = None if c_in == c_out else orthogonal_(nn.Conv2d(c_in, c_out, 1, bias=False)) 18 | super().__init__( 19 | layers.AdaGN(feats_in, c_in, max(1, c_in // group_size)), 20 | nn.GELU(), 21 | nn.Conv2d(c_in, c_mid, 3, padding=1), 22 | nn.Dropout2d(dropout_rate, inplace=True), 23 | layers.AdaGN(feats_in, c_mid, max(1, c_mid // group_size)), 24 | nn.GELU(), 25 | nn.Conv2d(c_mid, c_out, 3, padding=1), 26 | nn.Dropout2d(dropout_rate, inplace=True), 27 | skip=skip) 28 | nn.init.zeros_(self.main[-2].weight) 29 | nn.init.zeros_(self.main[-2].bias) 30 | 31 | 32 | class DBlock(layers.ConditionedSequential): 33 | def __init__(self, n_layers, feats_in, c_in, c_mid, c_out, group_size=32, head_size=64, dropout_rate=0., downsample=False, self_attn=False, cross_attn=False, c_enc=0): 34 | modules = [nn.Identity()] 35 | for i in range(n_layers): 36 | my_c_in = c_in if i == 0 else c_mid 37 | my_c_out = c_mid if i < n_layers - 1 else c_out 38 | modules.append(ResConvBlock(feats_in, my_c_in, c_mid, my_c_out, group_size, dropout_rate)) 39 | if self_attn: 40 | norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size)) 41 | modules.append(layers.SelfAttention2d(my_c_out, max(1, my_c_out // head_size), norm, dropout_rate)) 42 | if cross_attn: 43 | norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size)) 44 | modules.append(layers.CrossAttention2d(my_c_out, c_enc, max(1, my_c_out // head_size), norm, dropout_rate)) 45 | super().__init__(*modules) 46 | self.set_downsample(downsample) 47 | 48 | def set_downsample(self, downsample): 49 | self[0] = layers.Downsample2d() if downsample else nn.Identity() 50 | return self 51 | 52 | 53 | class UBlock(layers.ConditionedSequential): 54 | def __init__(self, n_layers, feats_in, c_in, c_mid, c_out, group_size=32, head_size=64, dropout_rate=0., upsample=False, self_attn=False, cross_attn=False, c_enc=0): 55 | modules = [] 56 | for i in range(n_layers): 57 | my_c_in = c_in if i == 0 else c_mid 58 | my_c_out = c_mid if i < n_layers - 1 else c_out 59 | modules.append(ResConvBlock(feats_in, my_c_in, c_mid, my_c_out, group_size, dropout_rate)) 60 | if self_attn: 61 | norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size)) 62 | modules.append(layers.SelfAttention2d(my_c_out, max(1, my_c_out // head_size), norm, dropout_rate)) 63 | if cross_attn: 64 | norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size)) 65 | modules.append(layers.CrossAttention2d(my_c_out, c_enc, max(1, my_c_out // head_size), norm, dropout_rate)) 66 | modules.append(nn.Identity()) 67 | super().__init__(*modules) 68 | self.set_upsample(upsample) 69 | 70 | def forward(self, input, cond, skip=None): 71 | if skip is not None: 72 | input = torch.cat([input, skip], dim=1) 73 | return super().forward(input, cond) 74 | 75 | def set_upsample(self, upsample): 76 | self[-1] = layers.Upsample2d() if upsample else nn.Identity() 77 | return self 78 | 79 | 80 | class MappingNet(nn.Sequential): 81 | def __init__(self, feats_in, feats_out, n_layers=2): 82 | layers = [] 83 | for i in range(n_layers): 84 | layers.append(orthogonal_(nn.Linear(feats_in if i == 0 else feats_out, feats_out))) 85 | layers.append(nn.GELU()) 86 | super().__init__(*layers) 87 | 88 | 89 | class ImageDenoiserModelV1(nn.Module): 90 | def __init__(self, c_in, feats_in, depths, channels, self_attn_depths, cross_attn_depths=None, mapping_cond_dim=0, unet_cond_dim=0, cross_cond_dim=0, dropout_rate=0., patch_size=1, skip_stages=0, has_variance=False): 91 | super().__init__() 92 | self.c_in = c_in 93 | self.channels = channels 94 | self.unet_cond_dim = unet_cond_dim 95 | self.patch_size = patch_size 96 | self.has_variance = has_variance 97 | self.timestep_embed = layers.FourierFeatures(1, feats_in) 98 | if mapping_cond_dim > 0: 99 | self.mapping_cond = nn.Linear(mapping_cond_dim, feats_in, bias=False) 100 | self.mapping = MappingNet(feats_in, feats_in) 101 | self.proj_in = nn.Conv2d((c_in + unet_cond_dim) * self.patch_size ** 2, channels[max(0, skip_stages - 1)], 1) 102 | self.proj_out = nn.Conv2d(channels[max(0, skip_stages - 1)], c_in * self.patch_size ** 2 + (1 if self.has_variance else 0), 1) 103 | nn.init.zeros_(self.proj_out.weight) 104 | nn.init.zeros_(self.proj_out.bias) 105 | if cross_cond_dim == 0: 106 | cross_attn_depths = [False] * len(self_attn_depths) 107 | d_blocks, u_blocks = [], [] 108 | for i in range(len(depths)): 109 | my_c_in = channels[max(0, i - 1)] 110 | d_blocks.append(DBlock(depths[i], feats_in, my_c_in, channels[i], channels[i], downsample=i > skip_stages, self_attn=self_attn_depths[i], cross_attn=cross_attn_depths[i], c_enc=cross_cond_dim, dropout_rate=dropout_rate)) 111 | for i in range(len(depths)): 112 | my_c_in = channels[i] * 2 if i < len(depths) - 1 else channels[i] 113 | my_c_out = channels[max(0, i - 1)] 114 | u_blocks.append(UBlock(depths[i], feats_in, my_c_in, channels[i], my_c_out, upsample=i > skip_stages, self_attn=self_attn_depths[i], cross_attn=cross_attn_depths[i], c_enc=cross_cond_dim, dropout_rate=dropout_rate)) 115 | self.u_net = layers.UNet(d_blocks, reversed(u_blocks), skip_stages=skip_stages) 116 | 117 | def param_groups(self, base_lr=2e-4): 118 | wd_names = [] 119 | for name, _ in self.named_parameters(): 120 | if name.startswith("mapping") or name.startswith("u_net"): 121 | if name.endswith(".weight"): 122 | wd_names.append(name) 123 | wd, no_wd = [], [] 124 | for name, param in self.named_parameters(): 125 | if name in wd_names: 126 | wd.append(param) 127 | else: 128 | no_wd.append(param) 129 | groups = [ 130 | {"params": wd, "lr": base_lr}, 131 | {"params": no_wd, "lr": base_lr, "weight_decay": 0.0}, 132 | ] 133 | return groups 134 | 135 | def forward(self, input, sigma, mapping_cond=None, unet_cond=None, cross_cond=None, cross_cond_padding=None, return_variance=False): 136 | c_noise = sigma.log() / 4 137 | timestep_embed = self.timestep_embed(utils.append_dims(c_noise, 2)) 138 | mapping_cond_embed = torch.zeros_like(timestep_embed) if mapping_cond is None else self.mapping_cond(mapping_cond) 139 | mapping_out = self.mapping(timestep_embed + mapping_cond_embed) 140 | cond = {'cond': mapping_out} 141 | if unet_cond is not None: 142 | input = torch.cat([input, unet_cond], dim=1) 143 | if cross_cond is not None: 144 | cond['cross'] = cross_cond 145 | cond['cross_padding'] = cross_cond_padding 146 | if self.patch_size > 1: 147 | input = F.pixel_unshuffle(input, self.patch_size) 148 | input = self.proj_in(input) 149 | input = self.u_net(input, cond) 150 | input = self.proj_out(input) 151 | if self.has_variance: 152 | input, logvar = input[:, :-1], input[:, -1].flatten(1).mean(1) 153 | if self.patch_size > 1: 154 | input = F.pixel_shuffle(input, self.patch_size) 155 | if self.has_variance and return_variance: 156 | return input, logvar 157 | return input 158 | 159 | def set_skip_stages(self, skip_stages): 160 | self.proj_in = nn.Conv2d(self.proj_in.in_channels, self.channels[max(0, skip_stages - 1)], 1) 161 | self.proj_out = nn.Conv2d(self.channels[max(0, skip_stages - 1)], self.proj_out.out_channels, 1) 162 | nn.init.zeros_(self.proj_out.weight) 163 | nn.init.zeros_(self.proj_out.bias) 164 | self.u_net.skip_stages = skip_stages 165 | for i, block in enumerate(self.u_net.d_blocks): 166 | block.set_downsample(i > skip_stages) 167 | for i, block in enumerate(reversed(self.u_net.u_blocks)): 168 | block.set_upsample(i > skip_stages) 169 | return self 170 | 171 | def set_patch_size(self, patch_size): 172 | self.patch_size = patch_size 173 | self.proj_in = nn.Conv2d((self.c_in + self.unet_cond_dim) * self.patch_size ** 2, self.channels[max(0, self.u_net.skip_stages - 1)], 1) 174 | self.proj_out = nn.Conv2d(self.channels[max(0, self.u_net.skip_stages - 1)], self.c_in * self.patch_size ** 2 + (1 if self.has_variance else 0), 1) 175 | nn.init.zeros_(self.proj_out.weight) 176 | nn.init.zeros_(self.proj_out.bias) 177 | -------------------------------------------------------------------------------- /k_diffusion/config.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import json 3 | import math 4 | from pathlib import Path 5 | 6 | from jsonmerge import merge 7 | 8 | from . import augmentation, layers, models, utils 9 | 10 | 11 | def round_to_power_of_two(x, tol): 12 | approxs = [] 13 | for i in range(math.ceil(math.log2(x))): 14 | mult = 2**i 15 | approxs.append(round(x / mult) * mult) 16 | for approx in reversed(approxs): 17 | error = abs((approx - x) / x) 18 | if error <= tol: 19 | return approx 20 | return approxs[0] 21 | 22 | 23 | def load_config(path_or_dict): 24 | defaults_image_v1 = { 25 | 'model': { 26 | 'patch_size': 1, 27 | 'augment_wrapper': True, 28 | 'mapping_cond_dim': 0, 29 | 'unet_cond_dim': 0, 30 | 'cross_cond_dim': 0, 31 | 'cross_attn_depths': None, 32 | 'skip_stages': 0, 33 | 'has_variance': False, 34 | }, 35 | 'optimizer': { 36 | 'type': 'adamw', 37 | 'lr': 1e-4, 38 | 'betas': [0.95, 0.999], 39 | 'eps': 1e-6, 40 | 'weight_decay': 1e-3, 41 | }, 42 | } 43 | defaults_image_transformer_v1 = { 44 | 'model': { 45 | 'd_ff': 0, 46 | 'augment_wrapper': False, 47 | 'skip_stages': 0, 48 | 'has_variance': False, 49 | }, 50 | 'optimizer': { 51 | 'type': 'adamw', 52 | 'lr': 5e-4, 53 | 'betas': [0.9, 0.99], 54 | 'eps': 1e-8, 55 | 'weight_decay': 1e-4, 56 | }, 57 | } 58 | defaults_image_transformer_v2 = { 59 | 'model': { 60 | 'mapping_width': 256, 61 | 'mapping_depth': 2, 62 | 'mapping_d_ff': None, 63 | 'mapping_cond_dim': 0, 64 | 'mapping_dropout_rate': 0., 65 | 'd_ffs': None, 66 | 'self_attns': None, 67 | 'dropout_rate': None, 68 | 'augment_wrapper': False, 69 | 'skip_stages': 0, 70 | 'has_variance': False, 71 | }, 72 | 'optimizer': { 73 | 'type': 'adamw', 74 | 'lr': 5e-4, 75 | 'betas': [0.9, 0.99], 76 | 'eps': 1e-8, 77 | 'weight_decay': 1e-4, 78 | }, 79 | } 80 | defaults = { 81 | 'model': { 82 | 'sigma_data': 1., 83 | 'dropout_rate': 0., 84 | 'augment_prob': 0., 85 | 'loss_config': 'karras', 86 | 'loss_weighting': 'karras', 87 | 'loss_scales': 1, 88 | }, 89 | 'dataset': { 90 | 'type': 'imagefolder', 91 | 'num_classes': 0, 92 | 'cond_dropout_rate': 0.1, 93 | }, 94 | 'optimizer': { 95 | 'type': 'adamw', 96 | 'lr': 1e-4, 97 | 'betas': [0.9, 0.999], 98 | 'eps': 1e-8, 99 | 'weight_decay': 1e-4, 100 | }, 101 | 'lr_sched': { 102 | 'type': 'constant', 103 | 'warmup': 0., 104 | }, 105 | 'ema_sched': { 106 | 'type': 'inverse', 107 | 'power': 0.6667, 108 | 'max_value': 0.9999 109 | }, 110 | } 111 | if not isinstance(path_or_dict, dict): 112 | file = Path(path_or_dict) 113 | if file.suffix == '.safetensors': 114 | metadata = utils.get_safetensors_metadata(file) 115 | config = json.loads(metadata['config']) 116 | else: 117 | config = json.loads(file.read_text()) 118 | else: 119 | config = path_or_dict 120 | if config['model']['type'] == 'image_v1': 121 | config = merge(defaults_image_v1, config) 122 | elif config['model']['type'] == 'image_transformer_v1': 123 | config = merge(defaults_image_transformer_v1, config) 124 | if not config['model']['d_ff']: 125 | config['model']['d_ff'] = round_to_power_of_two(config['model']['width'] * 8 / 3, tol=0.05) 126 | elif config['model']['type'] == 'image_transformer_v2': 127 | config = merge(defaults_image_transformer_v2, config) 128 | if not config['model']['mapping_d_ff']: 129 | config['model']['mapping_d_ff'] = config['model']['mapping_width'] * 3 130 | if not config['model']['d_ffs']: 131 | d_ffs = [] 132 | for width in config['model']['widths']: 133 | d_ffs.append(width * 3) 134 | config['model']['d_ffs'] = d_ffs 135 | if not config['model']['self_attns']: 136 | self_attns = [] 137 | default_neighborhood = {"type": "neighborhood", "d_head": 64, "kernel_size": 7} 138 | default_global = {"type": "global", "d_head": 64} 139 | for i in range(len(config['model']['widths'])): 140 | self_attns.append(default_neighborhood if i < len(config['model']['widths']) - 1 else default_global) 141 | config['model']['self_attns'] = self_attns 142 | if config['model']['dropout_rate'] is None: 143 | config['model']['dropout_rate'] = [0.0] * len(config['model']['widths']) 144 | elif isinstance(config['model']['dropout_rate'], float): 145 | config['model']['dropout_rate'] = [config['model']['dropout_rate']] * len(config['model']['widths']) 146 | return merge(defaults, config) 147 | 148 | 149 | def make_model(config): 150 | dataset_config = config['dataset'] 151 | num_classes = dataset_config['num_classes'] 152 | config = config['model'] 153 | if config['type'] == 'image_v1': 154 | model = models.ImageDenoiserModelV1( 155 | config['input_channels'], 156 | config['mapping_out'], 157 | config['depths'], 158 | config['channels'], 159 | config['self_attn_depths'], 160 | config['cross_attn_depths'], 161 | patch_size=config['patch_size'], 162 | dropout_rate=config['dropout_rate'], 163 | mapping_cond_dim=config['mapping_cond_dim'] + (9 if config['augment_wrapper'] else 0), 164 | unet_cond_dim=config['unet_cond_dim'], 165 | cross_cond_dim=config['cross_cond_dim'], 166 | skip_stages=config['skip_stages'], 167 | has_variance=config['has_variance'], 168 | ) 169 | if config['augment_wrapper']: 170 | model = augmentation.KarrasAugmentWrapper(model) 171 | elif config['type'] == 'image_transformer_v1': 172 | model = models.ImageTransformerDenoiserModelV1( 173 | n_layers=config['depth'], 174 | d_model=config['width'], 175 | d_ff=config['d_ff'], 176 | in_features=config['input_channels'], 177 | out_features=config['input_channels'], 178 | patch_size=config['patch_size'], 179 | num_classes=num_classes + 1 if num_classes else 0, 180 | dropout=config['dropout_rate'], 181 | sigma_data=config['sigma_data'], 182 | ) 183 | elif config['type'] == 'image_transformer_v2': 184 | assert len(config['widths']) == len(config['depths']) 185 | assert len(config['widths']) == len(config['d_ffs']) 186 | assert len(config['widths']) == len(config['self_attns']) 187 | assert len(config['widths']) == len(config['dropout_rate']) 188 | levels = [] 189 | for depth, width, d_ff, self_attn, dropout in zip(config['depths'], config['widths'], config['d_ffs'], config['self_attns'], config['dropout_rate']): 190 | if self_attn['type'] == 'global': 191 | self_attn = models.image_transformer_v2.GlobalAttentionSpec(self_attn.get('d_head', 64)) 192 | elif self_attn['type'] == 'neighborhood': 193 | self_attn = models.image_transformer_v2.NeighborhoodAttentionSpec(self_attn.get('d_head', 64), self_attn.get('kernel_size', 7)) 194 | elif self_attn['type'] == 'shifted-window': 195 | self_attn = models.image_transformer_v2.ShiftedWindowAttentionSpec(self_attn.get('d_head', 64), self_attn['window_size']) 196 | elif self_attn['type'] == 'none': 197 | self_attn = models.image_transformer_v2.NoAttentionSpec() 198 | else: 199 | raise ValueError(f'unsupported self attention type {self_attn["type"]}') 200 | levels.append(models.image_transformer_v2.LevelSpec(depth, width, d_ff, self_attn, dropout)) 201 | mapping = models.image_transformer_v2.MappingSpec(config['mapping_depth'], config['mapping_width'], config['mapping_d_ff'], config['mapping_dropout_rate']) 202 | model = models.ImageTransformerDenoiserModelV2( 203 | levels=levels, 204 | mapping=mapping, 205 | in_channels=config['input_channels'], 206 | out_channels=config['input_channels'], 207 | patch_size=config['patch_size'], 208 | num_classes=num_classes + 1 if num_classes else 0, 209 | mapping_cond_dim=config['mapping_cond_dim'], 210 | ) 211 | else: 212 | raise ValueError(f'unsupported model type {config["type"]}') 213 | return model 214 | 215 | 216 | def make_denoiser_wrapper(config): 217 | config = config['model'] 218 | sigma_data = config.get('sigma_data', 1.) 219 | has_variance = config.get('has_variance', False) 220 | loss_config = config.get('loss_config', 'karras') 221 | if loss_config == 'karras': 222 | weighting = config.get('loss_weighting', 'karras') 223 | scales = config.get('loss_scales', 1) 224 | if not has_variance: 225 | return partial(layers.Denoiser, sigma_data=sigma_data, weighting=weighting, scales=scales) 226 | return partial(layers.DenoiserWithVariance, sigma_data=sigma_data, weighting=weighting) 227 | if loss_config == 'simple': 228 | if has_variance: 229 | raise ValueError('Simple loss config does not support a variance output') 230 | return partial(layers.SimpleLossDenoiser, sigma_data=sigma_data) 231 | raise ValueError('Unknown loss config type') 232 | 233 | 234 | def make_sample_density(config): 235 | sd_config = config['sigma_sample_density'] 236 | sigma_data = config['sigma_data'] 237 | if sd_config['type'] == 'lognormal': 238 | loc = sd_config['mean'] if 'mean' in sd_config else sd_config['loc'] 239 | scale = sd_config['std'] if 'std' in sd_config else sd_config['scale'] 240 | return partial(utils.rand_log_normal, loc=loc, scale=scale) 241 | if sd_config['type'] == 'loglogistic': 242 | loc = sd_config['loc'] if 'loc' in sd_config else math.log(sigma_data) 243 | scale = sd_config['scale'] if 'scale' in sd_config else 0.5 244 | min_value = sd_config['min_value'] if 'min_value' in sd_config else 0. 245 | max_value = sd_config['max_value'] if 'max_value' in sd_config else float('inf') 246 | return partial(utils.rand_log_logistic, loc=loc, scale=scale, min_value=min_value, max_value=max_value) 247 | if sd_config['type'] == 'loguniform': 248 | min_value = sd_config['min_value'] if 'min_value' in sd_config else config['sigma_min'] 249 | max_value = sd_config['max_value'] if 'max_value' in sd_config else config['sigma_max'] 250 | return partial(utils.rand_log_uniform, min_value=min_value, max_value=max_value) 251 | if sd_config['type'] in {'v-diffusion', 'cosine'}: 252 | min_value = sd_config['min_value'] if 'min_value' in sd_config else 1e-3 253 | max_value = sd_config['max_value'] if 'max_value' in sd_config else 1e3 254 | return partial(utils.rand_v_diffusion, sigma_data=sigma_data, min_value=min_value, max_value=max_value) 255 | if sd_config['type'] == 'split-lognormal': 256 | loc = sd_config['mean'] if 'mean' in sd_config else sd_config['loc'] 257 | scale_1 = sd_config['std_1'] if 'std_1' in sd_config else sd_config['scale_1'] 258 | scale_2 = sd_config['std_2'] if 'std_2' in sd_config else sd_config['scale_2'] 259 | return partial(utils.rand_split_log_normal, loc=loc, scale_1=scale_1, scale_2=scale_2) 260 | if sd_config['type'] == 'cosine-interpolated': 261 | min_value = sd_config.get('min_value', min(config['sigma_min'], 1e-3)) 262 | max_value = sd_config.get('max_value', max(config['sigma_max'], 1e3)) 263 | image_d = sd_config.get('image_d', max(config['input_size'])) 264 | noise_d_low = sd_config.get('noise_d_low', 32) 265 | noise_d_high = sd_config.get('noise_d_high', max(config['input_size'])) 266 | return partial(utils.rand_cosine_interpolated, image_d=image_d, noise_d_low=noise_d_low, noise_d_high=noise_d_high, sigma_data=sigma_data, min_value=min_value, max_value=max_value) 267 | 268 | raise ValueError('Unknown sample density type') 269 | -------------------------------------------------------------------------------- /k_diffusion/layers.py: -------------------------------------------------------------------------------- 1 | from functools import lru_cache, reduce 2 | import math 3 | 4 | from dctorch import functional as df 5 | from einops import rearrange, repeat 6 | import torch 7 | from torch import nn 8 | from torch.nn import functional as F 9 | 10 | from . import sampling, utils 11 | 12 | 13 | # Helper functions 14 | 15 | 16 | def dct(x): 17 | if x.ndim == 3: 18 | return df.dct(x) 19 | if x.ndim == 4: 20 | return df.dct2(x) 21 | if x.ndim == 5: 22 | return df.dct3(x) 23 | raise ValueError(f'Unsupported dimensionality {x.ndim}') 24 | 25 | 26 | @lru_cache 27 | def freq_weight_1d(n, scales=0, dtype=None, device=None): 28 | ramp = torch.linspace(0.5 / n, 0.5, n, dtype=dtype, device=device) 29 | weights = -torch.log2(ramp) 30 | if scales >= 1: 31 | weights = torch.clamp_max(weights, scales) 32 | return weights 33 | 34 | 35 | @lru_cache 36 | def freq_weight_nd(shape, scales=0, dtype=None, device=None): 37 | indexers = [[slice(None) if i == j else None for j in range(len(shape))] for i in range(len(shape))] 38 | weights = [freq_weight_1d(n, scales, dtype, device)[ix] for n, ix in zip(shape, indexers)] 39 | return reduce(torch.minimum, weights) 40 | 41 | 42 | # Karras et al. preconditioned denoiser 43 | 44 | 45 | class Denoiser(nn.Module): 46 | """A Karras et al. preconditioner for denoising diffusion models.""" 47 | 48 | def __init__(self, inner_model, sigma_data=1., weighting='karras', scales=1): 49 | super().__init__() 50 | self.inner_model = inner_model 51 | self.sigma_data = sigma_data 52 | self.scales = scales 53 | if callable(weighting): 54 | self.weighting = weighting 55 | if weighting == 'karras': 56 | self.weighting = torch.ones_like 57 | elif weighting == 'soft-min-snr': 58 | self.weighting = self._weighting_soft_min_snr 59 | elif weighting == 'snr': 60 | self.weighting = self._weighting_snr 61 | else: 62 | raise ValueError(f'Unknown weighting type {weighting}') 63 | 64 | def _weighting_soft_min_snr(self, sigma): 65 | return (sigma * self.sigma_data) ** 2 / (sigma ** 2 + self.sigma_data ** 2) ** 2 66 | 67 | def _weighting_snr(self, sigma): 68 | return self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) 69 | 70 | def get_scalings(self, sigma): 71 | c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) 72 | c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 73 | c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 74 | return c_skip, c_out, c_in 75 | 76 | def loss(self, input, noise, sigma, **kwargs): 77 | c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] 78 | c_weight = self.weighting(sigma) 79 | noised_input = input + noise * utils.append_dims(sigma, input.ndim) 80 | model_output = self.inner_model(noised_input * c_in, sigma, **kwargs) 81 | target = (input - c_skip * noised_input) / c_out 82 | if self.scales == 1: 83 | return ((model_output - target) ** 2).flatten(1).mean(1) * c_weight 84 | sq_error = dct(model_output - target) ** 2 85 | f_weight = freq_weight_nd(sq_error.shape[2:], self.scales, dtype=sq_error.dtype, device=sq_error.device) 86 | return (sq_error * f_weight).flatten(1).mean(1) * c_weight 87 | 88 | def forward(self, input, sigma, **kwargs): 89 | c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] 90 | return self.inner_model(input * c_in, sigma, **kwargs) * c_out + input * c_skip 91 | 92 | 93 | class DenoiserWithVariance(Denoiser): 94 | def loss(self, input, noise, sigma, **kwargs): 95 | c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] 96 | noised_input = input + noise * utils.append_dims(sigma, input.ndim) 97 | model_output, logvar = self.inner_model(noised_input * c_in, sigma, return_variance=True, **kwargs) 98 | logvar = utils.append_dims(logvar, model_output.ndim) 99 | target = (input - c_skip * noised_input) / c_out 100 | losses = ((model_output - target) ** 2 / logvar.exp() + logvar) / 2 101 | return losses.flatten(1).mean(1) 102 | 103 | 104 | class SimpleLossDenoiser(Denoiser): 105 | """L_simple with the Karras et al. preconditioner.""" 106 | 107 | def loss(self, input, noise, sigma, **kwargs): 108 | noised_input = input + noise * utils.append_dims(sigma, input.ndim) 109 | denoised = self(noised_input, sigma, **kwargs) 110 | eps = sampling.to_d(noised_input, sigma, denoised) 111 | return (eps - noise).pow(2).flatten(1).mean(1) 112 | 113 | 114 | # Residual blocks 115 | 116 | class ResidualBlock(nn.Module): 117 | def __init__(self, *main, skip=None): 118 | super().__init__() 119 | self.main = nn.Sequential(*main) 120 | self.skip = skip if skip else nn.Identity() 121 | 122 | def forward(self, input): 123 | return self.main(input) + self.skip(input) 124 | 125 | 126 | # Noise level (and other) conditioning 127 | 128 | class ConditionedModule(nn.Module): 129 | pass 130 | 131 | 132 | class UnconditionedModule(ConditionedModule): 133 | def __init__(self, module): 134 | super().__init__() 135 | self.module = module 136 | 137 | def forward(self, input, cond=None): 138 | return self.module(input) 139 | 140 | 141 | class ConditionedSequential(nn.Sequential, ConditionedModule): 142 | def forward(self, input, cond): 143 | for module in self: 144 | if isinstance(module, ConditionedModule): 145 | input = module(input, cond) 146 | else: 147 | input = module(input) 148 | return input 149 | 150 | 151 | class ConditionedResidualBlock(ConditionedModule): 152 | def __init__(self, *main, skip=None): 153 | super().__init__() 154 | self.main = ConditionedSequential(*main) 155 | self.skip = skip if skip else nn.Identity() 156 | 157 | def forward(self, input, cond): 158 | skip = self.skip(input, cond) if isinstance(self.skip, ConditionedModule) else self.skip(input) 159 | return self.main(input, cond) + skip 160 | 161 | 162 | class AdaGN(ConditionedModule): 163 | def __init__(self, feats_in, c_out, num_groups, eps=1e-5, cond_key='cond'): 164 | super().__init__() 165 | self.num_groups = num_groups 166 | self.eps = eps 167 | self.cond_key = cond_key 168 | self.mapper = nn.Linear(feats_in, c_out * 2) 169 | nn.init.zeros_(self.mapper.weight) 170 | nn.init.zeros_(self.mapper.bias) 171 | 172 | def forward(self, input, cond): 173 | weight, bias = self.mapper(cond[self.cond_key]).chunk(2, dim=-1) 174 | input = F.group_norm(input, self.num_groups, eps=self.eps) 175 | return torch.addcmul(utils.append_dims(bias, input.ndim), input, utils.append_dims(weight, input.ndim) + 1) 176 | 177 | 178 | # Attention 179 | 180 | 181 | class SelfAttention2d(ConditionedModule): 182 | def __init__(self, c_in, n_head, norm, dropout_rate=0.): 183 | super().__init__() 184 | assert c_in % n_head == 0 185 | self.norm_in = norm(c_in) 186 | self.n_head = n_head 187 | self.qkv_proj = nn.Conv2d(c_in, c_in * 3, 1) 188 | self.out_proj = nn.Conv2d(c_in, c_in, 1) 189 | self.dropout = nn.Dropout(dropout_rate) 190 | nn.init.zeros_(self.out_proj.weight) 191 | nn.init.zeros_(self.out_proj.bias) 192 | 193 | def forward(self, input, cond): 194 | n, c, h, w = input.shape 195 | qkv = self.qkv_proj(self.norm_in(input, cond)) 196 | qkv = qkv.view([n, self.n_head * 3, c // self.n_head, h * w]).transpose(2, 3) 197 | q, k, v = qkv.chunk(3, dim=1) 198 | y = F.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout.p) 199 | y = y.transpose(2, 3).contiguous().view([n, c, h, w]) 200 | return input + self.out_proj(y) 201 | 202 | 203 | class CrossAttention2d(ConditionedModule): 204 | def __init__(self, c_dec, c_enc, n_head, norm_dec, dropout_rate=0., 205 | cond_key='cross', cond_key_padding='cross_padding'): 206 | super().__init__() 207 | assert c_dec % n_head == 0 208 | self.cond_key = cond_key 209 | self.cond_key_padding = cond_key_padding 210 | self.norm_enc = nn.LayerNorm(c_enc) 211 | self.norm_dec = norm_dec(c_dec) 212 | self.n_head = n_head 213 | self.q_proj = nn.Conv2d(c_dec, c_dec, 1) 214 | self.kv_proj = nn.Linear(c_enc, c_dec * 2) 215 | self.out_proj = nn.Conv2d(c_dec, c_dec, 1) 216 | self.dropout = nn.Dropout(dropout_rate) 217 | nn.init.zeros_(self.out_proj.weight) 218 | nn.init.zeros_(self.out_proj.bias) 219 | 220 | def forward(self, input, cond): 221 | n, c, h, w = input.shape 222 | q = self.q_proj(self.norm_dec(input, cond)) 223 | q = q.view([n, self.n_head, c // self.n_head, h * w]).transpose(2, 3) 224 | kv = self.kv_proj(self.norm_enc(cond[self.cond_key])) 225 | kv = kv.view([n, -1, self.n_head * 2, c // self.n_head]).transpose(1, 2) 226 | k, v = kv.chunk(2, dim=1) 227 | attn_mask = (cond[self.cond_key_padding][:, None, None, :]) * -10000 228 | y = F.scaled_dot_product_attention(q, k, v, attn_mask, dropout_p=self.dropout.p) 229 | y = y.transpose(2, 3).contiguous().view([n, c, h, w]) 230 | return input + self.out_proj(y) 231 | 232 | 233 | # Downsampling/upsampling 234 | 235 | _kernels = { 236 | 'linear': 237 | [1 / 8, 3 / 8, 3 / 8, 1 / 8], 238 | 'cubic': 239 | [-0.01171875, -0.03515625, 0.11328125, 0.43359375, 240 | 0.43359375, 0.11328125, -0.03515625, -0.01171875], 241 | 'lanczos3': 242 | [0.003689131001010537, 0.015056144446134567, -0.03399861603975296, 243 | -0.066637322306633, 0.13550527393817902, 0.44638532400131226, 244 | 0.44638532400131226, 0.13550527393817902, -0.066637322306633, 245 | -0.03399861603975296, 0.015056144446134567, 0.003689131001010537] 246 | } 247 | _kernels['bilinear'] = _kernels['linear'] 248 | _kernels['bicubic'] = _kernels['cubic'] 249 | 250 | 251 | class Downsample2d(nn.Module): 252 | def __init__(self, kernel='linear', pad_mode='reflect'): 253 | super().__init__() 254 | self.pad_mode = pad_mode 255 | kernel_1d = torch.tensor([_kernels[kernel]]) 256 | self.pad = kernel_1d.shape[1] // 2 - 1 257 | self.register_buffer('kernel', kernel_1d.T @ kernel_1d) 258 | 259 | def forward(self, x): 260 | x = F.pad(x, (self.pad,) * 4, self.pad_mode) 261 | weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]]) 262 | indices = torch.arange(x.shape[1], device=x.device) 263 | weight[indices, indices] = self.kernel.to(weight) 264 | return F.conv2d(x, weight, stride=2) 265 | 266 | 267 | class Upsample2d(nn.Module): 268 | def __init__(self, kernel='linear', pad_mode='reflect'): 269 | super().__init__() 270 | self.pad_mode = pad_mode 271 | kernel_1d = torch.tensor([_kernels[kernel]]) * 2 272 | self.pad = kernel_1d.shape[1] // 2 - 1 273 | self.register_buffer('kernel', kernel_1d.T @ kernel_1d) 274 | 275 | def forward(self, x): 276 | x = F.pad(x, ((self.pad + 1) // 2,) * 4, self.pad_mode) 277 | weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]]) 278 | indices = torch.arange(x.shape[1], device=x.device) 279 | weight[indices, indices] = self.kernel.to(weight) 280 | return F.conv_transpose2d(x, weight, stride=2, padding=self.pad * 2 + 1) 281 | 282 | 283 | # Embeddings 284 | 285 | class FourierFeatures(nn.Module): 286 | def __init__(self, in_features, out_features, std=1.): 287 | super().__init__() 288 | assert out_features % 2 == 0 289 | self.register_buffer('weight', torch.randn([out_features // 2, in_features]) * std) 290 | 291 | def forward(self, input): 292 | f = 2 * math.pi * input @ self.weight.T 293 | return torch.cat([f.cos(), f.sin()], dim=-1) 294 | 295 | 296 | # U-Nets 297 | 298 | class UNet(ConditionedModule): 299 | def __init__(self, d_blocks, u_blocks, skip_stages=0): 300 | super().__init__() 301 | self.d_blocks = nn.ModuleList(d_blocks) 302 | self.u_blocks = nn.ModuleList(u_blocks) 303 | self.skip_stages = skip_stages 304 | 305 | def forward(self, input, cond): 306 | skips = [] 307 | for block in self.d_blocks[self.skip_stages:]: 308 | input = block(input, cond) 309 | skips.append(input) 310 | for i, (block, skip) in enumerate(zip(self.u_blocks, reversed(skips))): 311 | input = block(input, cond, skip if i > 0 else None) 312 | return input 313 | -------------------------------------------------------------------------------- /k_diffusion/models/image_transformer_v1.py: -------------------------------------------------------------------------------- 1 | """k-diffusion transformer diffusion models, version 1.""" 2 | 3 | import math 4 | 5 | from einops import rearrange 6 | import torch 7 | from torch import nn 8 | import torch._dynamo 9 | from torch.nn import functional as F 10 | 11 | from . import flags 12 | from .. import layers 13 | from .axial_rope import AxialRoPE, make_axial_pos 14 | 15 | if flags.get_use_compile(): 16 | torch._dynamo.config.suppress_errors = True 17 | 18 | 19 | def zero_init(layer): 20 | nn.init.zeros_(layer.weight) 21 | if layer.bias is not None: 22 | nn.init.zeros_(layer.bias) 23 | return layer 24 | 25 | 26 | def checkpoint_helper(function, *args, **kwargs): 27 | if flags.get_checkpointing(): 28 | kwargs.setdefault("use_reentrant", True) 29 | return torch.utils.checkpoint.checkpoint(function, *args, **kwargs) 30 | else: 31 | return function(*args, **kwargs) 32 | 33 | 34 | def tag_param(param, tag): 35 | if not hasattr(param, "_tags"): 36 | param._tags = set([tag]) 37 | else: 38 | param._tags.add(tag) 39 | return param 40 | 41 | 42 | def tag_module(module, tag): 43 | for param in module.parameters(): 44 | tag_param(param, tag) 45 | return module 46 | 47 | 48 | def apply_wd(module): 49 | for name, param in module.named_parameters(): 50 | if name.endswith("weight"): 51 | tag_param(param, "wd") 52 | return module 53 | 54 | 55 | def filter_params(function, module): 56 | for param in module.parameters(): 57 | tags = getattr(param, "_tags", set()) 58 | if function(tags): 59 | yield param 60 | 61 | 62 | def scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0): 63 | if flags.get_use_flash_attention_2() and attn_mask is None: 64 | try: 65 | from flash_attn import flash_attn_func 66 | q_ = q.transpose(-3, -2) 67 | k_ = k.transpose(-3, -2) 68 | v_ = v.transpose(-3, -2) 69 | o_ = flash_attn_func(q_, k_, v_, dropout_p=dropout_p) 70 | return o_.transpose(-3, -2) 71 | except (ImportError, RuntimeError): 72 | pass 73 | return F.scaled_dot_product_attention(q, k, v, attn_mask, dropout_p=dropout_p) 74 | 75 | 76 | @flags.compile_wrap 77 | def geglu(x): 78 | a, b = x.chunk(2, dim=-1) 79 | return a * F.gelu(b) 80 | 81 | 82 | @flags.compile_wrap 83 | def rms_norm(x, scale, eps): 84 | dtype = torch.promote_types(x.dtype, torch.float32) 85 | mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True) 86 | scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps) 87 | return x * scale.to(x.dtype) 88 | 89 | 90 | class GEGLU(nn.Module): 91 | def forward(self, x): 92 | return geglu(x) 93 | 94 | 95 | class RMSNorm(nn.Module): 96 | def __init__(self, param_shape, eps=1e-6): 97 | super().__init__() 98 | self.eps = eps 99 | self.scale = nn.Parameter(torch.ones(param_shape)) 100 | 101 | def extra_repr(self): 102 | return f"shape={tuple(self.scale.shape)}, eps={self.eps}" 103 | 104 | def forward(self, x): 105 | return rms_norm(x, self.scale, self.eps) 106 | 107 | 108 | class QKNorm(nn.Module): 109 | def __init__(self, n_heads, eps=1e-6, max_scale=100.0): 110 | super().__init__() 111 | self.eps = eps 112 | self.max_scale = math.log(max_scale) 113 | self.scale = nn.Parameter(torch.full((n_heads,), math.log(10.0))) 114 | self.proj_() 115 | 116 | def extra_repr(self): 117 | return f"n_heads={self.scale.shape[0]}, eps={self.eps}" 118 | 119 | @torch.no_grad() 120 | def proj_(self): 121 | """Modify the scale in-place so it doesn't get "stuck" with zero gradient if it's clamped 122 | to the max value.""" 123 | self.scale.clamp_(max=self.max_scale) 124 | 125 | def forward(self, x): 126 | self.proj_() 127 | scale = torch.exp(0.5 * self.scale - 0.25 * math.log(x.shape[-1])) 128 | return rms_norm(x, scale[:, None, None], self.eps) 129 | 130 | 131 | class AdaRMSNorm(nn.Module): 132 | def __init__(self, features, cond_features, eps=1e-6): 133 | super().__init__() 134 | self.eps = eps 135 | self.linear = apply_wd(zero_init(nn.Linear(cond_features, features, bias=False))) 136 | tag_module(self.linear, "mapping") 137 | 138 | def extra_repr(self): 139 | return f"eps={self.eps}," 140 | 141 | def forward(self, x, cond): 142 | return rms_norm(x, self.linear(cond) + 1, self.eps) 143 | 144 | 145 | class SelfAttentionBlock(nn.Module): 146 | def __init__(self, d_model, d_head, dropout=0.0): 147 | super().__init__() 148 | self.d_head = d_head 149 | self.n_heads = d_model // d_head 150 | self.norm = AdaRMSNorm(d_model, d_model) 151 | self.qkv_proj = apply_wd(nn.Linear(d_model, d_model * 3, bias=False)) 152 | self.qk_norm = QKNorm(self.n_heads) 153 | self.pos_emb = AxialRoPE(d_head, self.n_heads) 154 | self.dropout = nn.Dropout(dropout) 155 | self.out_proj = apply_wd(zero_init(nn.Linear(d_model, d_model, bias=False))) 156 | 157 | def extra_repr(self): 158 | return f"d_head={self.d_head}," 159 | 160 | def forward(self, x, pos, attn_mask, cond): 161 | skip = x 162 | x = self.norm(x, cond) 163 | q, k, v = self.qkv_proj(x).chunk(3, dim=-1) 164 | q = rearrange(q, "n l (h e) -> n h l e", e=self.d_head) 165 | k = rearrange(k, "n l (h e) -> n h l e", e=self.d_head) 166 | v = rearrange(v, "n l (h e) -> n h l e", e=self.d_head) 167 | q = self.pos_emb(self.qk_norm(q), pos) 168 | k = self.pos_emb(self.qk_norm(k), pos) 169 | x = scaled_dot_product_attention(q, k, v, attn_mask) 170 | x = rearrange(x, "n h l e -> n l (h e)") 171 | x = self.dropout(x) 172 | x = self.out_proj(x) 173 | return x + skip 174 | 175 | 176 | class FeedForwardBlock(nn.Module): 177 | def __init__(self, d_model, d_ff, dropout=0.0): 178 | super().__init__() 179 | self.norm = AdaRMSNorm(d_model, d_model) 180 | self.up_proj = apply_wd(nn.Linear(d_model, d_ff * 2, bias=False)) 181 | self.act = GEGLU() 182 | self.dropout = nn.Dropout(dropout) 183 | self.down_proj = apply_wd(zero_init(nn.Linear(d_ff, d_model, bias=False))) 184 | 185 | def forward(self, x, cond): 186 | skip = x 187 | x = self.norm(x, cond) 188 | x = self.up_proj(x) 189 | x = self.act(x) 190 | x = self.dropout(x) 191 | x = self.down_proj(x) 192 | return x + skip 193 | 194 | 195 | class TransformerBlock(nn.Module): 196 | def __init__(self, d_model, d_ff, d_head, dropout=0.0): 197 | super().__init__() 198 | self.self_attn = SelfAttentionBlock(d_model, d_head, dropout=dropout) 199 | self.ff = FeedForwardBlock(d_model, d_ff, dropout=dropout) 200 | 201 | def forward(self, x, pos, attn_mask, cond): 202 | x = checkpoint_helper(self.self_attn, x, pos, attn_mask, cond) 203 | x = checkpoint_helper(self.ff, x, cond) 204 | return x 205 | 206 | 207 | class Patching(nn.Module): 208 | def __init__(self, features, patch_size): 209 | super().__init__() 210 | self.features = features 211 | self.patch_size = patch_size 212 | self.d_out = features * patch_size[0] * patch_size[1] 213 | 214 | def extra_repr(self): 215 | return f"features={self.features}, patch_size={self.patch_size!r}" 216 | 217 | def forward(self, x, pixel_aspect_ratio=1.0): 218 | *_, h, w = x.shape 219 | h_out = h // self.patch_size[0] 220 | w_out = w // self.patch_size[1] 221 | if h % self.patch_size[0] != 0 or w % self.patch_size[1] != 0: 222 | raise ValueError(f"Image size {h}x{w} is not divisible by patch size {self.patch_size[0]}x{self.patch_size[1]}") 223 | x = rearrange(x, "... c (h i) (w j) -> ... (h w) (c i j)", i=self.patch_size[0], j=self.patch_size[1]) 224 | pixel_aspect_ratio = pixel_aspect_ratio * self.patch_size[0] / self.patch_size[1] 225 | pos = make_axial_pos(h_out, w_out, pixel_aspect_ratio, device=x.device) 226 | return x, pos 227 | 228 | 229 | class Unpatching(nn.Module): 230 | def __init__(self, features, patch_size): 231 | super().__init__() 232 | self.features = features 233 | self.patch_size = patch_size 234 | self.d_in = features * patch_size[0] * patch_size[1] 235 | 236 | def extra_repr(self): 237 | return f"features={self.features}, patch_size={self.patch_size!r}" 238 | 239 | def forward(self, x, h, w): 240 | h_in = h // self.patch_size[0] 241 | w_in = w // self.patch_size[1] 242 | x = rearrange(x, "... (h w) (c i j) -> ... c (h i) (w j)", h=h_in, w=w_in, i=self.patch_size[0], j=self.patch_size[1]) 243 | return x 244 | 245 | 246 | class MappingFeedForwardBlock(nn.Module): 247 | def __init__(self, d_model, d_ff, dropout=0.0): 248 | super().__init__() 249 | self.norm = RMSNorm(d_model) 250 | self.up_proj = apply_wd(nn.Linear(d_model, d_ff * 2, bias=False)) 251 | self.act = GEGLU() 252 | self.dropout = nn.Dropout(dropout) 253 | self.down_proj = apply_wd(zero_init(nn.Linear(d_ff, d_model, bias=False))) 254 | 255 | def forward(self, x): 256 | skip = x 257 | x = self.norm(x) 258 | x = self.up_proj(x) 259 | x = self.act(x) 260 | x = self.dropout(x) 261 | x = self.down_proj(x) 262 | return x + skip 263 | 264 | 265 | class MappingNetwork(nn.Module): 266 | def __init__(self, n_layers, d_model, d_ff, dropout=0.0): 267 | super().__init__() 268 | self.in_norm = RMSNorm(d_model) 269 | self.blocks = nn.ModuleList([MappingFeedForwardBlock(d_model, d_ff, dropout=dropout) for _ in range(n_layers)]) 270 | self.out_norm = RMSNorm(d_model) 271 | 272 | def forward(self, x): 273 | x = self.in_norm(x) 274 | for block in self.blocks: 275 | x = block(x) 276 | x = self.out_norm(x) 277 | return x 278 | 279 | 280 | class ImageTransformerDenoiserModelV1(nn.Module): 281 | def __init__(self, n_layers, d_model, d_ff, in_features, out_features, patch_size, num_classes=0, dropout=0.0, sigma_data=1.0): 282 | super().__init__() 283 | self.sigma_data = sigma_data 284 | self.num_classes = num_classes 285 | self.patch_in = Patching(in_features, patch_size) 286 | self.patch_out = Unpatching(out_features, patch_size) 287 | 288 | self.time_emb = layers.FourierFeatures(1, d_model) 289 | self.time_in_proj = nn.Linear(d_model, d_model, bias=False) 290 | self.aug_emb = layers.FourierFeatures(9, d_model) 291 | self.aug_in_proj = nn.Linear(d_model, d_model, bias=False) 292 | self.class_emb = nn.Embedding(num_classes, d_model) if num_classes else None 293 | self.mapping = tag_module(MappingNetwork(2, d_model, d_ff, dropout=dropout), "mapping") 294 | 295 | self.in_proj = nn.Linear(self.patch_in.d_out, d_model, bias=False) 296 | self.blocks = nn.ModuleList([TransformerBlock(d_model, d_ff, 64, dropout=dropout) for _ in range(n_layers)]) 297 | self.out_norm = RMSNorm(d_model) 298 | self.out_proj = zero_init(nn.Linear(d_model, self.patch_out.d_in, bias=False)) 299 | 300 | def proj_(self): 301 | for block in self.blocks: 302 | block.self_attn.qk_norm.proj_() 303 | 304 | def param_groups(self, base_lr=5e-4, mapping_lr_scale=1 / 3): 305 | wd = filter_params(lambda tags: "wd" in tags and "mapping" not in tags, self) 306 | no_wd = filter_params(lambda tags: "wd" not in tags and "mapping" not in tags, self) 307 | mapping_wd = filter_params(lambda tags: "wd" in tags and "mapping" in tags, self) 308 | mapping_no_wd = filter_params(lambda tags: "wd" not in tags and "mapping" in tags, self) 309 | groups = [ 310 | {"params": list(wd), "lr": base_lr}, 311 | {"params": list(no_wd), "lr": base_lr, "weight_decay": 0.0}, 312 | {"params": list(mapping_wd), "lr": base_lr * mapping_lr_scale}, 313 | {"params": list(mapping_no_wd), "lr": base_lr * mapping_lr_scale, "weight_decay": 0.0} 314 | ] 315 | return groups 316 | 317 | def forward(self, x, sigma, aug_cond=None, class_cond=None): 318 | # Patching 319 | *_, h, w = x.shape 320 | x, pos = self.patch_in(x) 321 | attn_mask = None 322 | x = self.in_proj(x) 323 | 324 | # Mapping network 325 | if class_cond is None and self.class_emb is not None: 326 | raise ValueError("class_cond must be specified if num_classes > 0") 327 | 328 | c_noise = torch.log(sigma) / 4 329 | time_emb = self.time_in_proj(self.time_emb(c_noise[..., None])) 330 | aug_cond = x.new_zeros([x.shape[0], 9]) if aug_cond is None else aug_cond 331 | aug_emb = self.aug_in_proj(self.aug_emb(aug_cond)) 332 | class_emb = self.class_emb(class_cond) if self.class_emb is not None else 0 333 | cond = self.mapping(time_emb + aug_emb + class_emb).unsqueeze(-2) 334 | 335 | # Transformer 336 | for block in self.blocks: 337 | x = block(x, pos, attn_mask, cond) 338 | 339 | # Unpatching 340 | x = self.out_norm(x) 341 | x = self.out_proj(x) 342 | x = self.patch_out(x, h, w) 343 | 344 | return x 345 | -------------------------------------------------------------------------------- /k_diffusion/utils.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | import hashlib 3 | import math 4 | from pathlib import Path 5 | import shutil 6 | import threading 7 | import time 8 | import urllib 9 | import warnings 10 | 11 | from PIL import Image 12 | import safetensors 13 | import torch 14 | from torch import nn, optim 15 | from torch.utils import data 16 | from torchvision.transforms import functional as TF 17 | 18 | 19 | def from_pil_image(x): 20 | """Converts from a PIL image to a tensor.""" 21 | x = TF.to_tensor(x) 22 | if x.ndim == 2: 23 | x = x[..., None] 24 | return x * 2 - 1 25 | 26 | 27 | def to_pil_image(x): 28 | """Converts from a tensor to a PIL image.""" 29 | if x.ndim == 4: 30 | assert x.shape[0] == 1 31 | x = x[0] 32 | if x.shape[0] == 1: 33 | x = x[0] 34 | return TF.to_pil_image((x.clamp(-1, 1) + 1) / 2) 35 | 36 | 37 | def hf_datasets_augs_helper(examples, transform, image_key, mode='RGB'): 38 | """Apply passed in transforms for HuggingFace Datasets.""" 39 | images = [transform(image.convert(mode)) for image in examples[image_key]] 40 | return {image_key: images} 41 | 42 | 43 | def append_dims(x, target_dims): 44 | """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" 45 | dims_to_append = target_dims - x.ndim 46 | if dims_to_append < 0: 47 | raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') 48 | return x[(...,) + (None,) * dims_to_append] 49 | 50 | 51 | def n_params(module): 52 | """Returns the number of trainable parameters in a module.""" 53 | return sum(p.numel() for p in module.parameters()) 54 | 55 | 56 | def download_file(path, url, digest=None): 57 | """Downloads a file if it does not exist, optionally checking its SHA-256 hash.""" 58 | path = Path(path) 59 | path.parent.mkdir(parents=True, exist_ok=True) 60 | if not path.exists(): 61 | with urllib.request.urlopen(url) as response, open(path, 'wb') as f: 62 | shutil.copyfileobj(response, f) 63 | if digest is not None: 64 | file_digest = hashlib.sha256(open(path, 'rb').read()).hexdigest() 65 | if digest != file_digest: 66 | raise OSError(f'hash of {path} (url: {url}) failed to validate') 67 | return path 68 | 69 | 70 | @contextmanager 71 | def train_mode(model, mode=True): 72 | """A context manager that places a model into training mode and restores 73 | the previous mode on exit.""" 74 | modes = [module.training for module in model.modules()] 75 | try: 76 | yield model.train(mode) 77 | finally: 78 | for i, module in enumerate(model.modules()): 79 | module.training = modes[i] 80 | 81 | 82 | def eval_mode(model): 83 | """A context manager that places a model into evaluation mode and restores 84 | the previous mode on exit.""" 85 | return train_mode(model, False) 86 | 87 | 88 | @torch.no_grad() 89 | def ema_update(model, averaged_model, decay): 90 | """Incorporates updated model parameters into an exponential moving averaged 91 | version of a model. It should be called after each optimizer step.""" 92 | model_params = dict(model.named_parameters()) 93 | averaged_params = dict(averaged_model.named_parameters()) 94 | assert model_params.keys() == averaged_params.keys() 95 | 96 | for name, param in model_params.items(): 97 | averaged_params[name].lerp_(param, 1 - decay) 98 | 99 | model_buffers = dict(model.named_buffers()) 100 | averaged_buffers = dict(averaged_model.named_buffers()) 101 | assert model_buffers.keys() == averaged_buffers.keys() 102 | 103 | for name, buf in model_buffers.items(): 104 | averaged_buffers[name].copy_(buf) 105 | 106 | 107 | class EMAWarmup: 108 | """Implements an EMA warmup using an inverse decay schedule. 109 | If inv_gamma=1 and power=1, implements a simple average. inv_gamma=1, power=2/3 are 110 | good values for models you plan to train for a million or more steps (reaches decay 111 | factor 0.999 at 31.6K steps, 0.9999 at 1M steps), inv_gamma=1, power=3/4 for models 112 | you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at 113 | 215.4k steps). 114 | Args: 115 | inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. 116 | power (float): Exponential factor of EMA warmup. Default: 1. 117 | min_value (float): The minimum EMA decay rate. Default: 0. 118 | max_value (float): The maximum EMA decay rate. Default: 1. 119 | start_at (int): The epoch to start averaging at. Default: 0. 120 | last_epoch (int): The index of last epoch. Default: 0. 121 | """ 122 | 123 | def __init__(self, inv_gamma=1., power=1., min_value=0., max_value=1., start_at=0, 124 | last_epoch=0): 125 | self.inv_gamma = inv_gamma 126 | self.power = power 127 | self.min_value = min_value 128 | self.max_value = max_value 129 | self.start_at = start_at 130 | self.last_epoch = last_epoch 131 | 132 | def state_dict(self): 133 | """Returns the state of the class as a :class:`dict`.""" 134 | return dict(self.__dict__.items()) 135 | 136 | def load_state_dict(self, state_dict): 137 | """Loads the class's state. 138 | Args: 139 | state_dict (dict): scaler state. Should be an object returned 140 | from a call to :meth:`state_dict`. 141 | """ 142 | self.__dict__.update(state_dict) 143 | 144 | def get_value(self): 145 | """Gets the current EMA decay rate.""" 146 | epoch = max(0, self.last_epoch - self.start_at) 147 | value = 1 - (1 + epoch / self.inv_gamma) ** -self.power 148 | return 0. if epoch < 0 else min(self.max_value, max(self.min_value, value)) 149 | 150 | def step(self): 151 | """Updates the step count.""" 152 | self.last_epoch += 1 153 | 154 | 155 | class InverseLR(optim.lr_scheduler._LRScheduler): 156 | """Implements an inverse decay learning rate schedule with an optional exponential 157 | warmup. When last_epoch=-1, sets initial lr as lr. 158 | inv_gamma is the number of steps/epochs required for the learning rate to decay to 159 | (1 / 2)**power of its original value. 160 | Args: 161 | optimizer (Optimizer): Wrapped optimizer. 162 | inv_gamma (float): Inverse multiplicative factor of learning rate decay. Default: 1. 163 | power (float): Exponential factor of learning rate decay. Default: 1. 164 | warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable) 165 | Default: 0. 166 | min_lr (float): The minimum learning rate. Default: 0. 167 | last_epoch (int): The index of last epoch. Default: -1. 168 | verbose (bool): If ``True``, prints a message to stdout for 169 | each update. Default: ``False``. 170 | """ 171 | 172 | def __init__(self, optimizer, inv_gamma=1., power=1., warmup=0., min_lr=0., 173 | last_epoch=-1, verbose=False): 174 | self.inv_gamma = inv_gamma 175 | self.power = power 176 | if not 0. <= warmup < 1: 177 | raise ValueError('Invalid value for warmup') 178 | self.warmup = warmup 179 | self.min_lr = min_lr 180 | super().__init__(optimizer, last_epoch, verbose) 181 | 182 | def get_lr(self): 183 | if not self._get_lr_called_within_step: 184 | warnings.warn("To get the last learning rate computed by the scheduler, " 185 | "please use `get_last_lr()`.") 186 | 187 | return self._get_closed_form_lr() 188 | 189 | def _get_closed_form_lr(self): 190 | warmup = 1 - self.warmup ** (self.last_epoch + 1) 191 | lr_mult = (1 + self.last_epoch / self.inv_gamma) ** -self.power 192 | return [warmup * max(self.min_lr, base_lr * lr_mult) 193 | for base_lr in self.base_lrs] 194 | 195 | 196 | class ExponentialLR(optim.lr_scheduler._LRScheduler): 197 | """Implements an exponential learning rate schedule with an optional exponential 198 | warmup. When last_epoch=-1, sets initial lr as lr. Decays the learning rate 199 | continuously by decay (default 0.5) every num_steps steps. 200 | Args: 201 | optimizer (Optimizer): Wrapped optimizer. 202 | num_steps (float): The number of steps to decay the learning rate by decay in. 203 | decay (float): The factor by which to decay the learning rate every num_steps 204 | steps. Default: 0.5. 205 | warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable) 206 | Default: 0. 207 | min_lr (float): The minimum learning rate. Default: 0. 208 | last_epoch (int): The index of last epoch. Default: -1. 209 | verbose (bool): If ``True``, prints a message to stdout for 210 | each update. Default: ``False``. 211 | """ 212 | 213 | def __init__(self, optimizer, num_steps, decay=0.5, warmup=0., min_lr=0., 214 | last_epoch=-1, verbose=False): 215 | self.num_steps = num_steps 216 | self.decay = decay 217 | if not 0. <= warmup < 1: 218 | raise ValueError('Invalid value for warmup') 219 | self.warmup = warmup 220 | self.min_lr = min_lr 221 | super().__init__(optimizer, last_epoch, verbose) 222 | 223 | def get_lr(self): 224 | if not self._get_lr_called_within_step: 225 | warnings.warn("To get the last learning rate computed by the scheduler, " 226 | "please use `get_last_lr()`.") 227 | 228 | return self._get_closed_form_lr() 229 | 230 | def _get_closed_form_lr(self): 231 | warmup = 1 - self.warmup ** (self.last_epoch + 1) 232 | lr_mult = (self.decay ** (1 / self.num_steps)) ** self.last_epoch 233 | return [warmup * max(self.min_lr, base_lr * lr_mult) 234 | for base_lr in self.base_lrs] 235 | 236 | 237 | class ConstantLRWithWarmup(optim.lr_scheduler._LRScheduler): 238 | """Implements a constant learning rate schedule with an optional exponential 239 | warmup. When last_epoch=-1, sets initial lr as lr. 240 | Args: 241 | optimizer (Optimizer): Wrapped optimizer. 242 | warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable) 243 | Default: 0. 244 | last_epoch (int): The index of last epoch. Default: -1. 245 | verbose (bool): If ``True``, prints a message to stdout for 246 | each update. Default: ``False``. 247 | """ 248 | 249 | def __init__(self, optimizer, warmup=0., last_epoch=-1, verbose=False): 250 | if not 0. <= warmup < 1: 251 | raise ValueError('Invalid value for warmup') 252 | self.warmup = warmup 253 | super().__init__(optimizer, last_epoch, verbose) 254 | 255 | def get_lr(self): 256 | if not self._get_lr_called_within_step: 257 | warnings.warn("To get the last learning rate computed by the scheduler, " 258 | "please use `get_last_lr()`.") 259 | 260 | return self._get_closed_form_lr() 261 | 262 | def _get_closed_form_lr(self): 263 | warmup = 1 - self.warmup ** (self.last_epoch + 1) 264 | return [warmup * base_lr for base_lr in self.base_lrs] 265 | 266 | 267 | def stratified_uniform(shape, group=0, groups=1, dtype=None, device=None): 268 | """Draws stratified samples from a uniform distribution.""" 269 | if groups <= 0: 270 | raise ValueError(f"groups must be positive, got {groups}") 271 | if group < 0 or group >= groups: 272 | raise ValueError(f"group must be in [0, {groups})") 273 | n = shape[-1] * groups 274 | offsets = torch.arange(group, n, groups, dtype=dtype, device=device) 275 | u = torch.rand(shape, dtype=dtype, device=device) 276 | return (offsets + u) / n 277 | 278 | 279 | stratified_settings = threading.local() 280 | 281 | 282 | @contextmanager 283 | def enable_stratified(group=0, groups=1, disable=False): 284 | """A context manager that enables stratified sampling.""" 285 | try: 286 | stratified_settings.disable = disable 287 | stratified_settings.group = group 288 | stratified_settings.groups = groups 289 | yield 290 | finally: 291 | del stratified_settings.disable 292 | del stratified_settings.group 293 | del stratified_settings.groups 294 | 295 | 296 | @contextmanager 297 | def enable_stratified_accelerate(accelerator, disable=False): 298 | """A context manager that enables stratified sampling, distributing the strata across 299 | all processes and gradient accumulation steps using settings from Hugging Face Accelerate.""" 300 | try: 301 | rank = accelerator.process_index 302 | world_size = accelerator.num_processes 303 | acc_steps = accelerator.gradient_state.num_steps 304 | acc_step = accelerator.step % acc_steps 305 | group = rank * acc_steps + acc_step 306 | groups = world_size * acc_steps 307 | with enable_stratified(group, groups, disable=disable): 308 | yield 309 | finally: 310 | pass 311 | 312 | 313 | def stratified_with_settings(shape, dtype=None, device=None): 314 | """Draws stratified samples from a uniform distribution, using settings from a context 315 | manager.""" 316 | if not hasattr(stratified_settings, 'disable') or stratified_settings.disable: 317 | return torch.rand(shape, dtype=dtype, device=device) 318 | return stratified_uniform( 319 | shape, stratified_settings.group, stratified_settings.groups, dtype=dtype, device=device 320 | ) 321 | 322 | 323 | def rand_log_normal(shape, loc=0., scale=1., device='cpu', dtype=torch.float32): 324 | """Draws samples from an lognormal distribution.""" 325 | u = stratified_with_settings(shape, device=device, dtype=dtype) * (1 - 2e-7) + 1e-7 326 | return torch.distributions.Normal(loc, scale).icdf(u).exp() 327 | 328 | 329 | def rand_log_logistic(shape, loc=0., scale=1., min_value=0., max_value=float('inf'), device='cpu', dtype=torch.float32): 330 | """Draws samples from an optionally truncated log-logistic distribution.""" 331 | min_value = torch.as_tensor(min_value, device=device, dtype=torch.float64) 332 | max_value = torch.as_tensor(max_value, device=device, dtype=torch.float64) 333 | min_cdf = min_value.log().sub(loc).div(scale).sigmoid() 334 | max_cdf = max_value.log().sub(loc).div(scale).sigmoid() 335 | u = stratified_with_settings(shape, device=device, dtype=torch.float64) * (max_cdf - min_cdf) + min_cdf 336 | return u.logit().mul(scale).add(loc).exp().to(dtype) 337 | 338 | 339 | def rand_log_uniform(shape, min_value, max_value, device='cpu', dtype=torch.float32): 340 | """Draws samples from an log-uniform distribution.""" 341 | min_value = math.log(min_value) 342 | max_value = math.log(max_value) 343 | return (stratified_with_settings(shape, device=device, dtype=dtype) * (max_value - min_value) + min_value).exp() 344 | 345 | 346 | def rand_v_diffusion(shape, sigma_data=1., min_value=0., max_value=float('inf'), device='cpu', dtype=torch.float32): 347 | """Draws samples from a truncated v-diffusion training timestep distribution.""" 348 | min_cdf = math.atan(min_value / sigma_data) * 2 / math.pi 349 | max_cdf = math.atan(max_value / sigma_data) * 2 / math.pi 350 | u = stratified_with_settings(shape, device=device, dtype=dtype) * (max_cdf - min_cdf) + min_cdf 351 | return torch.tan(u * math.pi / 2) * sigma_data 352 | 353 | 354 | def rand_cosine_interpolated(shape, image_d, noise_d_low, noise_d_high, sigma_data=1., min_value=1e-3, max_value=1e3, device='cpu', dtype=torch.float32): 355 | """Draws samples from an interpolated cosine timestep distribution (from simple diffusion).""" 356 | 357 | def logsnr_schedule_cosine(t, logsnr_min, logsnr_max): 358 | t_min = math.atan(math.exp(-0.5 * logsnr_max)) 359 | t_max = math.atan(math.exp(-0.5 * logsnr_min)) 360 | return -2 * torch.log(torch.tan(t_min + t * (t_max - t_min))) 361 | 362 | def logsnr_schedule_cosine_shifted(t, image_d, noise_d, logsnr_min, logsnr_max): 363 | shift = 2 * math.log(noise_d / image_d) 364 | return logsnr_schedule_cosine(t, logsnr_min - shift, logsnr_max - shift) + shift 365 | 366 | def logsnr_schedule_cosine_interpolated(t, image_d, noise_d_low, noise_d_high, logsnr_min, logsnr_max): 367 | logsnr_low = logsnr_schedule_cosine_shifted(t, image_d, noise_d_low, logsnr_min, logsnr_max) 368 | logsnr_high = logsnr_schedule_cosine_shifted(t, image_d, noise_d_high, logsnr_min, logsnr_max) 369 | return torch.lerp(logsnr_low, logsnr_high, t) 370 | 371 | logsnr_min = -2 * math.log(min_value / sigma_data) 372 | logsnr_max = -2 * math.log(max_value / sigma_data) 373 | u = stratified_with_settings(shape, device=device, dtype=dtype) 374 | logsnr = logsnr_schedule_cosine_interpolated(u, image_d, noise_d_low, noise_d_high, logsnr_min, logsnr_max) 375 | return torch.exp(-logsnr / 2) * sigma_data 376 | 377 | 378 | def rand_split_log_normal(shape, loc, scale_1, scale_2, device='cpu', dtype=torch.float32): 379 | """Draws samples from a split lognormal distribution.""" 380 | n = torch.randn(shape, device=device, dtype=dtype).abs() 381 | u = torch.rand(shape, device=device, dtype=dtype) 382 | n_left = n * -scale_1 + loc 383 | n_right = n * scale_2 + loc 384 | ratio = scale_1 / (scale_1 + scale_2) 385 | return torch.where(u < ratio, n_left, n_right).exp() 386 | 387 | 388 | class FolderOfImages(data.Dataset): 389 | """Recursively finds all images in a directory. It does not support 390 | classes/targets.""" 391 | 392 | IMG_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp'} 393 | 394 | def __init__(self, root, transform=None): 395 | super().__init__() 396 | self.root = Path(root) 397 | self.transform = nn.Identity() if transform is None else transform 398 | self.paths = sorted(path for path in self.root.rglob('*') if path.suffix.lower() in self.IMG_EXTENSIONS) 399 | 400 | def __repr__(self): 401 | return f'FolderOfImages(root="{self.root}", len: {len(self)})' 402 | 403 | def __len__(self): 404 | return len(self.paths) 405 | 406 | def __getitem__(self, key): 407 | path = self.paths[key] 408 | with open(path, 'rb') as f: 409 | image = Image.open(f).convert('RGB') 410 | image = self.transform(image) 411 | return image, 412 | 413 | 414 | class CSVLogger: 415 | def __init__(self, filename, columns): 416 | self.filename = Path(filename) 417 | self.columns = columns 418 | if self.filename.exists(): 419 | self.file = open(self.filename, 'a') 420 | else: 421 | self.file = open(self.filename, 'w') 422 | self.write(*self.columns) 423 | 424 | def write(self, *args): 425 | print(*args, sep=',', file=self.file, flush=True) 426 | 427 | 428 | @contextmanager 429 | def tf32_mode(cudnn=None, matmul=None): 430 | """A context manager that sets whether TF32 is allowed on cuDNN or matmul.""" 431 | cudnn_old = torch.backends.cudnn.allow_tf32 432 | matmul_old = torch.backends.cuda.matmul.allow_tf32 433 | try: 434 | if cudnn is not None: 435 | torch.backends.cudnn.allow_tf32 = cudnn 436 | if matmul is not None: 437 | torch.backends.cuda.matmul.allow_tf32 = matmul 438 | yield 439 | finally: 440 | if cudnn is not None: 441 | torch.backends.cudnn.allow_tf32 = cudnn_old 442 | if matmul is not None: 443 | torch.backends.cuda.matmul.allow_tf32 = matmul_old 444 | 445 | 446 | def get_safetensors_metadata(path): 447 | """Retrieves the metadata from a safetensors file.""" 448 | return safetensors.safe_open(path, "pt").metadata() 449 | 450 | 451 | def ema_update_dict(values, updates, decay): 452 | for k, v in updates.items(): 453 | if k not in values: 454 | values[k] = v 455 | else: 456 | values[k] *= decay 457 | values[k] += (1 - decay) * v 458 | return values 459 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """Trains Karras et al. (2022) diffusion models.""" 4 | 5 | import argparse 6 | from copy import deepcopy 7 | from functools import partial 8 | import importlib.util 9 | import math 10 | import json 11 | from pathlib import Path 12 | import time 13 | 14 | import accelerate 15 | import safetensors.torch as safetorch 16 | import torch 17 | import torch._dynamo 18 | from torch import distributed as dist 19 | from torch import multiprocessing as mp 20 | from torch import optim 21 | from torch.utils import data, flop_counter 22 | from torchvision import datasets, transforms, utils 23 | from tqdm.auto import tqdm 24 | 25 | import k_diffusion as K 26 | 27 | 28 | def ensure_distributed(): 29 | if not dist.is_initialized(): 30 | dist.init_process_group(world_size=1, rank=0, store=dist.HashStore()) 31 | 32 | 33 | def main(): 34 | p = argparse.ArgumentParser(description=__doc__, 35 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 36 | p.add_argument('--batch-size', type=int, default=64, 37 | help='the batch size') 38 | p.add_argument('--checkpointing', action='store_true', 39 | help='enable gradient checkpointing') 40 | p.add_argument('--clip-model', type=str, default='ViT-B/16', 41 | choices=K.evaluation.CLIPFeatureExtractor.available_models(), 42 | help='the CLIP model to use to evaluate') 43 | p.add_argument('--compile', action='store_true', 44 | help='compile the model') 45 | p.add_argument('--config', type=str, required=True, 46 | help='the configuration file') 47 | p.add_argument('--demo-every', type=int, default=500, 48 | help='save a demo grid every this many steps') 49 | p.add_argument('--dinov2-model', type=str, default='vitl14', 50 | choices=K.evaluation.DINOv2FeatureExtractor.available_models(), 51 | help='the DINOv2 model to use to evaluate') 52 | p.add_argument('--end-step', type=int, default=None, 53 | help='the step to end training at') 54 | p.add_argument('--evaluate-every', type=int, default=10000, 55 | help='evaluate every this many steps') 56 | p.add_argument('--evaluate-n', type=int, default=2000, 57 | help='the number of samples to draw to evaluate') 58 | p.add_argument('--evaluate-only', action='store_true', 59 | help='evaluate instead of training') 60 | p.add_argument('--evaluate-with', type=str, default='inception', 61 | choices=['inception', 'clip', 'dinov2'], 62 | help='the feature extractor to use for evaluation') 63 | p.add_argument('--gns', action='store_true', 64 | help='measure the gradient noise scale (DDP only, disables stratified sampling)') 65 | p.add_argument('--grad-accum-steps', type=int, default=1, 66 | help='the number of gradient accumulation steps') 67 | p.add_argument('--lr', type=float, 68 | help='the learning rate') 69 | p.add_argument('--mixed-precision', type=str, 70 | help='the mixed precision type') 71 | p.add_argument('--name', type=str, default='model', 72 | help='the name of the run') 73 | p.add_argument('--num-workers', type=int, default=8, 74 | help='the number of data loader workers') 75 | p.add_argument('--reset-ema', action='store_true', 76 | help='reset the EMA') 77 | p.add_argument('--resume', type=str, 78 | help='the checkpoint to resume from') 79 | p.add_argument('--resume-inference', type=str, 80 | help='the inference checkpoint to resume from') 81 | p.add_argument('--sample-n', type=int, default=64, 82 | help='the number of images to sample for demo grids') 83 | p.add_argument('--save-every', type=int, default=10000, 84 | help='save every this many steps') 85 | p.add_argument('--seed', type=int, 86 | help='the random seed') 87 | p.add_argument('--start-method', type=str, default='spawn', 88 | choices=['fork', 'forkserver', 'spawn'], 89 | help='the multiprocessing start method') 90 | p.add_argument('--wandb-entity', type=str, 91 | help='the wandb entity name') 92 | p.add_argument('--wandb-group', type=str, 93 | help='the wandb group name') 94 | p.add_argument('--wandb-project', type=str, 95 | help='the wandb project name (specify this to enable wandb)') 96 | p.add_argument('--wandb-save-model', action='store_true', 97 | help='save model to wandb') 98 | args = p.parse_args() 99 | 100 | mp.set_start_method(args.start_method) 101 | torch.backends.cuda.matmul.allow_tf32 = True 102 | try: 103 | torch._dynamo.config.automatic_dynamic_shapes = False 104 | except AttributeError: 105 | pass 106 | 107 | config = K.config.load_config(args.config) 108 | model_config = config['model'] 109 | dataset_config = config['dataset'] 110 | opt_config = config['optimizer'] 111 | sched_config = config['lr_sched'] 112 | ema_sched_config = config['ema_sched'] 113 | 114 | # TODO: allow non-square input sizes 115 | assert len(model_config['input_size']) == 2 and model_config['input_size'][0] == model_config['input_size'][1] 116 | size = model_config['input_size'] 117 | 118 | accelerator = accelerate.Accelerator(gradient_accumulation_steps=args.grad_accum_steps, mixed_precision=args.mixed_precision) 119 | ensure_distributed() 120 | device = accelerator.device 121 | unwrap = accelerator.unwrap_model 122 | print(f'Process {accelerator.process_index} using device: {device}', flush=True) 123 | accelerator.wait_for_everyone() 124 | if accelerator.is_main_process: 125 | print(f'World size: {accelerator.num_processes}', flush=True) 126 | print(f'Batch size: {args.batch_size * accelerator.num_processes}', flush=True) 127 | 128 | if args.seed is not None: 129 | seeds = torch.randint(-2 ** 63, 2 ** 63 - 1, [accelerator.num_processes], generator=torch.Generator().manual_seed(args.seed)) 130 | torch.manual_seed(seeds[accelerator.process_index]) 131 | demo_gen = torch.Generator().manual_seed(torch.randint(-2 ** 63, 2 ** 63 - 1, ()).item()) 132 | elapsed = 0.0 133 | 134 | inner_model = K.config.make_model(config) 135 | inner_model_ema = deepcopy(inner_model) 136 | 137 | if args.compile: 138 | inner_model.compile() 139 | # inner_model_ema.compile() 140 | 141 | if accelerator.is_main_process: 142 | print(f'Parameters: {K.utils.n_params(inner_model):,}') 143 | 144 | # If logging to wandb, initialize the run 145 | use_wandb = accelerator.is_main_process and args.wandb_project 146 | if use_wandb: 147 | import wandb 148 | log_config = vars(args) 149 | log_config['config'] = config 150 | log_config['parameters'] = K.utils.n_params(inner_model) 151 | wandb.init(project=args.wandb_project, entity=args.wandb_entity, group=args.wandb_group, config=log_config, save_code=True) 152 | 153 | lr = opt_config['lr'] if args.lr is None else args.lr 154 | groups = inner_model.param_groups(lr) 155 | if opt_config['type'] == 'adamw': 156 | opt = optim.AdamW(groups, 157 | lr=lr, 158 | betas=tuple(opt_config['betas']), 159 | eps=opt_config['eps'], 160 | weight_decay=opt_config['weight_decay']) 161 | elif opt_config['type'] == 'adam8bit': 162 | import bitsandbytes as bnb 163 | opt = bnb.optim.Adam8bit(groups, 164 | lr=lr, 165 | betas=tuple(opt_config['betas']), 166 | eps=opt_config['eps'], 167 | weight_decay=opt_config['weight_decay']) 168 | elif opt_config['type'] == 'sgd': 169 | opt = optim.SGD(groups, 170 | lr=lr, 171 | momentum=opt_config.get('momentum', 0.), 172 | nesterov=opt_config.get('nesterov', False), 173 | weight_decay=opt_config.get('weight_decay', 0.)) 174 | else: 175 | raise ValueError('Invalid optimizer type') 176 | 177 | if sched_config['type'] == 'inverse': 178 | sched = K.utils.InverseLR(opt, 179 | inv_gamma=sched_config['inv_gamma'], 180 | power=sched_config['power'], 181 | warmup=sched_config['warmup']) 182 | elif sched_config['type'] == 'exponential': 183 | sched = K.utils.ExponentialLR(opt, 184 | num_steps=sched_config['num_steps'], 185 | decay=sched_config['decay'], 186 | warmup=sched_config['warmup']) 187 | elif sched_config['type'] == 'constant': 188 | sched = K.utils.ConstantLRWithWarmup(opt, warmup=sched_config['warmup']) 189 | else: 190 | raise ValueError('Invalid schedule type') 191 | 192 | assert ema_sched_config['type'] == 'inverse' 193 | ema_sched = K.utils.EMAWarmup(power=ema_sched_config['power'], 194 | max_value=ema_sched_config['max_value']) 195 | ema_stats = {} 196 | 197 | tf = transforms.Compose([ 198 | transforms.Resize(size[0], interpolation=transforms.InterpolationMode.BICUBIC), 199 | transforms.CenterCrop(size[0]), 200 | K.augmentation.KarrasAugmentationPipeline(model_config['augment_prob'], disable_all=model_config['augment_prob'] == 0), 201 | ]) 202 | 203 | if dataset_config['type'] == 'imagefolder': 204 | train_set = K.utils.FolderOfImages(dataset_config['location'], transform=tf) 205 | elif dataset_config['type'] == 'imagefolder-class': 206 | train_set = datasets.ImageFolder(dataset_config['location'], transform=tf) 207 | elif dataset_config['type'] == 'cifar10': 208 | train_set = datasets.CIFAR10(dataset_config['location'], train=True, download=True, transform=tf) 209 | elif dataset_config['type'] == 'mnist': 210 | train_set = datasets.MNIST(dataset_config['location'], train=True, download=True, transform=tf) 211 | elif dataset_config['type'] == 'huggingface': 212 | from datasets import load_dataset 213 | train_set = load_dataset(dataset_config['location']) 214 | train_set.set_transform(partial(K.utils.hf_datasets_augs_helper, transform=tf, image_key=dataset_config['image_key'])) 215 | train_set = train_set['train'] 216 | elif dataset_config['type'] == 'custom': 217 | location = (Path(args.config).parent / dataset_config['location']).resolve() 218 | spec = importlib.util.spec_from_file_location('custom_dataset', location) 219 | module = importlib.util.module_from_spec(spec) 220 | spec.loader.exec_module(module) 221 | get_dataset = getattr(module, dataset_config.get('get_dataset', 'get_dataset')) 222 | custom_dataset_config = dataset_config.get('config', {}) 223 | train_set = get_dataset(custom_dataset_config, transform=tf) 224 | else: 225 | raise ValueError('Invalid dataset type') 226 | 227 | if accelerator.is_main_process: 228 | try: 229 | print(f'Number of items in dataset: {len(train_set):,}') 230 | except TypeError: 231 | pass 232 | 233 | image_key = dataset_config.get('image_key', 0) 234 | num_classes = dataset_config.get('num_classes', 0) 235 | cond_dropout_rate = dataset_config.get('cond_dropout_rate', 0.1) 236 | class_key = dataset_config.get('class_key', 1) 237 | 238 | train_dl = data.DataLoader(train_set, args.batch_size, shuffle=True, drop_last=True, 239 | num_workers=args.num_workers, persistent_workers=True, pin_memory=True) 240 | 241 | inner_model, inner_model_ema, opt, train_dl = accelerator.prepare(inner_model, inner_model_ema, opt, train_dl) 242 | 243 | with torch.no_grad(), K.models.flops.flop_counter() as fc: 244 | x = torch.zeros([1, model_config['input_channels'], size[0], size[1]], device=device) 245 | sigma = torch.ones([1], device=device) 246 | extra_args = {} 247 | if getattr(unwrap(inner_model), "num_classes", 0): 248 | extra_args['class_cond'] = torch.zeros([1], dtype=torch.long, device=device) 249 | inner_model(x, sigma, **extra_args) 250 | if accelerator.is_main_process: 251 | print(f"Forward pass GFLOPs: {fc.flops / 1_000_000_000:,.3f}", flush=True) 252 | 253 | if use_wandb: 254 | wandb.watch(inner_model) 255 | if accelerator.num_processes == 1: 256 | args.gns = False 257 | if args.gns: 258 | gns_stats_hook = K.gns.DDPGradientStatsHook(inner_model) 259 | gns_stats = K.gns.GradientNoiseScale() 260 | else: 261 | gns_stats = None 262 | sigma_min = model_config['sigma_min'] 263 | sigma_max = model_config['sigma_max'] 264 | sample_density = K.config.make_sample_density(model_config) 265 | 266 | model = K.config.make_denoiser_wrapper(config)(inner_model) 267 | model_ema = K.config.make_denoiser_wrapper(config)(inner_model_ema) 268 | 269 | state_path = Path(f'{args.name}_state.json') 270 | 271 | if state_path.exists() or args.resume: 272 | if args.resume: 273 | ckpt_path = args.resume 274 | if not args.resume: 275 | state = json.load(open(state_path)) 276 | ckpt_path = state['latest_checkpoint'] 277 | if accelerator.is_main_process: 278 | print(f'Resuming from {ckpt_path}...') 279 | ckpt = torch.load(ckpt_path, map_location='cpu') 280 | unwrap(model.inner_model).load_state_dict(ckpt['model']) 281 | unwrap(model_ema.inner_model).load_state_dict(ckpt['model_ema']) 282 | opt.load_state_dict(ckpt['opt']) 283 | sched.load_state_dict(ckpt['sched']) 284 | ema_sched.load_state_dict(ckpt['ema_sched']) 285 | ema_stats = ckpt.get('ema_stats', ema_stats) 286 | epoch = ckpt['epoch'] + 1 287 | step = ckpt['step'] + 1 288 | if args.gns and ckpt.get('gns_stats', None) is not None: 289 | gns_stats.load_state_dict(ckpt['gns_stats']) 290 | demo_gen.set_state(ckpt['demo_gen']) 291 | elapsed = ckpt.get('elapsed', 0.0) 292 | 293 | del ckpt 294 | else: 295 | epoch = 0 296 | step = 0 297 | 298 | if args.reset_ema: 299 | unwrap(model.inner_model).load_state_dict(unwrap(model_ema.inner_model).state_dict()) 300 | ema_sched = K.utils.EMAWarmup(power=ema_sched_config['power'], 301 | max_value=ema_sched_config['max_value']) 302 | ema_stats = {} 303 | 304 | if args.resume_inference: 305 | if accelerator.is_main_process: 306 | print(f'Loading {args.resume_inference}...') 307 | ckpt = safetorch.load_file(args.resume_inference) 308 | unwrap(model.inner_model).load_state_dict(ckpt) 309 | unwrap(model_ema.inner_model).load_state_dict(ckpt) 310 | del ckpt 311 | 312 | evaluate_enabled = args.evaluate_every > 0 and args.evaluate_n > 0 313 | metrics_log = None 314 | if evaluate_enabled: 315 | if args.evaluate_with == 'inception': 316 | extractor = K.evaluation.InceptionV3FeatureExtractor(device=device) 317 | elif args.evaluate_with == 'clip': 318 | extractor = K.evaluation.CLIPFeatureExtractor(args.clip_model, device=device) 319 | elif args.evaluate_with == 'dinov2': 320 | extractor = K.evaluation.DINOv2FeatureExtractor(args.dinov2_model, device=device) 321 | else: 322 | raise ValueError('Invalid evaluation feature extractor') 323 | train_iter = iter(train_dl) 324 | if accelerator.is_main_process: 325 | print('Computing features for reals...') 326 | reals_features = K.evaluation.compute_features(accelerator, lambda x: next(train_iter)[image_key][1], extractor, args.evaluate_n, args.batch_size) 327 | if accelerator.is_main_process and not args.evaluate_only: 328 | metrics_log = K.utils.CSVLogger(f'{args.name}_metrics.csv', ['step', 'time', 'loss', 'fid', 'kid']) 329 | del train_iter 330 | 331 | cfg_scale = 1. 332 | 333 | def make_cfg_model_fn(model): 334 | def cfg_model_fn(x, sigma, class_cond): 335 | x_in = torch.cat([x, x]) 336 | sigma_in = torch.cat([sigma, sigma]) 337 | class_uncond = torch.full_like(class_cond, num_classes) 338 | class_cond_in = torch.cat([class_uncond, class_cond]) 339 | out = model(x_in, sigma_in, class_cond=class_cond_in) 340 | out_uncond, out_cond = out.chunk(2) 341 | return out_uncond + (out_cond - out_uncond) * cfg_scale 342 | if cfg_scale != 1: 343 | return cfg_model_fn 344 | return model 345 | 346 | @torch.no_grad() 347 | @K.utils.eval_mode(model_ema) 348 | def demo(): 349 | if accelerator.is_main_process: 350 | tqdm.write('Sampling...') 351 | filename = f'{args.name}_demo_{step:08}.png' 352 | n_per_proc = math.ceil(args.sample_n / accelerator.num_processes) 353 | x = torch.randn([accelerator.num_processes, n_per_proc, model_config['input_channels'], size[0], size[1]], generator=demo_gen).to(device) 354 | dist.broadcast(x, 0) 355 | x = x[accelerator.process_index] * sigma_max 356 | model_fn, extra_args = model_ema, {} 357 | if num_classes: 358 | class_cond = torch.randint(0, num_classes, [accelerator.num_processes, n_per_proc], generator=demo_gen).to(device) 359 | dist.broadcast(class_cond, 0) 360 | extra_args['class_cond'] = class_cond[accelerator.process_index] 361 | model_fn = make_cfg_model_fn(model_ema) 362 | sigmas = K.sampling.get_sigmas_karras(50, sigma_min, sigma_max, rho=7., device=device) 363 | x_0 = K.sampling.sample_dpmpp_2m_sde(model_fn, x, sigmas, extra_args=extra_args, eta=0.0, solver_type='heun', disable=not accelerator.is_main_process) 364 | x_0 = accelerator.gather(x_0)[:args.sample_n] 365 | if accelerator.is_main_process: 366 | grid = utils.make_grid(x_0, nrow=math.ceil(args.sample_n ** 0.5), padding=0) 367 | K.utils.to_pil_image(grid).save(filename) 368 | if use_wandb: 369 | wandb.log({'demo_grid': wandb.Image(filename)}, step=step) 370 | 371 | @torch.no_grad() 372 | @K.utils.eval_mode(model_ema) 373 | def evaluate(): 374 | if not evaluate_enabled: 375 | return 376 | if accelerator.is_main_process: 377 | tqdm.write('Evaluating...') 378 | sigmas = K.sampling.get_sigmas_karras(50, sigma_min, sigma_max, rho=7., device=device) 379 | def sample_fn(n): 380 | x = torch.randn([n, model_config['input_channels'], size[0], size[1]], device=device) * sigma_max 381 | model_fn, extra_args = model_ema, {} 382 | if num_classes: 383 | extra_args['class_cond'] = torch.randint(0, num_classes, [n], device=device) 384 | model_fn = make_cfg_model_fn(model_ema) 385 | x_0 = K.sampling.sample_dpmpp_2m_sde(model_fn, x, sigmas, extra_args=extra_args, eta=0.0, solver_type='heun', disable=True) 386 | return x_0 387 | fakes_features = K.evaluation.compute_features(accelerator, sample_fn, extractor, args.evaluate_n, args.batch_size) 388 | if accelerator.is_main_process: 389 | fid = K.evaluation.fid(fakes_features, reals_features) 390 | kid = K.evaluation.kid(fakes_features, reals_features) 391 | print(f'FID: {fid.item():g}, KID: {kid.item():g}') 392 | if accelerator.is_main_process and metrics_log is not None: 393 | metrics_log.write(step, elapsed, ema_stats['loss'], fid.item(), kid.item()) 394 | if use_wandb: 395 | wandb.log({'FID': fid.item(), 'KID': kid.item()}, step=step) 396 | 397 | def save(): 398 | accelerator.wait_for_everyone() 399 | filename = f'{args.name}_{step:08}.pth' 400 | if accelerator.is_main_process: 401 | tqdm.write(f'Saving to {filename}...') 402 | inner_model = unwrap(model.inner_model) 403 | inner_model_ema = unwrap(model_ema.inner_model) 404 | obj = { 405 | 'config': config, 406 | 'model': inner_model.state_dict(), 407 | 'model_ema': inner_model_ema.state_dict(), 408 | 'opt': opt.state_dict(), 409 | 'sched': sched.state_dict(), 410 | 'ema_sched': ema_sched.state_dict(), 411 | 'epoch': epoch, 412 | 'step': step, 413 | 'gns_stats': gns_stats.state_dict() if gns_stats is not None else None, 414 | 'ema_stats': ema_stats, 415 | 'demo_gen': demo_gen.get_state(), 416 | 'elapsed': elapsed, 417 | } 418 | accelerator.save(obj, filename) 419 | if accelerator.is_main_process: 420 | state_obj = {'latest_checkpoint': filename} 421 | json.dump(state_obj, open(state_path, 'w')) 422 | if args.wandb_save_model and use_wandb: 423 | wandb.save(filename) 424 | 425 | if args.evaluate_only: 426 | if not evaluate_enabled: 427 | raise ValueError('--evaluate-only requested but evaluation is disabled') 428 | evaluate() 429 | return 430 | 431 | losses_since_last_print = [] 432 | 433 | try: 434 | while True: 435 | for batch in tqdm(train_dl, smoothing=0.1, disable=not accelerator.is_main_process): 436 | if device.type == 'cuda': 437 | start_timer = torch.cuda.Event(enable_timing=True) 438 | end_timer = torch.cuda.Event(enable_timing=True) 439 | torch.cuda.synchronize() 440 | start_timer.record() 441 | else: 442 | start_timer = time.time() 443 | 444 | with accelerator.accumulate(model): 445 | reals, _, aug_cond = batch[image_key] 446 | class_cond, extra_args = None, {} 447 | if num_classes: 448 | class_cond = batch[class_key] 449 | drop = torch.rand(class_cond.shape, device=class_cond.device) 450 | class_cond.masked_fill_(drop < cond_dropout_rate, num_classes) 451 | extra_args['class_cond'] = class_cond 452 | noise = torch.randn_like(reals) 453 | with K.utils.enable_stratified_accelerate(accelerator, disable=args.gns): 454 | sigma = sample_density([reals.shape[0]], device=device) 455 | with K.models.checkpointing(args.checkpointing): 456 | losses = model.loss(reals, noise, sigma, aug_cond=aug_cond, **extra_args) 457 | loss = accelerator.gather(losses).mean().item() 458 | losses_since_last_print.append(loss) 459 | accelerator.backward(losses.mean()) 460 | if args.gns: 461 | sq_norm_small_batch, sq_norm_large_batch = gns_stats_hook.get_stats() 462 | gns_stats.update(sq_norm_small_batch, sq_norm_large_batch, reals.shape[0], reals.shape[0] * accelerator.num_processes) 463 | if accelerator.sync_gradients: 464 | accelerator.clip_grad_norm_(model.parameters(), 1.) 465 | opt.step() 466 | sched.step() 467 | opt.zero_grad() 468 | 469 | ema_decay = ema_sched.get_value() 470 | K.utils.ema_update_dict(ema_stats, {'loss': loss}, ema_decay ** (1 / args.grad_accum_steps)) 471 | if accelerator.sync_gradients: 472 | K.utils.ema_update(model, model_ema, ema_decay) 473 | ema_sched.step() 474 | 475 | if device.type == 'cuda': 476 | end_timer.record() 477 | torch.cuda.synchronize() 478 | elapsed += start_timer.elapsed_time(end_timer) / 1000 479 | else: 480 | elapsed += time.time() - start_timer 481 | 482 | if step % 25 == 0: 483 | loss_disp = sum(losses_since_last_print) / len(losses_since_last_print) 484 | losses_since_last_print.clear() 485 | avg_loss = ema_stats['loss'] 486 | if accelerator.is_main_process: 487 | if args.gns: 488 | tqdm.write(f'Epoch: {epoch}, step: {step}, loss: {loss_disp:g}, avg loss: {avg_loss:g}, gns: {gns_stats.get_gns():g}') 489 | else: 490 | tqdm.write(f'Epoch: {epoch}, step: {step}, loss: {loss_disp:g}, avg loss: {avg_loss:g}') 491 | 492 | if use_wandb: 493 | log_dict = { 494 | 'epoch': epoch, 495 | 'loss': loss, 496 | 'lr': sched.get_last_lr()[0], 497 | 'ema_decay': ema_decay, 498 | } 499 | if args.gns: 500 | log_dict['gradient_noise_scale'] = gns_stats.get_gns() 501 | wandb.log(log_dict, step=step) 502 | 503 | step += 1 504 | 505 | if step % args.demo_every == 0: 506 | demo() 507 | 508 | if evaluate_enabled and step > 0 and step % args.evaluate_every == 0: 509 | evaluate() 510 | 511 | if step == args.end_step or (step > 0 and step % args.save_every == 0): 512 | save() 513 | 514 | if step == args.end_step: 515 | if accelerator.is_main_process: 516 | tqdm.write('Done!') 517 | return 518 | 519 | epoch += 1 520 | except KeyboardInterrupt: 521 | pass 522 | 523 | 524 | if __name__ == '__main__': 525 | main() 526 | -------------------------------------------------------------------------------- /k_diffusion/models/image_transformer_v2.py: -------------------------------------------------------------------------------- 1 | """k-diffusion transformer diffusion models, version 2.""" 2 | 3 | from dataclasses import dataclass 4 | from functools import lru_cache, reduce 5 | import math 6 | from typing import Union 7 | 8 | from einops import rearrange 9 | import torch 10 | from torch import nn 11 | import torch._dynamo 12 | from torch.nn import functional as F 13 | 14 | from . import flags, flops 15 | from .. import layers 16 | from .axial_rope import make_axial_pos 17 | 18 | 19 | try: 20 | import natten 21 | except ImportError: 22 | natten = None 23 | 24 | try: 25 | import flash_attn 26 | except ImportError: 27 | flash_attn = None 28 | 29 | 30 | if flags.get_use_compile(): 31 | torch._dynamo.config.cache_size_limit = max(64, torch._dynamo.config.cache_size_limit) 32 | torch._dynamo.config.suppress_errors = True 33 | 34 | 35 | # Helpers 36 | 37 | def zero_init(layer): 38 | nn.init.zeros_(layer.weight) 39 | if layer.bias is not None: 40 | nn.init.zeros_(layer.bias) 41 | return layer 42 | 43 | 44 | def checkpoint(function, *args, **kwargs): 45 | if flags.get_checkpointing(): 46 | kwargs.setdefault("use_reentrant", True) 47 | return torch.utils.checkpoint.checkpoint(function, *args, **kwargs) 48 | else: 49 | return function(*args, **kwargs) 50 | 51 | 52 | def downscale_pos(pos): 53 | pos = rearrange(pos, "... (h nh) (w nw) e -> ... h w (nh nw) e", nh=2, nw=2) 54 | return torch.mean(pos, dim=-2) 55 | 56 | 57 | # Param tags 58 | 59 | def tag_param(param, tag): 60 | if not hasattr(param, "_tags"): 61 | param._tags = set([tag]) 62 | else: 63 | param._tags.add(tag) 64 | return param 65 | 66 | 67 | def tag_module(module, tag): 68 | for param in module.parameters(): 69 | tag_param(param, tag) 70 | return module 71 | 72 | 73 | def apply_wd(module): 74 | for name, param in module.named_parameters(): 75 | if name.endswith("weight"): 76 | tag_param(param, "wd") 77 | return module 78 | 79 | 80 | def filter_params(function, module): 81 | for param in module.parameters(): 82 | tags = getattr(param, "_tags", set()) 83 | if function(tags): 84 | yield param 85 | 86 | 87 | # Kernels 88 | 89 | @flags.compile_wrap 90 | def linear_geglu(x, weight, bias=None): 91 | x = x @ weight.mT 92 | if bias is not None: 93 | x = x + bias 94 | x, gate = x.chunk(2, dim=-1) 95 | return x * F.gelu(gate) 96 | 97 | 98 | @flags.compile_wrap 99 | def rms_norm(x, scale, eps): 100 | dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32)) 101 | mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True) 102 | scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps) 103 | return x * scale.to(x.dtype) 104 | 105 | 106 | @flags.compile_wrap 107 | def scale_for_cosine_sim(q, k, scale, eps): 108 | dtype = reduce(torch.promote_types, (q.dtype, k.dtype, scale.dtype, torch.float32)) 109 | sum_sq_q = torch.sum(q.to(dtype)**2, dim=-1, keepdim=True) 110 | sum_sq_k = torch.sum(k.to(dtype)**2, dim=-1, keepdim=True) 111 | sqrt_scale = torch.sqrt(scale.to(dtype)) 112 | scale_q = sqrt_scale * torch.rsqrt(sum_sq_q + eps) 113 | scale_k = sqrt_scale * torch.rsqrt(sum_sq_k + eps) 114 | return q * scale_q.to(q.dtype), k * scale_k.to(k.dtype) 115 | 116 | 117 | @flags.compile_wrap 118 | def scale_for_cosine_sim_qkv(qkv, scale, eps): 119 | q, k, v = qkv.unbind(2) 120 | q, k = scale_for_cosine_sim(q, k, scale[:, None], eps) 121 | return torch.stack((q, k, v), dim=2) 122 | 123 | 124 | # Layers 125 | 126 | class Linear(nn.Linear): 127 | def forward(self, x): 128 | flops.op(flops.op_linear, x.shape, self.weight.shape) 129 | return super().forward(x) 130 | 131 | 132 | class LinearGEGLU(nn.Linear): 133 | def __init__(self, in_features, out_features, bias=True): 134 | super().__init__(in_features, out_features * 2, bias=bias) 135 | self.out_features = out_features 136 | 137 | def forward(self, x): 138 | flops.op(flops.op_linear, x.shape, self.weight.shape) 139 | return linear_geglu(x, self.weight, self.bias) 140 | 141 | 142 | class RMSNorm(nn.Module): 143 | def __init__(self, shape, eps=1e-6): 144 | super().__init__() 145 | self.eps = eps 146 | self.scale = nn.Parameter(torch.ones(shape)) 147 | 148 | def extra_repr(self): 149 | return f"shape={tuple(self.scale.shape)}, eps={self.eps}" 150 | 151 | def forward(self, x): 152 | return rms_norm(x, self.scale, self.eps) 153 | 154 | 155 | class AdaRMSNorm(nn.Module): 156 | def __init__(self, features, cond_features, eps=1e-6): 157 | super().__init__() 158 | self.eps = eps 159 | self.linear = apply_wd(zero_init(Linear(cond_features, features, bias=False))) 160 | tag_module(self.linear, "mapping") 161 | 162 | def extra_repr(self): 163 | return f"eps={self.eps}," 164 | 165 | def forward(self, x, cond): 166 | return rms_norm(x, self.linear(cond)[:, None, None, :] + 1, self.eps) 167 | 168 | 169 | # Rotary position embeddings 170 | 171 | @flags.compile_wrap 172 | def apply_rotary_emb(x, theta, conj=False): 173 | out_dtype = x.dtype 174 | dtype = reduce(torch.promote_types, (x.dtype, theta.dtype, torch.float32)) 175 | d = theta.shape[-1] 176 | assert d * 2 <= x.shape[-1] 177 | x1, x2, x3 = x[..., :d], x[..., d : d * 2], x[..., d * 2 :] 178 | x1, x2, theta = x1.to(dtype), x2.to(dtype), theta.to(dtype) 179 | cos, sin = torch.cos(theta), torch.sin(theta) 180 | sin = -sin if conj else sin 181 | y1 = x1 * cos - x2 * sin 182 | y2 = x2 * cos + x1 * sin 183 | y1, y2 = y1.to(out_dtype), y2.to(out_dtype) 184 | return torch.cat((y1, y2, x3), dim=-1) 185 | 186 | 187 | @flags.compile_wrap 188 | def _apply_rotary_emb_inplace(x, theta, conj): 189 | dtype = reduce(torch.promote_types, (x.dtype, theta.dtype, torch.float32)) 190 | d = theta.shape[-1] 191 | assert d * 2 <= x.shape[-1] 192 | x1, x2 = x[..., :d], x[..., d : d * 2] 193 | x1_, x2_, theta = x1.to(dtype), x2.to(dtype), theta.to(dtype) 194 | cos, sin = torch.cos(theta), torch.sin(theta) 195 | sin = -sin if conj else sin 196 | y1 = x1_ * cos - x2_ * sin 197 | y2 = x2_ * cos + x1_ * sin 198 | x1.copy_(y1) 199 | x2.copy_(y2) 200 | 201 | 202 | class ApplyRotaryEmbeddingInplace(torch.autograd.Function): 203 | generate_vmap_rule = True 204 | 205 | @staticmethod 206 | def setup_context(ctx, inputs, output): 207 | x, theta, conj = inputs 208 | ctx.mark_dirty(x) 209 | ctx.save_for_backward(theta) 210 | ctx.save_for_forward(theta) 211 | ctx.conj = conj 212 | 213 | @staticmethod 214 | def forward(x, theta, conj): 215 | _apply_rotary_emb_inplace(x, theta, conj) 216 | return x 217 | 218 | @staticmethod 219 | def backward(ctx, grad_output): 220 | theta, = ctx.saved_tensors 221 | grad_output = ApplyRotaryEmbeddingInplace.apply(grad_output.clone(), theta, not ctx.conj) 222 | return grad_output, None, None 223 | 224 | @staticmethod 225 | def jvp(ctx, grad_input, _, __): 226 | theta, = ctx.saved_tensors 227 | return ApplyRotaryEmbeddingInplace.apply(grad_input, theta, ctx.conj) 228 | 229 | 230 | def apply_rotary_emb_(x, theta): 231 | return ApplyRotaryEmbeddingInplace.apply(x, theta, False) 232 | 233 | 234 | class AxialRoPE(nn.Module): 235 | def __init__(self, dim, n_heads): 236 | super().__init__() 237 | log_min = math.log(math.pi) 238 | log_max = math.log(10.0 * math.pi) 239 | freqs = torch.linspace(log_min, log_max, n_heads * dim // 4 + 1)[:-1].exp() 240 | self.register_buffer("freqs", freqs.view(dim // 4, n_heads).T.contiguous()) 241 | 242 | def extra_repr(self): 243 | return f"dim={self.freqs.shape[1] * 4}, n_heads={self.freqs.shape[0]}" 244 | 245 | def forward(self, pos): 246 | theta_h = pos[..., None, 0:1] * self.freqs.to(pos.dtype) 247 | theta_w = pos[..., None, 1:2] * self.freqs.to(pos.dtype) 248 | return torch.cat((theta_h, theta_w), dim=-1) 249 | 250 | 251 | # Shifted window attention 252 | 253 | def window(window_size, x): 254 | *b, h, w, c = x.shape 255 | x = torch.reshape( 256 | x, 257 | (*b, h // window_size, window_size, w // window_size, window_size, c), 258 | ) 259 | x = torch.permute( 260 | x, 261 | (*range(len(b)), -5, -3, -4, -2, -1), 262 | ) 263 | return x 264 | 265 | 266 | def unwindow(x): 267 | *b, h, w, wh, ww, c = x.shape 268 | x = torch.permute(x, (*range(len(b)), -5, -3, -4, -2, -1)) 269 | x = torch.reshape(x, (*b, h * wh, w * ww, c)) 270 | return x 271 | 272 | 273 | def shifted_window(window_size, window_shift, x): 274 | x = torch.roll(x, shifts=(window_shift, window_shift), dims=(-2, -3)) 275 | windows = window(window_size, x) 276 | return windows 277 | 278 | 279 | def shifted_unwindow(window_shift, x): 280 | x = unwindow(x) 281 | x = torch.roll(x, shifts=(-window_shift, -window_shift), dims=(-2, -3)) 282 | return x 283 | 284 | 285 | @lru_cache 286 | def make_shifted_window_masks(n_h_w, n_w_w, w_h, w_w, shift, device=None): 287 | ph_coords = torch.arange(n_h_w, device=device) 288 | pw_coords = torch.arange(n_w_w, device=device) 289 | h_coords = torch.arange(w_h, device=device) 290 | w_coords = torch.arange(w_w, device=device) 291 | patch_h, patch_w, q_h, q_w, k_h, k_w = torch.meshgrid( 292 | ph_coords, 293 | pw_coords, 294 | h_coords, 295 | w_coords, 296 | h_coords, 297 | w_coords, 298 | indexing="ij", 299 | ) 300 | is_top_patch = patch_h == 0 301 | is_left_patch = patch_w == 0 302 | q_above_shift = q_h < shift 303 | k_above_shift = k_h < shift 304 | q_left_of_shift = q_w < shift 305 | k_left_of_shift = k_w < shift 306 | m_corner = ( 307 | is_left_patch 308 | & is_top_patch 309 | & (q_left_of_shift == k_left_of_shift) 310 | & (q_above_shift == k_above_shift) 311 | ) 312 | m_left = is_left_patch & ~is_top_patch & (q_left_of_shift == k_left_of_shift) 313 | m_top = ~is_left_patch & is_top_patch & (q_above_shift == k_above_shift) 314 | m_rest = ~is_left_patch & ~is_top_patch 315 | m = m_corner | m_left | m_top | m_rest 316 | return m 317 | 318 | 319 | def apply_window_attention(window_size, window_shift, q, k, v, scale=None): 320 | # prep windows and masks 321 | q_windows = shifted_window(window_size, window_shift, q) 322 | k_windows = shifted_window(window_size, window_shift, k) 323 | v_windows = shifted_window(window_size, window_shift, v) 324 | b, heads, h, w, wh, ww, d_head = q_windows.shape 325 | mask = make_shifted_window_masks(h, w, wh, ww, window_shift, device=q.device) 326 | q_seqs = torch.reshape(q_windows, (b, heads, h, w, wh * ww, d_head)) 327 | k_seqs = torch.reshape(k_windows, (b, heads, h, w, wh * ww, d_head)) 328 | v_seqs = torch.reshape(v_windows, (b, heads, h, w, wh * ww, d_head)) 329 | mask = torch.reshape(mask, (h, w, wh * ww, wh * ww)) 330 | 331 | # do the attention here 332 | flops.op(flops.op_attention, q_seqs.shape, k_seqs.shape, v_seqs.shape) 333 | qkv = F.scaled_dot_product_attention(q_seqs, k_seqs, v_seqs, mask, scale=scale) 334 | 335 | # unwindow 336 | qkv = torch.reshape(qkv, (b, heads, h, w, wh, ww, d_head)) 337 | return shifted_unwindow(window_shift, qkv) 338 | 339 | 340 | # Transformer layers 341 | 342 | 343 | def use_flash_2(x): 344 | if not flags.get_use_flash_attention_2(): 345 | return False 346 | if flash_attn is None: 347 | return False 348 | if x.device.type != "cuda": 349 | return False 350 | if x.dtype not in (torch.float16, torch.bfloat16): 351 | return False 352 | return True 353 | 354 | 355 | class SelfAttentionBlock(nn.Module): 356 | def __init__(self, d_model, d_head, cond_features, dropout=0.0): 357 | super().__init__() 358 | self.d_head = d_head 359 | self.n_heads = d_model // d_head 360 | self.norm = AdaRMSNorm(d_model, cond_features) 361 | self.qkv_proj = apply_wd(Linear(d_model, d_model * 3, bias=False)) 362 | self.scale = nn.Parameter(torch.full([self.n_heads], 10.0)) 363 | self.pos_emb = AxialRoPE(d_head // 2, self.n_heads) 364 | self.dropout = nn.Dropout(dropout) 365 | self.out_proj = apply_wd(zero_init(Linear(d_model, d_model, bias=False))) 366 | 367 | def extra_repr(self): 368 | return f"d_head={self.d_head}," 369 | 370 | def forward(self, x, pos, cond): 371 | skip = x 372 | x = self.norm(x, cond) 373 | qkv = self.qkv_proj(x) 374 | pos = rearrange(pos, "... h w e -> ... (h w) e").to(qkv.dtype) 375 | theta = self.pos_emb(pos) 376 | if use_flash_2(qkv): 377 | qkv = rearrange(qkv, "n h w (t nh e) -> n (h w) t nh e", t=3, e=self.d_head) 378 | qkv = scale_for_cosine_sim_qkv(qkv, self.scale, 1e-6) 379 | theta = torch.stack((theta, theta, torch.zeros_like(theta)), dim=-3) 380 | qkv = apply_rotary_emb_(qkv, theta) 381 | flops_shape = qkv.shape[-5], qkv.shape[-2], qkv.shape[-4], qkv.shape[-1] 382 | flops.op(flops.op_attention, flops_shape, flops_shape, flops_shape) 383 | x = flash_attn.flash_attn_qkvpacked_func(qkv, softmax_scale=1.0) 384 | x = rearrange(x, "n (h w) nh e -> n h w (nh e)", h=skip.shape[-3], w=skip.shape[-2]) 385 | else: 386 | q, k, v = rearrange(qkv, "n h w (t nh e) -> t n nh (h w) e", t=3, e=self.d_head) 387 | q, k = scale_for_cosine_sim(q, k, self.scale[:, None, None], 1e-6) 388 | theta = theta.movedim(-2, -3) 389 | q = apply_rotary_emb_(q, theta) 390 | k = apply_rotary_emb_(k, theta) 391 | flops.op(flops.op_attention, q.shape, k.shape, v.shape) 392 | x = F.scaled_dot_product_attention(q, k, v, scale=1.0) 393 | x = rearrange(x, "n nh (h w) e -> n h w (nh e)", h=skip.shape[-3], w=skip.shape[-2]) 394 | x = self.dropout(x) 395 | x = self.out_proj(x) 396 | return x + skip 397 | 398 | 399 | class NeighborhoodSelfAttentionBlock(nn.Module): 400 | def __init__(self, d_model, d_head, cond_features, kernel_size, dropout=0.0): 401 | super().__init__() 402 | self.d_head = d_head 403 | self.n_heads = d_model // d_head 404 | self.kernel_size = kernel_size 405 | self.norm = AdaRMSNorm(d_model, cond_features) 406 | self.qkv_proj = apply_wd(Linear(d_model, d_model * 3, bias=False)) 407 | self.scale = nn.Parameter(torch.full([self.n_heads], 10.0)) 408 | self.pos_emb = AxialRoPE(d_head // 2, self.n_heads) 409 | self.dropout = nn.Dropout(dropout) 410 | self.out_proj = apply_wd(zero_init(Linear(d_model, d_model, bias=False))) 411 | 412 | def extra_repr(self): 413 | return f"d_head={self.d_head}, kernel_size={self.kernel_size}" 414 | 415 | def forward(self, x, pos, cond): 416 | skip = x 417 | x = self.norm(x, cond) 418 | qkv = self.qkv_proj(x) 419 | if natten is None: 420 | raise ModuleNotFoundError("natten is required for neighborhood attention") 421 | if natten.has_fused_na(): 422 | q, k, v = rearrange(qkv, "n h w (t nh e) -> t n h w nh e", t=3, e=self.d_head) 423 | q, k = scale_for_cosine_sim(q, k, self.scale[:, None], 1e-6) 424 | theta = self.pos_emb(pos) 425 | q = apply_rotary_emb_(q, theta) 426 | k = apply_rotary_emb_(k, theta) 427 | flops.op(flops.op_natten, q.shape, k.shape, v.shape, self.kernel_size) 428 | x = natten.functional.na2d(q, k, v, self.kernel_size, scale=1.0) 429 | x = rearrange(x, "n h w nh e -> n h w (nh e)") 430 | else: 431 | q, k, v = rearrange(qkv, "n h w (t nh e) -> t n nh h w e", t=3, e=self.d_head) 432 | q, k = scale_for_cosine_sim(q, k, self.scale[:, None, None, None], 1e-6) 433 | theta = self.pos_emb(pos).movedim(-2, -4) 434 | q = apply_rotary_emb_(q, theta) 435 | k = apply_rotary_emb_(k, theta) 436 | flops.op(flops.op_natten, q.shape, k.shape, v.shape, self.kernel_size) 437 | qk = natten.functional.na2d_qk(q, k, self.kernel_size) 438 | a = torch.softmax(qk, dim=-1).to(v.dtype) 439 | x = natten.functional.na2d_av(a, v, self.kernel_size) 440 | x = rearrange(x, "n nh h w e -> n h w (nh e)") 441 | x = self.dropout(x) 442 | x = self.out_proj(x) 443 | return x + skip 444 | 445 | 446 | class ShiftedWindowSelfAttentionBlock(nn.Module): 447 | def __init__(self, d_model, d_head, cond_features, window_size, window_shift, dropout=0.0): 448 | super().__init__() 449 | self.d_head = d_head 450 | self.n_heads = d_model // d_head 451 | self.window_size = window_size 452 | self.window_shift = window_shift 453 | self.norm = AdaRMSNorm(d_model, cond_features) 454 | self.qkv_proj = apply_wd(Linear(d_model, d_model * 3, bias=False)) 455 | self.scale = nn.Parameter(torch.full([self.n_heads], 10.0)) 456 | self.pos_emb = AxialRoPE(d_head // 2, self.n_heads) 457 | self.dropout = nn.Dropout(dropout) 458 | self.out_proj = apply_wd(zero_init(Linear(d_model, d_model, bias=False))) 459 | 460 | def extra_repr(self): 461 | return f"d_head={self.d_head}, window_size={self.window_size}, window_shift={self.window_shift}" 462 | 463 | def forward(self, x, pos, cond): 464 | skip = x 465 | x = self.norm(x, cond) 466 | qkv = self.qkv_proj(x) 467 | q, k, v = rearrange(qkv, "n h w (t nh e) -> t n nh h w e", t=3, e=self.d_head) 468 | q, k = scale_for_cosine_sim(q, k, self.scale[:, None, None, None], 1e-6) 469 | theta = self.pos_emb(pos).movedim(-2, -4) 470 | q = apply_rotary_emb_(q, theta) 471 | k = apply_rotary_emb_(k, theta) 472 | x = apply_window_attention(self.window_size, self.window_shift, q, k, v, scale=1.0) 473 | x = rearrange(x, "n nh h w e -> n h w (nh e)") 474 | x = self.dropout(x) 475 | x = self.out_proj(x) 476 | return x + skip 477 | 478 | 479 | class FeedForwardBlock(nn.Module): 480 | def __init__(self, d_model, d_ff, cond_features, dropout=0.0): 481 | super().__init__() 482 | self.norm = AdaRMSNorm(d_model, cond_features) 483 | self.up_proj = apply_wd(LinearGEGLU(d_model, d_ff, bias=False)) 484 | self.dropout = nn.Dropout(dropout) 485 | self.down_proj = apply_wd(zero_init(Linear(d_ff, d_model, bias=False))) 486 | 487 | def forward(self, x, cond): 488 | skip = x 489 | x = self.norm(x, cond) 490 | x = self.up_proj(x) 491 | x = self.dropout(x) 492 | x = self.down_proj(x) 493 | return x + skip 494 | 495 | 496 | class GlobalTransformerLayer(nn.Module): 497 | def __init__(self, d_model, d_ff, d_head, cond_features, dropout=0.0): 498 | super().__init__() 499 | self.self_attn = SelfAttentionBlock(d_model, d_head, cond_features, dropout=dropout) 500 | self.ff = FeedForwardBlock(d_model, d_ff, cond_features, dropout=dropout) 501 | 502 | def forward(self, x, pos, cond): 503 | x = checkpoint(self.self_attn, x, pos, cond) 504 | x = checkpoint(self.ff, x, cond) 505 | return x 506 | 507 | 508 | class NeighborhoodTransformerLayer(nn.Module): 509 | def __init__(self, d_model, d_ff, d_head, cond_features, kernel_size, dropout=0.0): 510 | super().__init__() 511 | self.self_attn = NeighborhoodSelfAttentionBlock(d_model, d_head, cond_features, kernel_size, dropout=dropout) 512 | self.ff = FeedForwardBlock(d_model, d_ff, cond_features, dropout=dropout) 513 | 514 | def forward(self, x, pos, cond): 515 | x = checkpoint(self.self_attn, x, pos, cond) 516 | x = checkpoint(self.ff, x, cond) 517 | return x 518 | 519 | 520 | class ShiftedWindowTransformerLayer(nn.Module): 521 | def __init__(self, d_model, d_ff, d_head, cond_features, window_size, index, dropout=0.0): 522 | super().__init__() 523 | window_shift = window_size // 2 if index % 2 == 1 else 0 524 | self.self_attn = ShiftedWindowSelfAttentionBlock(d_model, d_head, cond_features, window_size, window_shift, dropout=dropout) 525 | self.ff = FeedForwardBlock(d_model, d_ff, cond_features, dropout=dropout) 526 | 527 | def forward(self, x, pos, cond): 528 | x = checkpoint(self.self_attn, x, pos, cond) 529 | x = checkpoint(self.ff, x, cond) 530 | return x 531 | 532 | 533 | class NoAttentionTransformerLayer(nn.Module): 534 | def __init__(self, d_model, d_ff, cond_features, dropout=0.0): 535 | super().__init__() 536 | self.ff = FeedForwardBlock(d_model, d_ff, cond_features, dropout=dropout) 537 | 538 | def forward(self, x, pos, cond): 539 | x = checkpoint(self.ff, x, cond) 540 | return x 541 | 542 | 543 | class Level(nn.ModuleList): 544 | def forward(self, x, *args, **kwargs): 545 | for layer in self: 546 | x = layer(x, *args, **kwargs) 547 | return x 548 | 549 | 550 | # Mapping network 551 | 552 | class MappingFeedForwardBlock(nn.Module): 553 | def __init__(self, d_model, d_ff, dropout=0.0): 554 | super().__init__() 555 | self.norm = RMSNorm(d_model) 556 | self.up_proj = apply_wd(LinearGEGLU(d_model, d_ff, bias=False)) 557 | self.dropout = nn.Dropout(dropout) 558 | self.down_proj = apply_wd(zero_init(Linear(d_ff, d_model, bias=False))) 559 | 560 | def forward(self, x): 561 | skip = x 562 | x = self.norm(x) 563 | x = self.up_proj(x) 564 | x = self.dropout(x) 565 | x = self.down_proj(x) 566 | return x + skip 567 | 568 | 569 | class MappingNetwork(nn.Module): 570 | def __init__(self, n_layers, d_model, d_ff, dropout=0.0): 571 | super().__init__() 572 | self.in_norm = RMSNorm(d_model) 573 | self.blocks = nn.ModuleList([MappingFeedForwardBlock(d_model, d_ff, dropout=dropout) for _ in range(n_layers)]) 574 | self.out_norm = RMSNorm(d_model) 575 | 576 | def forward(self, x): 577 | x = self.in_norm(x) 578 | for block in self.blocks: 579 | x = block(x) 580 | x = self.out_norm(x) 581 | return x 582 | 583 | 584 | # Token merging and splitting 585 | 586 | class TokenMerge(nn.Module): 587 | def __init__(self, in_features, out_features, patch_size=(2, 2)): 588 | super().__init__() 589 | self.h = patch_size[0] 590 | self.w = patch_size[1] 591 | self.proj = apply_wd(Linear(in_features * self.h * self.w, out_features, bias=False)) 592 | 593 | def forward(self, x): 594 | x = rearrange(x, "... (h nh) (w nw) e -> ... h w (nh nw e)", nh=self.h, nw=self.w) 595 | return self.proj(x) 596 | 597 | 598 | class TokenSplitWithoutSkip(nn.Module): 599 | def __init__(self, in_features, out_features, patch_size=(2, 2)): 600 | super().__init__() 601 | self.h = patch_size[0] 602 | self.w = patch_size[1] 603 | self.proj = apply_wd(Linear(in_features, out_features * self.h * self.w, bias=False)) 604 | 605 | def forward(self, x): 606 | x = self.proj(x) 607 | return rearrange(x, "... h w (nh nw e) -> ... (h nh) (w nw) e", nh=self.h, nw=self.w) 608 | 609 | 610 | class TokenSplit(nn.Module): 611 | def __init__(self, in_features, out_features, patch_size=(2, 2)): 612 | super().__init__() 613 | self.h = patch_size[0] 614 | self.w = patch_size[1] 615 | self.proj = apply_wd(Linear(in_features, out_features * self.h * self.w, bias=False)) 616 | self.fac = nn.Parameter(torch.ones(1) * 0.5) 617 | 618 | def forward(self, x, skip): 619 | x = self.proj(x) 620 | x = rearrange(x, "... h w (nh nw e) -> ... (h nh) (w nw) e", nh=self.h, nw=self.w) 621 | return torch.lerp(skip, x, self.fac.to(x.dtype)) 622 | 623 | 624 | # Configuration 625 | 626 | @dataclass 627 | class GlobalAttentionSpec: 628 | d_head: int 629 | 630 | 631 | @dataclass 632 | class NeighborhoodAttentionSpec: 633 | d_head: int 634 | kernel_size: int 635 | 636 | 637 | @dataclass 638 | class ShiftedWindowAttentionSpec: 639 | d_head: int 640 | window_size: int 641 | 642 | 643 | @dataclass 644 | class NoAttentionSpec: 645 | pass 646 | 647 | 648 | @dataclass 649 | class LevelSpec: 650 | depth: int 651 | width: int 652 | d_ff: int 653 | self_attn: Union[GlobalAttentionSpec, NeighborhoodAttentionSpec, ShiftedWindowAttentionSpec, NoAttentionSpec] 654 | dropout: float 655 | 656 | 657 | @dataclass 658 | class MappingSpec: 659 | depth: int 660 | width: int 661 | d_ff: int 662 | dropout: float 663 | 664 | 665 | # Model class 666 | 667 | class ImageTransformerDenoiserModelV2(nn.Module): 668 | def __init__(self, levels, mapping, in_channels, out_channels, patch_size, num_classes=0, mapping_cond_dim=0): 669 | super().__init__() 670 | self.num_classes = num_classes 671 | 672 | self.patch_in = TokenMerge(in_channels, levels[0].width, patch_size) 673 | 674 | self.time_emb = layers.FourierFeatures(1, mapping.width) 675 | self.time_in_proj = Linear(mapping.width, mapping.width, bias=False) 676 | self.aug_emb = layers.FourierFeatures(9, mapping.width) 677 | self.aug_in_proj = Linear(mapping.width, mapping.width, bias=False) 678 | self.class_emb = nn.Embedding(num_classes, mapping.width) if num_classes else None 679 | self.mapping_cond_in_proj = Linear(mapping_cond_dim, mapping.width, bias=False) if mapping_cond_dim else None 680 | self.mapping = tag_module(MappingNetwork(mapping.depth, mapping.width, mapping.d_ff, dropout=mapping.dropout), "mapping") 681 | 682 | self.down_levels, self.up_levels = nn.ModuleList(), nn.ModuleList() 683 | for i, spec in enumerate(levels): 684 | if isinstance(spec.self_attn, GlobalAttentionSpec): 685 | layer_factory = lambda _: GlobalTransformerLayer(spec.width, spec.d_ff, spec.self_attn.d_head, mapping.width, dropout=spec.dropout) 686 | elif isinstance(spec.self_attn, NeighborhoodAttentionSpec): 687 | layer_factory = lambda _: NeighborhoodTransformerLayer(spec.width, spec.d_ff, spec.self_attn.d_head, mapping.width, spec.self_attn.kernel_size, dropout=spec.dropout) 688 | elif isinstance(spec.self_attn, ShiftedWindowAttentionSpec): 689 | layer_factory = lambda i: ShiftedWindowTransformerLayer(spec.width, spec.d_ff, spec.self_attn.d_head, mapping.width, spec.self_attn.window_size, i, dropout=spec.dropout) 690 | elif isinstance(spec.self_attn, NoAttentionSpec): 691 | layer_factory = lambda _: NoAttentionTransformerLayer(spec.width, spec.d_ff, mapping.width, dropout=spec.dropout) 692 | else: 693 | raise ValueError(f"unsupported self attention spec {spec.self_attn}") 694 | 695 | if i < len(levels) - 1: 696 | self.down_levels.append(Level([layer_factory(i) for i in range(spec.depth)])) 697 | self.up_levels.append(Level([layer_factory(i + spec.depth) for i in range(spec.depth)])) 698 | else: 699 | self.mid_level = Level([layer_factory(i) for i in range(spec.depth)]) 700 | 701 | self.merges = nn.ModuleList([TokenMerge(spec_1.width, spec_2.width) for spec_1, spec_2 in zip(levels[:-1], levels[1:])]) 702 | self.splits = nn.ModuleList([TokenSplit(spec_2.width, spec_1.width) for spec_1, spec_2 in zip(levels[:-1], levels[1:])]) 703 | 704 | self.out_norm = RMSNorm(levels[0].width) 705 | self.patch_out = TokenSplitWithoutSkip(levels[0].width, out_channels, patch_size) 706 | nn.init.zeros_(self.patch_out.proj.weight) 707 | 708 | def param_groups(self, base_lr=5e-4, mapping_lr_scale=1 / 3): 709 | wd = filter_params(lambda tags: "wd" in tags and "mapping" not in tags, self) 710 | no_wd = filter_params(lambda tags: "wd" not in tags and "mapping" not in tags, self) 711 | mapping_wd = filter_params(lambda tags: "wd" in tags and "mapping" in tags, self) 712 | mapping_no_wd = filter_params(lambda tags: "wd" not in tags and "mapping" in tags, self) 713 | groups = [ 714 | {"params": list(wd), "lr": base_lr}, 715 | {"params": list(no_wd), "lr": base_lr, "weight_decay": 0.0}, 716 | {"params": list(mapping_wd), "lr": base_lr * mapping_lr_scale}, 717 | {"params": list(mapping_no_wd), "lr": base_lr * mapping_lr_scale, "weight_decay": 0.0} 718 | ] 719 | return groups 720 | 721 | def forward(self, x, sigma, aug_cond=None, class_cond=None, mapping_cond=None): 722 | # Patching 723 | x = x.movedim(-3, -1) 724 | x = self.patch_in(x) 725 | # TODO: pixel aspect ratio for nonsquare patches 726 | pos = make_axial_pos(x.shape[-3], x.shape[-2], device=x.device).view(x.shape[-3], x.shape[-2], 2) 727 | 728 | # Mapping network 729 | if class_cond is None and self.class_emb is not None: 730 | raise ValueError("class_cond must be specified if num_classes > 0") 731 | if mapping_cond is None and self.mapping_cond_in_proj is not None: 732 | raise ValueError("mapping_cond must be specified if mapping_cond_dim > 0") 733 | 734 | c_noise = torch.log(sigma) / 4 735 | time_emb = self.time_in_proj(self.time_emb(c_noise[..., None])) 736 | aug_cond = x.new_zeros([x.shape[0], 9]) if aug_cond is None else aug_cond 737 | aug_emb = self.aug_in_proj(self.aug_emb(aug_cond)) 738 | class_emb = self.class_emb(class_cond) if self.class_emb is not None else 0 739 | mapping_emb = self.mapping_cond_in_proj(mapping_cond) if self.mapping_cond_in_proj is not None else 0 740 | cond = self.mapping(time_emb + aug_emb + class_emb + mapping_emb) 741 | 742 | # Hourglass transformer 743 | skips, poses = [], [] 744 | for down_level, merge in zip(self.down_levels, self.merges): 745 | x = down_level(x, pos, cond) 746 | skips.append(x) 747 | poses.append(pos) 748 | x = merge(x) 749 | pos = downscale_pos(pos) 750 | 751 | x = self.mid_level(x, pos, cond) 752 | 753 | for up_level, split, skip, pos in reversed(list(zip(self.up_levels, self.splits, skips, poses))): 754 | x = split(x, skip) 755 | x = up_level(x, pos, cond) 756 | 757 | # Unpatching 758 | x = self.out_norm(x) 759 | x = self.patch_out(x) 760 | x = x.movedim(-1, -3) 761 | 762 | return x 763 | -------------------------------------------------------------------------------- /k_diffusion/sampling.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from scipy import integrate 4 | import torch 5 | from torch import nn 6 | from torchdiffeq import odeint 7 | import torchsde 8 | from tqdm.auto import trange, tqdm 9 | 10 | from . import utils 11 | 12 | 13 | def append_zero(x): 14 | return torch.cat([x, x.new_zeros([1])]) 15 | 16 | 17 | def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'): 18 | """Constructs the noise schedule of Karras et al. (2022).""" 19 | ramp = torch.linspace(0, 1, n) 20 | min_inv_rho = sigma_min ** (1 / rho) 21 | max_inv_rho = sigma_max ** (1 / rho) 22 | sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho 23 | return append_zero(sigmas).to(device) 24 | 25 | 26 | def get_sigmas_exponential(n, sigma_min, sigma_max, device='cpu'): 27 | """Constructs an exponential noise schedule.""" 28 | sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), n, device=device).exp() 29 | return append_zero(sigmas) 30 | 31 | 32 | def get_sigmas_polyexponential(n, sigma_min, sigma_max, rho=1., device='cpu'): 33 | """Constructs an polynomial in log sigma noise schedule.""" 34 | ramp = torch.linspace(1, 0, n, device=device) ** rho 35 | sigmas = torch.exp(ramp * (math.log(sigma_max) - math.log(sigma_min)) + math.log(sigma_min)) 36 | return append_zero(sigmas) 37 | 38 | 39 | def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'): 40 | """Constructs a continuous VP noise schedule.""" 41 | t = torch.linspace(1, eps_s, n, device=device) 42 | sigmas = torch.sqrt(torch.exp(beta_d * t ** 2 / 2 + beta_min * t) - 1) 43 | return append_zero(sigmas) 44 | 45 | 46 | def to_d(x, sigma, denoised): 47 | """Converts a denoiser output to a Karras ODE derivative.""" 48 | return (x - denoised) / utils.append_dims(sigma, x.ndim) 49 | 50 | 51 | def get_ancestral_step(sigma_from, sigma_to, eta=1.): 52 | """Calculates the noise level (sigma_down) to step down to and the amount 53 | of noise to add (sigma_up) when doing an ancestral sampling step.""" 54 | if not eta: 55 | return sigma_to, 0. 56 | sigma_up = min(sigma_to, eta * (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5) 57 | sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5 58 | return sigma_down, sigma_up 59 | 60 | 61 | def default_noise_sampler(x): 62 | return lambda sigma, sigma_next: torch.randn_like(x) 63 | 64 | 65 | class BatchedBrownianTree: 66 | """A wrapper around torchsde.BrownianTree that enables batches of entropy.""" 67 | 68 | def __init__(self, x, t0, t1, seed=None, **kwargs): 69 | t0, t1, self.sign = self.sort(t0, t1) 70 | w0 = kwargs.get('w0', torch.zeros_like(x)) 71 | if seed is None: 72 | seed = torch.randint(0, 2 ** 63 - 1, []).item() 73 | self.batched = True 74 | try: 75 | assert len(seed) == x.shape[0] 76 | w0 = w0[0] 77 | except TypeError: 78 | seed = [seed] 79 | self.batched = False 80 | self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed] 81 | 82 | @staticmethod 83 | def sort(a, b): 84 | return (a, b, 1) if a < b else (b, a, -1) 85 | 86 | def __call__(self, t0, t1): 87 | t0, t1, sign = self.sort(t0, t1) 88 | w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign) 89 | return w if self.batched else w[0] 90 | 91 | 92 | class BrownianTreeNoiseSampler: 93 | """A noise sampler backed by a torchsde.BrownianTree. 94 | 95 | Args: 96 | x (Tensor): The tensor whose shape, device and dtype to use to generate 97 | random samples. 98 | sigma_min (float): The low end of the valid interval. 99 | sigma_max (float): The high end of the valid interval. 100 | seed (int or List[int]): The random seed. If a list of seeds is 101 | supplied instead of a single integer, then the noise sampler will 102 | use one BrownianTree per batch item, each with its own seed. 103 | transform (callable): A function that maps sigma to the sampler's 104 | internal timestep. 105 | """ 106 | 107 | def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x): 108 | self.transform = transform 109 | t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max)) 110 | self.tree = BatchedBrownianTree(x, t0, t1, seed) 111 | 112 | def __call__(self, sigma, sigma_next): 113 | t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next)) 114 | return self.tree(t0, t1) / (t1 - t0).abs().sqrt() 115 | 116 | 117 | @torch.no_grad() 118 | def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): 119 | """Implements Algorithm 2 (Euler steps) from Karras et al. (2022).""" 120 | extra_args = {} if extra_args is None else extra_args 121 | s_in = x.new_ones([x.shape[0]]) 122 | for i in trange(len(sigmas) - 1, disable=disable): 123 | gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. 124 | eps = torch.randn_like(x) * s_noise 125 | sigma_hat = sigmas[i] * (gamma + 1) 126 | if gamma > 0: 127 | x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 128 | denoised = model(x, sigma_hat * s_in, **extra_args) 129 | d = to_d(x, sigma_hat, denoised) 130 | if callback is not None: 131 | callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) 132 | dt = sigmas[i + 1] - sigma_hat 133 | # Euler method 134 | x = x + d * dt 135 | return x 136 | 137 | 138 | @torch.no_grad() 139 | def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): 140 | """Ancestral sampling with Euler method steps.""" 141 | extra_args = {} if extra_args is None else extra_args 142 | noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler 143 | s_in = x.new_ones([x.shape[0]]) 144 | for i in trange(len(sigmas) - 1, disable=disable): 145 | denoised = model(x, sigmas[i] * s_in, **extra_args) 146 | sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) 147 | if callback is not None: 148 | callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) 149 | d = to_d(x, sigmas[i], denoised) 150 | # Euler method 151 | dt = sigma_down - sigmas[i] 152 | x = x + d * dt 153 | if sigmas[i + 1] > 0: 154 | x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up 155 | return x 156 | 157 | 158 | @torch.no_grad() 159 | def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): 160 | """Implements Algorithm 2 (Heun steps) from Karras et al. (2022).""" 161 | extra_args = {} if extra_args is None else extra_args 162 | s_in = x.new_ones([x.shape[0]]) 163 | for i in trange(len(sigmas) - 1, disable=disable): 164 | gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. 165 | eps = torch.randn_like(x) * s_noise 166 | sigma_hat = sigmas[i] * (gamma + 1) 167 | if gamma > 0: 168 | x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 169 | denoised = model(x, sigma_hat * s_in, **extra_args) 170 | d = to_d(x, sigma_hat, denoised) 171 | if callback is not None: 172 | callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) 173 | dt = sigmas[i + 1] - sigma_hat 174 | if sigmas[i + 1] == 0: 175 | # Euler method 176 | x = x + d * dt 177 | else: 178 | # Heun's method 179 | x_2 = x + d * dt 180 | denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args) 181 | d_2 = to_d(x_2, sigmas[i + 1], denoised_2) 182 | d_prime = (d + d_2) / 2 183 | x = x + d_prime * dt 184 | return x 185 | 186 | 187 | @torch.no_grad() 188 | def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): 189 | """A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022).""" 190 | extra_args = {} if extra_args is None else extra_args 191 | s_in = x.new_ones([x.shape[0]]) 192 | for i in trange(len(sigmas) - 1, disable=disable): 193 | gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. 194 | eps = torch.randn_like(x) * s_noise 195 | sigma_hat = sigmas[i] * (gamma + 1) 196 | if gamma > 0: 197 | x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 198 | denoised = model(x, sigma_hat * s_in, **extra_args) 199 | d = to_d(x, sigma_hat, denoised) 200 | if callback is not None: 201 | callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) 202 | if sigmas[i + 1] == 0: 203 | # Euler method 204 | dt = sigmas[i + 1] - sigma_hat 205 | x = x + d * dt 206 | else: 207 | # DPM-Solver-2 208 | sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp() 209 | dt_1 = sigma_mid - sigma_hat 210 | dt_2 = sigmas[i + 1] - sigma_hat 211 | x_2 = x + d * dt_1 212 | denoised_2 = model(x_2, sigma_mid * s_in, **extra_args) 213 | d_2 = to_d(x_2, sigma_mid, denoised_2) 214 | x = x + d_2 * dt_2 215 | return x 216 | 217 | 218 | @torch.no_grad() 219 | def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): 220 | """Ancestral sampling with DPM-Solver second-order steps.""" 221 | extra_args = {} if extra_args is None else extra_args 222 | noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler 223 | s_in = x.new_ones([x.shape[0]]) 224 | for i in trange(len(sigmas) - 1, disable=disable): 225 | denoised = model(x, sigmas[i] * s_in, **extra_args) 226 | sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) 227 | if callback is not None: 228 | callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) 229 | d = to_d(x, sigmas[i], denoised) 230 | if sigma_down == 0: 231 | # Euler method 232 | dt = sigma_down - sigmas[i] 233 | x = x + d * dt 234 | else: 235 | # DPM-Solver-2 236 | sigma_mid = sigmas[i].log().lerp(sigma_down.log(), 0.5).exp() 237 | dt_1 = sigma_mid - sigmas[i] 238 | dt_2 = sigma_down - sigmas[i] 239 | x_2 = x + d * dt_1 240 | denoised_2 = model(x_2, sigma_mid * s_in, **extra_args) 241 | d_2 = to_d(x_2, sigma_mid, denoised_2) 242 | x = x + d_2 * dt_2 243 | x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up 244 | return x 245 | 246 | 247 | def linear_multistep_coeff(order, t, i, j): 248 | if order - 1 > i: 249 | raise ValueError(f'Order {order} too high for step {i}') 250 | def fn(tau): 251 | prod = 1. 252 | for k in range(order): 253 | if j == k: 254 | continue 255 | prod *= (tau - t[i - k]) / (t[i - j] - t[i - k]) 256 | return prod 257 | return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0] 258 | 259 | 260 | @torch.no_grad() 261 | def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4): 262 | extra_args = {} if extra_args is None else extra_args 263 | s_in = x.new_ones([x.shape[0]]) 264 | sigmas_cpu = sigmas.detach().cpu().numpy() 265 | ds = [] 266 | for i in trange(len(sigmas) - 1, disable=disable): 267 | denoised = model(x, sigmas[i] * s_in, **extra_args) 268 | d = to_d(x, sigmas[i], denoised) 269 | ds.append(d) 270 | if len(ds) > order: 271 | ds.pop(0) 272 | if callback is not None: 273 | callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) 274 | cur_order = min(i + 1, order) 275 | coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)] 276 | x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds))) 277 | return x 278 | 279 | 280 | @torch.no_grad() 281 | def log_likelihood(model, x, sigma_min, sigma_max, extra_args=None, atol=1e-4, rtol=1e-4): 282 | extra_args = {} if extra_args is None else extra_args 283 | s_in = x.new_ones([x.shape[0]]) 284 | v = torch.randint_like(x, 2) * 2 - 1 285 | fevals = 0 286 | def ode_fn(sigma, x): 287 | nonlocal fevals 288 | with torch.enable_grad(): 289 | x = x[0].detach().requires_grad_() 290 | denoised = model(x, sigma * s_in, **extra_args) 291 | d = to_d(x, sigma, denoised) 292 | fevals += 1 293 | grad = torch.autograd.grad((d * v).sum(), x)[0] 294 | d_ll = (v * grad).flatten(1).sum(1) 295 | return d.detach(), d_ll 296 | x_min = x, x.new_zeros([x.shape[0]]) 297 | t = x.new_tensor([sigma_min, sigma_max]) 298 | sol = odeint(ode_fn, x_min, t, atol=atol, rtol=rtol, method='dopri5') 299 | latent, delta_ll = sol[0][-1], sol[1][-1] 300 | ll_prior = torch.distributions.Normal(0, sigma_max).log_prob(latent).flatten(1).sum(1) 301 | return ll_prior + delta_ll, {'fevals': fevals} 302 | 303 | 304 | class PIDStepSizeController: 305 | """A PID controller for ODE adaptive step size control.""" 306 | def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8): 307 | self.h = h 308 | self.b1 = (pcoeff + icoeff + dcoeff) / order 309 | self.b2 = -(pcoeff + 2 * dcoeff) / order 310 | self.b3 = dcoeff / order 311 | self.accept_safety = accept_safety 312 | self.eps = eps 313 | self.errs = [] 314 | 315 | def limiter(self, x): 316 | return 1 + math.atan(x - 1) 317 | 318 | def propose_step(self, error): 319 | inv_error = 1 / (float(error) + self.eps) 320 | if not self.errs: 321 | self.errs = [inv_error, inv_error, inv_error] 322 | self.errs[0] = inv_error 323 | factor = self.errs[0] ** self.b1 * self.errs[1] ** self.b2 * self.errs[2] ** self.b3 324 | factor = self.limiter(factor) 325 | accept = factor >= self.accept_safety 326 | if accept: 327 | self.errs[2] = self.errs[1] 328 | self.errs[1] = self.errs[0] 329 | self.h *= factor 330 | return accept 331 | 332 | 333 | class DPMSolver(nn.Module): 334 | """DPM-Solver. See https://arxiv.org/abs/2206.00927.""" 335 | 336 | def __init__(self, model, extra_args=None, eps_callback=None, info_callback=None): 337 | super().__init__() 338 | self.model = model 339 | self.extra_args = {} if extra_args is None else extra_args 340 | self.eps_callback = eps_callback 341 | self.info_callback = info_callback 342 | 343 | def t(self, sigma): 344 | return -sigma.log() 345 | 346 | def sigma(self, t): 347 | return t.neg().exp() 348 | 349 | def eps(self, eps_cache, key, x, t, *args, **kwargs): 350 | if key in eps_cache: 351 | return eps_cache[key], eps_cache 352 | sigma = self.sigma(t) * x.new_ones([x.shape[0]]) 353 | eps = (x - self.model(x, sigma, *args, **self.extra_args, **kwargs)) / self.sigma(t) 354 | if self.eps_callback is not None: 355 | self.eps_callback() 356 | return eps, {key: eps, **eps_cache} 357 | 358 | def dpm_solver_1_step(self, x, t, t_next, eps_cache=None): 359 | eps_cache = {} if eps_cache is None else eps_cache 360 | h = t_next - t 361 | eps, eps_cache = self.eps(eps_cache, 'eps', x, t) 362 | x_1 = x - self.sigma(t_next) * h.expm1() * eps 363 | return x_1, eps_cache 364 | 365 | def dpm_solver_2_step(self, x, t, t_next, r1=1 / 2, eps_cache=None): 366 | eps_cache = {} if eps_cache is None else eps_cache 367 | h = t_next - t 368 | eps, eps_cache = self.eps(eps_cache, 'eps', x, t) 369 | s1 = t + r1 * h 370 | u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps 371 | eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1) 372 | x_2 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / (2 * r1) * h.expm1() * (eps_r1 - eps) 373 | return x_2, eps_cache 374 | 375 | def dpm_solver_3_step(self, x, t, t_next, r1=1 / 3, r2=2 / 3, eps_cache=None): 376 | eps_cache = {} if eps_cache is None else eps_cache 377 | h = t_next - t 378 | eps, eps_cache = self.eps(eps_cache, 'eps', x, t) 379 | s1 = t + r1 * h 380 | s2 = t + r2 * h 381 | u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps 382 | eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1) 383 | u2 = x - self.sigma(s2) * (r2 * h).expm1() * eps - self.sigma(s2) * (r2 / r1) * ((r2 * h).expm1() / (r2 * h) - 1) * (eps_r1 - eps) 384 | eps_r2, eps_cache = self.eps(eps_cache, 'eps_r2', u2, s2) 385 | x_3 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / r2 * (h.expm1() / h - 1) * (eps_r2 - eps) 386 | return x_3, eps_cache 387 | 388 | def dpm_solver_fast(self, x, t_start, t_end, nfe, eta=0., s_noise=1., noise_sampler=None): 389 | noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler 390 | if not t_end > t_start and eta: 391 | raise ValueError('eta must be 0 for reverse sampling') 392 | 393 | m = math.floor(nfe / 3) + 1 394 | ts = torch.linspace(t_start, t_end, m + 1, device=x.device) 395 | 396 | if nfe % 3 == 0: 397 | orders = [3] * (m - 2) + [2, 1] 398 | else: 399 | orders = [3] * (m - 1) + [nfe % 3] 400 | 401 | for i in range(len(orders)): 402 | eps_cache = {} 403 | t, t_next = ts[i], ts[i + 1] 404 | if eta: 405 | sd, su = get_ancestral_step(self.sigma(t), self.sigma(t_next), eta) 406 | t_next_ = torch.minimum(t_end, self.t(sd)) 407 | su = (self.sigma(t_next) ** 2 - self.sigma(t_next_) ** 2) ** 0.5 408 | else: 409 | t_next_, su = t_next, 0. 410 | 411 | eps, eps_cache = self.eps(eps_cache, 'eps', x, t) 412 | denoised = x - self.sigma(t) * eps 413 | if self.info_callback is not None: 414 | self.info_callback({'x': x, 'i': i, 't': ts[i], 't_up': t, 'denoised': denoised}) 415 | 416 | if orders[i] == 1: 417 | x, eps_cache = self.dpm_solver_1_step(x, t, t_next_, eps_cache=eps_cache) 418 | elif orders[i] == 2: 419 | x, eps_cache = self.dpm_solver_2_step(x, t, t_next_, eps_cache=eps_cache) 420 | else: 421 | x, eps_cache = self.dpm_solver_3_step(x, t, t_next_, eps_cache=eps_cache) 422 | 423 | x = x + su * s_noise * noise_sampler(self.sigma(t), self.sigma(t_next)) 424 | 425 | return x 426 | 427 | def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None): 428 | noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler 429 | if order not in {2, 3}: 430 | raise ValueError('order should be 2 or 3') 431 | forward = t_end > t_start 432 | if not forward and eta: 433 | raise ValueError('eta must be 0 for reverse sampling') 434 | h_init = abs(h_init) * (1 if forward else -1) 435 | atol = torch.tensor(atol) 436 | rtol = torch.tensor(rtol) 437 | s = t_start 438 | x_prev = x 439 | accept = True 440 | pid = PIDStepSizeController(h_init, pcoeff, icoeff, dcoeff, 1.5 if eta else order, accept_safety) 441 | info = {'steps': 0, 'nfe': 0, 'n_accept': 0, 'n_reject': 0} 442 | 443 | while s < t_end - 1e-5 if forward else s > t_end + 1e-5: 444 | eps_cache = {} 445 | t = torch.minimum(t_end, s + pid.h) if forward else torch.maximum(t_end, s + pid.h) 446 | if eta: 447 | sd, su = get_ancestral_step(self.sigma(s), self.sigma(t), eta) 448 | t_ = torch.minimum(t_end, self.t(sd)) 449 | su = (self.sigma(t) ** 2 - self.sigma(t_) ** 2) ** 0.5 450 | else: 451 | t_, su = t, 0. 452 | 453 | eps, eps_cache = self.eps(eps_cache, 'eps', x, s) 454 | denoised = x - self.sigma(s) * eps 455 | 456 | if order == 2: 457 | x_low, eps_cache = self.dpm_solver_1_step(x, s, t_, eps_cache=eps_cache) 458 | x_high, eps_cache = self.dpm_solver_2_step(x, s, t_, eps_cache=eps_cache) 459 | else: 460 | x_low, eps_cache = self.dpm_solver_2_step(x, s, t_, r1=1 / 3, eps_cache=eps_cache) 461 | x_high, eps_cache = self.dpm_solver_3_step(x, s, t_, eps_cache=eps_cache) 462 | delta = torch.maximum(atol, rtol * torch.maximum(x_low.abs(), x_prev.abs())) 463 | error = torch.linalg.norm((x_low - x_high) / delta) / x.numel() ** 0.5 464 | accept = pid.propose_step(error) 465 | if accept: 466 | x_prev = x_low 467 | x = x_high + su * s_noise * noise_sampler(self.sigma(s), self.sigma(t)) 468 | s = t 469 | info['n_accept'] += 1 470 | else: 471 | info['n_reject'] += 1 472 | info['nfe'] += order 473 | info['steps'] += 1 474 | 475 | if self.info_callback is not None: 476 | self.info_callback({'x': x, 'i': info['steps'] - 1, 't': s, 't_up': s, 'denoised': denoised, 'error': error, 'h': pid.h, **info}) 477 | 478 | return x, info 479 | 480 | 481 | @torch.no_grad() 482 | def sample_dpm_fast(model, x, sigma_min, sigma_max, n, extra_args=None, callback=None, disable=None, eta=0., s_noise=1., noise_sampler=None): 483 | """DPM-Solver-Fast (fixed step size). See https://arxiv.org/abs/2206.00927.""" 484 | if sigma_min <= 0 or sigma_max <= 0: 485 | raise ValueError('sigma_min and sigma_max must not be 0') 486 | with tqdm(total=n, disable=disable) as pbar: 487 | dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update) 488 | if callback is not None: 489 | dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info}) 490 | return dpm_solver.dpm_solver_fast(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), n, eta, s_noise, noise_sampler) 491 | 492 | 493 | @torch.no_grad() 494 | def sample_dpm_adaptive(model, x, sigma_min, sigma_max, extra_args=None, callback=None, disable=None, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None, return_info=False): 495 | """DPM-Solver-12 and 23 (adaptive step size). See https://arxiv.org/abs/2206.00927.""" 496 | if sigma_min <= 0 or sigma_max <= 0: 497 | raise ValueError('sigma_min and sigma_max must not be 0') 498 | with tqdm(disable=disable) as pbar: 499 | dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update) 500 | if callback is not None: 501 | dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info}) 502 | x, info = dpm_solver.dpm_solver_adaptive(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise, noise_sampler) 503 | if return_info: 504 | return x, info 505 | return x 506 | 507 | 508 | @torch.no_grad() 509 | def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): 510 | """Ancestral sampling with DPM-Solver++(2S) second-order steps.""" 511 | extra_args = {} if extra_args is None else extra_args 512 | noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler 513 | s_in = x.new_ones([x.shape[0]]) 514 | sigma_fn = lambda t: t.neg().exp() 515 | t_fn = lambda sigma: sigma.log().neg() 516 | 517 | for i in trange(len(sigmas) - 1, disable=disable): 518 | denoised = model(x, sigmas[i] * s_in, **extra_args) 519 | sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) 520 | if callback is not None: 521 | callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) 522 | if sigma_down == 0: 523 | # Euler method 524 | d = to_d(x, sigmas[i], denoised) 525 | dt = sigma_down - sigmas[i] 526 | x = x + d * dt 527 | else: 528 | # DPM-Solver++(2S) 529 | t, t_next = t_fn(sigmas[i]), t_fn(sigma_down) 530 | r = 1 / 2 531 | h = t_next - t 532 | s = t + r * h 533 | x_2 = (sigma_fn(s) / sigma_fn(t)) * x - (-h * r).expm1() * denoised 534 | denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args) 535 | x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_2 536 | # Noise addition 537 | if sigmas[i + 1] > 0: 538 | x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up 539 | return x 540 | 541 | 542 | @torch.no_grad() 543 | def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2): 544 | """DPM-Solver++ (stochastic).""" 545 | sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() 546 | noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler 547 | extra_args = {} if extra_args is None else extra_args 548 | s_in = x.new_ones([x.shape[0]]) 549 | sigma_fn = lambda t: t.neg().exp() 550 | t_fn = lambda sigma: sigma.log().neg() 551 | 552 | for i in trange(len(sigmas) - 1, disable=disable): 553 | denoised = model(x, sigmas[i] * s_in, **extra_args) 554 | if callback is not None: 555 | callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) 556 | if sigmas[i + 1] == 0: 557 | # Euler method 558 | d = to_d(x, sigmas[i], denoised) 559 | dt = sigmas[i + 1] - sigmas[i] 560 | x = x + d * dt 561 | else: 562 | # DPM-Solver++ 563 | t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1]) 564 | h = t_next - t 565 | s = t + h * r 566 | fac = 1 / (2 * r) 567 | 568 | # Step 1 569 | sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(s), eta) 570 | s_ = t_fn(sd) 571 | x_2 = (sigma_fn(s_) / sigma_fn(t)) * x - (t - s_).expm1() * denoised 572 | x_2 = x_2 + noise_sampler(sigma_fn(t), sigma_fn(s)) * s_noise * su 573 | denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args) 574 | 575 | # Step 2 576 | sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(t_next), eta) 577 | t_next_ = t_fn(sd) 578 | denoised_d = (1 - fac) * denoised + fac * denoised_2 579 | x = (sigma_fn(t_next_) / sigma_fn(t)) * x - (t - t_next_).expm1() * denoised_d 580 | x = x + noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * su 581 | return x 582 | 583 | 584 | @torch.no_grad() 585 | def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None): 586 | """DPM-Solver++(2M).""" 587 | extra_args = {} if extra_args is None else extra_args 588 | s_in = x.new_ones([x.shape[0]]) 589 | sigma_fn = lambda t: t.neg().exp() 590 | t_fn = lambda sigma: sigma.log().neg() 591 | old_denoised = None 592 | 593 | for i in trange(len(sigmas) - 1, disable=disable): 594 | denoised = model(x, sigmas[i] * s_in, **extra_args) 595 | if callback is not None: 596 | callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) 597 | t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1]) 598 | h = t_next - t 599 | if old_denoised is None or sigmas[i + 1] == 0: 600 | x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised 601 | else: 602 | h_last = t - t_fn(sigmas[i - 1]) 603 | r = h_last / h 604 | denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised 605 | x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d 606 | old_denoised = denoised 607 | return x 608 | 609 | 610 | @torch.no_grad() 611 | def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'): 612 | """DPM-Solver++(2M) SDE.""" 613 | 614 | if solver_type not in {'heun', 'midpoint'}: 615 | raise ValueError('solver_type must be \'heun\' or \'midpoint\'') 616 | 617 | sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() 618 | noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler 619 | extra_args = {} if extra_args is None else extra_args 620 | s_in = x.new_ones([x.shape[0]]) 621 | 622 | old_denoised = None 623 | h_last = None 624 | 625 | for i in trange(len(sigmas) - 1, disable=disable): 626 | denoised = model(x, sigmas[i] * s_in, **extra_args) 627 | if callback is not None: 628 | callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) 629 | if sigmas[i + 1] == 0: 630 | # Denoising step 631 | x = denoised 632 | else: 633 | # DPM-Solver++(2M) SDE 634 | t, s = -sigmas[i].log(), -sigmas[i + 1].log() 635 | h = s - t 636 | eta_h = eta * h 637 | 638 | x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised 639 | 640 | if old_denoised is not None: 641 | r = h_last / h 642 | if solver_type == 'heun': 643 | x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised) 644 | elif solver_type == 'midpoint': 645 | x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised) 646 | 647 | if eta: 648 | x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise 649 | 650 | old_denoised = denoised 651 | h_last = h 652 | return x 653 | 654 | 655 | @torch.no_grad() 656 | def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): 657 | """DPM-Solver++(3M) SDE.""" 658 | 659 | sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() 660 | noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler 661 | extra_args = {} if extra_args is None else extra_args 662 | s_in = x.new_ones([x.shape[0]]) 663 | 664 | denoised_1, denoised_2 = None, None 665 | h_1, h_2 = None, None 666 | 667 | for i in trange(len(sigmas) - 1, disable=disable): 668 | denoised = model(x, sigmas[i] * s_in, **extra_args) 669 | if callback is not None: 670 | callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) 671 | if sigmas[i + 1] == 0: 672 | # Denoising step 673 | x = denoised 674 | else: 675 | t, s = -sigmas[i].log(), -sigmas[i + 1].log() 676 | h = s - t 677 | h_eta = h * (eta + 1) 678 | 679 | x = torch.exp(-h_eta) * x + (-h_eta).expm1().neg() * denoised 680 | 681 | if h_2 is not None: 682 | r0 = h_1 / h 683 | r1 = h_2 / h 684 | d1_0 = (denoised - denoised_1) / r0 685 | d1_1 = (denoised_1 - denoised_2) / r1 686 | d1 = d1_0 + (d1_0 - d1_1) * r0 / (r0 + r1) 687 | d2 = (d1_0 - d1_1) / (r0 + r1) 688 | phi_2 = h_eta.neg().expm1() / h_eta + 1 689 | phi_3 = phi_2 / h_eta - 0.5 690 | x = x + phi_2 * d1 - phi_3 * d2 691 | elif h_1 is not None: 692 | r = h_1 / h 693 | d = (denoised - denoised_1) / r 694 | phi_2 = h_eta.neg().expm1() / h_eta + 1 695 | x = x + phi_2 * d 696 | 697 | if eta: 698 | x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise 699 | 700 | denoised_1, denoised_2 = denoised, denoised_1 701 | h_1, h_2 = h, h_1 702 | return x 703 | --------------------------------------------------------------------------------