├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── cog.yaml ├── configs ├── config_32x32_small.json ├── config_32x32_small_butterflies.json ├── config_cifar10.json └── config_mnist.json ├── k_diffusion ├── __init__.py ├── augmentation.py ├── config.py ├── evaluation.py ├── external.py ├── gns.py ├── layers.py ├── models │ ├── __init__.py │ └── image_v1.py ├── sampling.py └── utils.py ├── make_grid.py ├── predict.py ├── pyproject.toml ├── requirements.txt ├── sample.py ├── sample_clip_guided.py ├── setup.cfg └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | venv* 2 | __pycache__ 3 | .ipynb_checkpoints 4 | *.pth 5 | *.egg-info 6 | data 7 | *_demo_*.png 8 | wandb/* 9 | *.csv -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "guided-diffusion"] 2 | path = guided-diffusion 3 | url = https://github.com/crowsonkb/guided-diffusion 4 | [submodule "latent-diffusion"] 5 | path = latent-diffusion 6 | url = https://github.com/CompVis/latent-diffusion.git 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2022 Katherine Crowson 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # k-diffusion 2 | 3 | An implementation of [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364) (Karras et al., 2022) for PyTorch. The patching method in [Improving Diffusion Model Efficiency Through Patching](https://arxiv.org/abs/2207.04316) is implemented as well. 4 | 5 | ## Training: 6 | 7 | To train models: 8 | 9 | ```sh 10 | $ ./train.py --config CONFIG_FILE --name RUN_NAME 11 | ``` 12 | 13 | For instance, to train a model on MNIST: 14 | 15 | ```sh 16 | $ ./train.py --config configs/config_mnist.json --name RUN_NAME 17 | ``` 18 | 19 | The configuration file allows you to specify the dataset type. Currently supported types are `"imagefolder"` (a folder with one subfolder per image class, the classes are currently ignored), `"cifar10"` (CIFAR-10), and `"mnist"` (MNIST). 20 | 21 | Multi-GPU and multi-node training is supported with [Hugging Face Accelerate](https://huggingface.co/docs/accelerate/index). You can configure Accelerate by running: 22 | 23 | ```sh 24 | $ accelerate config 25 | ``` 26 | 27 | on all nodes, then running: 28 | 29 | ```sh 30 | $ accelerate launch train.py --config CONFIG_FILE --name RUN_NAME 31 | ``` 32 | 33 | on all nodes. 34 | 35 | ## Enhancements/additional features: 36 | 37 | - k-diffusion models support progressive growing. 38 | 39 | - k-diffusion implements a sampler inspired by [DPM-Solver](https://arxiv.org/abs/2206.00927) and Karras et al. (2022) Algorithm 2 that produces higher quality samples at the same number of function evalutions as Karras Algorithm 2. It also implements a [linear multistep](https://en.wikipedia.org/wiki/Linear_multistep_method#Adams–Bashforth_methods) sampler (comparable to [PLMS](https://arxiv.org/abs/2202.09778)). 40 | 41 | - k-diffusion supports [CLIP](https://openai.com/blog/clip/) guided sampling from unconditional diffusion models (see `sample_clip_guided.py`). 42 | 43 | - k-diffusion has wrappers for [v-diffusion-pytorch](https://github.com/crowsonkb/v-diffusion-pytorch), [OpenAI diffusion](https://github.com/openai/guided-diffusion), and [CompVis diffusion](https://github.com/CompVis/latent-diffusion) models allowing them to be used with its samplers and ODE/SDE. 44 | 45 | - k-diffusion supports log likelihood calculation (not a variational lower bound) for native models and all wrapped models. 46 | 47 | ## To do: 48 | 49 | - Anything except unconditional image diffusion models 50 | 51 | - Latent diffusion 52 | -------------------------------------------------------------------------------- /cog.yaml: -------------------------------------------------------------------------------- 1 | build: 2 | gpu: true 3 | 4 | system_packages: 5 | - "libsm6" 6 | - "libxext6" 7 | - "libglib2.0-0" 8 | 9 | # python version in the form '3.8' or '3.8.12' 10 | python_version: "3.8" 11 | 12 | python_packages: 13 | - "torch==1.9.0" 14 | - "torchvision==0.10.0" 15 | - "accelerate==0.11.0" 16 | - "clean-fid==0.1.26" 17 | - "einops==0.4.1" 18 | - "jsonmerge==1.8.0" 19 | - "kornia==0.6.6" 20 | - "lpips==0.1.4" 21 | - "Pillow==9.2.0" 22 | - "pytorch-lightning==1.5" 23 | - "opencv-python==4.6.0.66" 24 | - "open-clip-torch==1.3.0" 25 | - "omegaconf==2.1.1" 26 | - "resize-right==0.0.2" 27 | - "scikit-image==0.19.3" 28 | - "scipy==1.8.1" 29 | - "streamlit==0.73.1" 30 | - "torch-fidelity==0.3.0" 31 | - "torchdiffeq==0.2.3" 32 | - "transformers==4.19.2" 33 | - "tqdm==4.64.0" 34 | - "wandb==0.12.21" 35 | 36 | 37 | run: 38 | - "pip install git+https://github.com/openai/CLIP" 39 | - "pip install git+https://github.com/crowsonkb/guided-diffusion" 40 | - "git clone https://github.com/CompVis/taming-transformers.git && cd taming-transformers && pip install -e . && cd .." 41 | #- "mkdir -p /root/.cache/k-diffusion; wget -O /root/.cache/k-diffusion/256x256_diffusion_uncond.pt https://models.nmb.ai/disco/256x256_diffusion_uncond.pt" 42 | - "mkdir -p /root/.cache/torch/hub/checkpoints; wget -O /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth https://models.nmb.ai/disco/vgg16-397923af.pth" 43 | - "mkdir -p /root/.cache/k-diffusion; wget --quiet -O /root/.cache/k-diffusion/txt2img-f8-large-jack000-finetuned-fp16.ckpt https://models.nmb.ai/majesty/txt2img-f8-large-jack000-finetuned-fp16.ckpt" 44 | - "mkdir -p /root/.cache/clip; wget --quiet -O /root/.cache/clip/vit_b_32-laion2b_e16-af8dbd0c.pth https://models.nmb.ai/clip/vit_b_32-laion2b_e16-af8dbd0c.pth" 45 | 46 | # predict.py defines how predictions are run on your model 47 | predict: "predict.py:Predictor" 48 | image: r8.im/nightmareai/k-diffusion -------------------------------------------------------------------------------- /configs/config_32x32_small.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": { 3 | "type": "image_v1", 4 | "input_channels": 3, 5 | "input_size": [32, 32], 6 | "patch_size": 1, 7 | "mapping_out": 256, 8 | "depths": [2, 4, 4], 9 | "channels": [128, 256, 512], 10 | "self_attn_depths": [false, true, true], 11 | "dropout_rate": 0.05, 12 | "augment_prob": 0.12, 13 | "sigma_data": 0.5, 14 | "sigma_min": 1e-2, 15 | "sigma_max": 80, 16 | "sigma_sample_density": { 17 | "type": "lognormal", 18 | "mean": -1.2, 19 | "std": 1.2 20 | } 21 | }, 22 | "dataset": { 23 | "type": "imagefolder", 24 | "location": "/path/to/dataset" 25 | }, 26 | "optimizer": { 27 | "type": "adamw", 28 | "lr": 1e-4, 29 | "betas": [0.95, 0.999], 30 | "eps": 1e-6, 31 | "weight_decay": 1e-3 32 | }, 33 | "lr_sched": { 34 | "type": "inverse", 35 | "inv_gamma": 20000.0, 36 | "power": 1.0, 37 | "warmup": 0.99 38 | }, 39 | "ema_sched": { 40 | "type": "inverse", 41 | "power": 0.6667, 42 | "max_value": 0.9999 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /configs/config_32x32_small_butterflies.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": { 3 | "type": "image_v1", 4 | "input_channels": 3, 5 | "input_size": [32, 32], 6 | "patch_size": 1, 7 | "mapping_out": 256, 8 | "depths": [2, 4, 4], 9 | "channels": [128, 256, 512], 10 | "self_attn_depths": [false, true, true], 11 | "dropout_rate": 0.05, 12 | "augment_prob": 0.12, 13 | "sigma_data": 0.5, 14 | "sigma_min": 1e-2, 15 | "sigma_max": 80, 16 | "sigma_sample_density": { 17 | "type": "lognormal", 18 | "mean": -1.2, 19 | "std": 1.2 20 | } 21 | }, 22 | "dataset": { 23 | "type": "huggingface", 24 | "location": "huggan/smithsonian_butterflies_subset", 25 | "image_key": "image" 26 | }, 27 | "optimizer": { 28 | "type": "adamw", 29 | "lr": 1e-4, 30 | "betas": [0.95, 0.999], 31 | "eps": 1e-6, 32 | "weight_decay": 1e-3 33 | }, 34 | "lr_sched": { 35 | "type": "inverse", 36 | "inv_gamma": 20000.0, 37 | "power": 1.0, 38 | "warmup": 0.99 39 | }, 40 | "ema_sched": { 41 | "type": "inverse", 42 | "power": 0.6667, 43 | "max_value": 0.9999 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /configs/config_cifar10.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": { 3 | "type": "image_v1", 4 | "input_channels": 3, 5 | "input_size": [32, 32], 6 | "patch_size": 1, 7 | "mapping_out": 256, 8 | "depths": [2, 4, 4], 9 | "channels": [128, 256, 512], 10 | "self_attn_depths": [false, true, true], 11 | "dropout_rate": 0.05, 12 | "augment_prob": 0.12, 13 | "sigma_data": 0.5, 14 | "sigma_min": 1e-2, 15 | "sigma_max": 80, 16 | "sigma_sample_density": { 17 | "type": "lognormal", 18 | "mean": -1.2, 19 | "std": 1.2 20 | } 21 | }, 22 | "dataset": { 23 | "type": "cifar10", 24 | "location": "data" 25 | }, 26 | "optimizer": { 27 | "type": "adamw", 28 | "lr": 1e-4, 29 | "betas": [0.95, 0.999], 30 | "eps": 1e-6, 31 | "weight_decay": 1e-3 32 | }, 33 | "lr_sched": { 34 | "type": "inverse", 35 | "inv_gamma": 20000.0, 36 | "power": 1.0, 37 | "warmup": 0.99 38 | }, 39 | "ema_sched": { 40 | "type": "inverse", 41 | "power": 0.6667, 42 | "max_value": 0.9999 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /configs/config_mnist.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": { 3 | "type": "image_v1", 4 | "input_channels": 1, 5 | "input_size": [28, 28], 6 | "patch_size": 1, 7 | "mapping_out": 256, 8 | "depths": [2, 4, 4], 9 | "channels": [128, 128, 256], 10 | "self_attn_depths": [false, false, true], 11 | "dropout_rate": 0.05, 12 | "augment_prob": 0.12, 13 | "sigma_data": 0.6162, 14 | "sigma_min": 1e-2, 15 | "sigma_max": 80, 16 | "sigma_sample_density": { 17 | "type": "lognormal", 18 | "mean": -1.2, 19 | "std": 1.2 20 | } 21 | }, 22 | "dataset": { 23 | "type": "mnist", 24 | "location": "data" 25 | }, 26 | "optimizer": { 27 | "type": "adamw", 28 | "lr": 2e-4, 29 | "betas": [0.95, 0.999], 30 | "eps": 1e-6, 31 | "weight_decay": 1e-3 32 | }, 33 | "lr_sched": { 34 | "type": "inverse", 35 | "inv_gamma": 20000.0, 36 | "power": 1.0, 37 | "warmup": 0.99 38 | }, 39 | "ema_sched": { 40 | "type": "inverse", 41 | "power": 0.6667, 42 | "max_value": 0.9999 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /k_diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | from . import augmentation, config, evaluation, external, gns, layers, models, sampling, utils 2 | from .layers import Denoiser 3 | -------------------------------------------------------------------------------- /k_diffusion/augmentation.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | import math 3 | import operator 4 | 5 | import numpy as np 6 | from skimage import transform 7 | import torch 8 | from torch import nn 9 | 10 | 11 | def translate2d(tx, ty): 12 | mat = [[1, 0, tx], 13 | [0, 1, ty], 14 | [0, 0, 1]] 15 | return torch.tensor(mat, dtype=torch.float32) 16 | 17 | 18 | def scale2d(sx, sy): 19 | mat = [[sx, 0, 0], 20 | [ 0, sy, 0], 21 | [ 0, 0, 1]] 22 | return torch.tensor(mat, dtype=torch.float32) 23 | 24 | 25 | def rotate2d(theta): 26 | mat = [[torch.cos(theta), torch.sin(-theta), 0], 27 | [torch.sin(theta), torch.cos(theta), 0], 28 | [ 0, 0, 1]] 29 | return torch.tensor(mat, dtype=torch.float32) 30 | 31 | 32 | class KarrasAugmentationPipeline: 33 | def __init__(self, a_prob=0.12, a_scale=2**0.2, a_aniso=2**0.2, a_trans=1/8): 34 | self.a_prob = a_prob 35 | self.a_scale = a_scale 36 | self.a_aniso = a_aniso 37 | self.a_trans = a_trans 38 | 39 | def __call__(self, image): 40 | h, w = image.size 41 | mats = [translate2d(h / 2 - 0.5, w / 2 - 0.5)] 42 | 43 | # x-flip 44 | a0 = torch.randint(2, []).float() 45 | mats.append(scale2d(1 - 2 * a0, 1)) 46 | # y-flip 47 | do = (torch.rand([]) < self.a_prob).float() 48 | a1 = torch.randint(2, []).float() * do 49 | mats.append(scale2d(1, 1 - 2 * a1)) 50 | # scaling 51 | do = (torch.rand([]) < self.a_prob).float() 52 | a2 = torch.randn([]) * do 53 | mats.append(scale2d(self.a_scale ** a2, self.a_scale ** a2)) 54 | # rotation 55 | do = (torch.rand([]) < self.a_prob).float() 56 | a3 = (torch.rand([]) * 2 * math.pi - math.pi) * do 57 | mats.append(rotate2d(-a3)) 58 | # anisotropy 59 | do = (torch.rand([]) < self.a_prob).float() 60 | a4 = (torch.rand([]) * 2 * math.pi - math.pi) * do 61 | a5 = torch.randn([]) * do 62 | mats.append(rotate2d(a4)) 63 | mats.append(scale2d(self.a_aniso ** a5, self.a_aniso ** -a5)) 64 | mats.append(rotate2d(-a4)) 65 | # translation 66 | do = (torch.rand([]) < self.a_prob).float() 67 | a6 = torch.randn([]) * do 68 | a7 = torch.randn([]) * do 69 | mats.append(translate2d(self.a_trans * w * a6, self.a_trans * h * a7)) 70 | 71 | # form the transformation matrix and conditioning vector 72 | mats.append(translate2d(-h / 2 + 0.5, -w / 2 + 0.5)) 73 | mat = reduce(operator.matmul, mats) 74 | cond = torch.stack([a0, a1, a2, a3.cos() - 1, a3.sin(), a5 * a4.cos(), a5 * a4.sin(), a6, a7]) 75 | 76 | # apply the transformation 77 | image_orig = np.array(image, dtype=np.float32) / 255 78 | if image_orig.ndim == 2: 79 | image_orig = image_orig[..., None] 80 | tf = transform.AffineTransform(mat.numpy()) 81 | image = transform.warp(image_orig, tf.inverse, order=3, mode='reflect', cval=0.5, clip=False, preserve_range=True) 82 | image_orig = torch.as_tensor(image_orig).movedim(2, 0) * 2 - 1 83 | image = torch.as_tensor(image).movedim(2, 0) * 2 - 1 84 | return image, image_orig, cond 85 | 86 | 87 | class KarrasAugmentWrapper(nn.Module): 88 | def __init__(self, model): 89 | super().__init__() 90 | self.inner_model = model 91 | 92 | def forward(self, input, sigma, aug_cond=None, mapping_cond=None, **kwargs): 93 | if aug_cond is None: 94 | aug_cond = input.new_zeros([input.shape[0], 9]) 95 | if mapping_cond is None: 96 | mapping_cond = aug_cond 97 | else: 98 | mapping_cond = torch.cat([aug_cond, mapping_cond], dim=1) 99 | return self.inner_model(input, sigma, mapping_cond=mapping_cond, **kwargs) 100 | 101 | def set_skip_stages(self, skip_stages): 102 | return self.inner_model.set_skip_stages(skip_stages) 103 | 104 | def set_patch_size(self, patch_size): 105 | return self.inner_model.set_patch_size(patch_size) 106 | -------------------------------------------------------------------------------- /k_diffusion/config.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import json 3 | 4 | from jsonmerge import merge 5 | 6 | from . import augmentation, models, utils 7 | 8 | 9 | def load_config(file): 10 | defaults = { 11 | 'model': { 12 | 'sigma_data': 1., 13 | 'patch_size': 1, 14 | 'dropout_rate': 0., 15 | 'augment_prob': 0., 16 | 'mapping_cond_dim': 0, 17 | 'unet_cond_dim': 0, 18 | 'cross_cond_dim': 0, 19 | 'cross_attn_depths': None, 20 | 'skip_stages': 0, 21 | }, 22 | 'dataset': { 23 | 'type': 'imagefolder', 24 | }, 25 | 'optimizer': { 26 | 'type': 'adamw', 27 | 'lr': 1e-4, 28 | 'betas': [0.95, 0.999], 29 | 'eps': 1e-6, 30 | 'weight_decay': 1e-3, 31 | }, 32 | 'lr_sched': { 33 | 'type': 'inverse', 34 | 'inv_gamma': 20000., 35 | 'power': 1., 36 | 'warmup': 0.99, 37 | }, 38 | 'ema_sched': { 39 | 'type': 'inverse', 40 | 'power': 0.6667, 41 | 'max_value': 0.9999 42 | }, 43 | } 44 | config = json.load(file) 45 | return merge(defaults, config) 46 | 47 | 48 | def make_model(config): 49 | config = config['model'] 50 | assert config['type'] == 'image_v1' 51 | model = models.ImageDenoiserModelV1( 52 | config['input_channels'], 53 | config['mapping_out'], 54 | config['depths'], 55 | config['channels'], 56 | config['self_attn_depths'], 57 | config['cross_attn_depths'], 58 | patch_size=config['patch_size'], 59 | dropout_rate=config['dropout_rate'], 60 | mapping_cond_dim=config['mapping_cond_dim'] + 9, 61 | unet_cond_dim=config['unet_cond_dim'], 62 | cross_cond_dim=config['cross_cond_dim'], 63 | skip_stages=config['skip_stages'], 64 | ) 65 | model = augmentation.KarrasAugmentWrapper(model) 66 | return model 67 | 68 | 69 | def make_sample_density(config): 70 | config = config['sigma_sample_density'] 71 | if config['type'] == 'lognormal': 72 | loc = config['mean'] if 'mean' in config else config['loc'] 73 | scale = config['std'] if 'std' in config else config['scale'] 74 | return partial(utils.rand_log_normal, loc=loc, scale=scale) 75 | if config['type'] == 'loglogistic': 76 | loc = config['loc'] 77 | scale = config['scale'] 78 | min_value = config['min_value'] if 'min_value' in config else 0. 79 | max_value = config['max_value'] if 'max_value' in config else float('inf') 80 | return partial(utils.rand_log_logistic, loc=loc, scale=scale, min_value=min_value, max_value=max_value) 81 | if config['type'] == 'loguniform': 82 | min_value = config['min_value'] 83 | max_value = config['max_value'] 84 | return partial(utils.rand_log_uniform, min_value=min_value, max_value=max_value) 85 | raise ValueError('Unknown sample density type') 86 | -------------------------------------------------------------------------------- /k_diffusion/evaluation.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | from pathlib import Path 4 | 5 | from cleanfid.inception_torchscript import InceptionV3W 6 | import clip 7 | from resize_right import resize 8 | import torch 9 | from torch import nn 10 | from torch.nn import functional as F 11 | from torchvision import transforms 12 | from tqdm.auto import trange 13 | 14 | from . import utils 15 | 16 | 17 | class InceptionV3FeatureExtractor(nn.Module): 18 | def __init__(self, device='cpu'): 19 | super().__init__() 20 | path = Path(os.environ.get('XDG_CACHE_HOME', Path.home() / '.cache')) / 'k-diffusion' 21 | url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt' 22 | digest = 'f58cb9b6ec323ed63459aa4fb441fe750cfe39fafad6da5cb504a16f19e958f4' 23 | utils.download_file(path / 'inception-2015-12-05.pt', url, digest) 24 | self.model = InceptionV3W(str(path), resize_inside=False).to(device) 25 | self.size = (299, 299) 26 | 27 | def forward(self, x): 28 | if x.shape[2:4] != self.size: 29 | x = resize(x, out_shape=self.size, pad_mode='reflect') 30 | if x.shape[1] == 1: 31 | x = torch.cat([x] * 3, dim=1) 32 | x = (x * 127.5 + 127.5).clamp(0, 255) 33 | return self.model(x) 34 | 35 | 36 | class CLIPFeatureExtractor(nn.Module): 37 | def __init__(self, name='ViT-L/14@336px', device='cpu'): 38 | super().__init__() 39 | self.model = clip.load(name, device=device)[0].eval().requires_grad_(False) 40 | self.normalize = transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), 41 | std=(0.26862954, 0.26130258, 0.27577711)) 42 | self.size = (self.model.visual.input_resolution, self.model.visual.input_resolution) 43 | 44 | def forward(self, x): 45 | if x.shape[2:4] != self.size: 46 | x = resize(x.add(1).div(2), out_shape=self.size, pad_mode='reflect').clamp(0, 1) 47 | x = self.normalize(x) 48 | x = self.model.encode_image(x).float() 49 | x = F.normalize(x) * x.shape[1] ** 0.5 50 | return x 51 | 52 | 53 | def compute_features(accelerator, sample_fn, extractor_fn, n, batch_size): 54 | n_per_proc = math.ceil(n / accelerator.num_processes) 55 | feats_all = [] 56 | for i in trange(0, n_per_proc, batch_size, disable=not accelerator.is_main_process): 57 | cur_batch_size = min(n - i, batch_size) 58 | samples = sample_fn(cur_batch_size)[:cur_batch_size] 59 | feats_all.append(accelerator.gather(extractor_fn(samples))) 60 | return torch.cat(feats_all)[:n] 61 | 62 | 63 | def polynomial_kernel(x, y): 64 | d = x.shape[-1] 65 | dot = x @ y.transpose(-2, -1) 66 | return (dot / d + 1) ** 3 67 | 68 | 69 | def kid(x, y, kernel=polynomial_kernel): 70 | m = x.shape[-2] 71 | n = y.shape[-2] 72 | kxx = kernel(x, x) 73 | kyy = kernel(y, y) 74 | kxy = kernel(x, y) 75 | kxx_sum = kxx.sum([-1, -2]) - kxx.diagonal(dim1=-1, dim2=-2).sum(-1) 76 | kyy_sum = kyy.sum([-1, -2]) - kyy.diagonal(dim1=-1, dim2=-2).sum(-1) 77 | kxy_sum = kxy.sum([-1, -2]) 78 | term_1 = kxx_sum / m / (m - 1) 79 | term_2 = kyy_sum / n / (n - 1) 80 | term_3 = kxy_sum * 2 / m / n 81 | return term_1 + term_2 - term_3 82 | 83 | 84 | class _MatrixSquareRootEig(torch.autograd.Function): 85 | @staticmethod 86 | def forward(ctx, a): 87 | vals, vecs = torch.linalg.eigh(a) 88 | ctx.save_for_backward(vals, vecs) 89 | return vecs @ vals.abs().sqrt().diag_embed() @ vecs.transpose(-2, -1) 90 | 91 | @staticmethod 92 | def backward(ctx, grad_output): 93 | vals, vecs = ctx.saved_tensors 94 | d = vals.abs().sqrt().unsqueeze(-1).repeat_interleave(vals.shape[-1], -1) 95 | vecs_t = vecs.transpose(-2, -1) 96 | return vecs @ (vecs_t @ grad_output @ vecs / (d + d.transpose(-2, -1))) @ vecs_t 97 | 98 | 99 | def sqrtm_eig(a): 100 | if a.ndim < 2: 101 | raise RuntimeError('tensor of matrices must have at least 2 dimensions') 102 | if a.shape[-2] != a.shape[-1]: 103 | raise RuntimeError('tensor must be batches of square matrices') 104 | return _MatrixSquareRootEig.apply(a) 105 | 106 | 107 | def fid(x, y, eps=1e-8): 108 | x_mean = x.mean(dim=0) 109 | y_mean = y.mean(dim=0) 110 | mean_term = (x_mean - y_mean).pow(2).sum() 111 | x_cov = torch.cov(x.T) 112 | y_cov = torch.cov(y.T) 113 | eps_eye = torch.eye(x_cov.shape[0], device=x_cov.device, dtype=x_cov.dtype) * eps 114 | x_cov = x_cov + eps_eye 115 | y_cov = y_cov + eps_eye 116 | x_cov_sqrt = sqrtm_eig(x_cov) 117 | cov_term = torch.trace(x_cov + y_cov - 2 * sqrtm_eig(x_cov_sqrt @ y_cov @ x_cov_sqrt)) 118 | return mean_term + cov_term 119 | -------------------------------------------------------------------------------- /k_diffusion/external.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from . import sampling, utils 7 | 8 | 9 | class VDenoiser(nn.Module): 10 | """A v-diffusion-pytorch model wrapper for k-diffusion.""" 11 | 12 | def __init__(self, inner_model): 13 | super().__init__() 14 | self.inner_model = inner_model 15 | self.sigma_data = 1. 16 | 17 | def get_scalings(self, sigma): 18 | c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) 19 | c_out = -sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 20 | c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 21 | return c_skip, c_out, c_in 22 | 23 | def sigma_to_t(self, sigma): 24 | return sigma.atan() / math.pi * 2 25 | 26 | def t_to_sigma(self, t): 27 | return (t * math.pi / 2).tan() 28 | 29 | def loss(self, input, noise, sigma, **kwargs): 30 | c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] 31 | noised_input = input + noise * utils.append_dims(sigma, input.ndim) 32 | model_output = self.inner_model(noised_input * c_in, self.sigma_to_t(sigma), **kwargs) 33 | target = (input - c_skip * noised_input) / c_out 34 | return (model_output - target).pow(2).flatten(1).mean(1) 35 | 36 | def forward(self, input, sigma, **kwargs): 37 | c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] 38 | return self.inner_model(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip 39 | 40 | 41 | class DiscreteSchedule(nn.Module): 42 | """A mapping between continuous noise levels (sigmas) and a list of discrete noise 43 | levels.""" 44 | 45 | def __init__(self, sigmas, quantize): 46 | super().__init__() 47 | self.register_buffer('sigmas', sigmas) 48 | self.quantize = quantize 49 | 50 | def get_sigmas(self, n=None): 51 | if n is None: 52 | return sampling.append_zero(self.sigmas.flip(0)) 53 | t_max = len(self.sigmas) - 1 54 | t = torch.linspace(t_max, 0, n, device=self.sigmas.device) 55 | return sampling.append_zero(self.t_to_sigma(t)) 56 | 57 | def sigma_to_t(self, sigma, quantize=None): 58 | quantize = self.quantize if quantize is None else quantize 59 | dists = torch.abs(sigma - self.sigmas[:, None]) 60 | if quantize: 61 | return torch.argmin(dists, dim=0).view(sigma.shape) 62 | low_idx, high_idx = torch.sort(torch.topk(dists, dim=0, k=2, largest=False).indices, dim=0)[0] 63 | low, high = self.sigmas[low_idx], self.sigmas[high_idx] 64 | w = (low - sigma) / (low - high) 65 | w = w.clamp(0, 1) 66 | t = (1 - w) * low_idx + w * high_idx 67 | return t.view(sigma.shape) 68 | 69 | def t_to_sigma(self, t): 70 | t = t.float() 71 | low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac() 72 | return (1 - w) * self.sigmas[low_idx] + w * self.sigmas[high_idx] 73 | 74 | 75 | class DiscreteEpsDDPMDenoiser(DiscreteSchedule): 76 | """A wrapper for discrete schedule DDPM models that output eps (the predicted 77 | noise).""" 78 | 79 | def __init__(self, model, alphas_cumprod, quantize): 80 | super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize) 81 | self.inner_model = model 82 | self.sigma_data = 1. 83 | 84 | def get_scalings(self, sigma): 85 | c_out = -sigma 86 | c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 87 | return c_out, c_in 88 | 89 | def get_eps(self, *args, **kwargs): 90 | return self.inner_model(*args, **kwargs) 91 | 92 | def loss(self, input, noise, sigma, **kwargs): 93 | c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] 94 | noised_input = input + noise * utils.append_dims(sigma, input.ndim) 95 | eps = self.get_eps(noised_input * c_in, self.sigma_to_t(sigma), **kwargs) 96 | return (eps - noise).pow(2).flatten(1).mean(1) 97 | 98 | def forward(self, input, sigma, **kwargs): 99 | c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] 100 | eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs) 101 | return input + eps * c_out 102 | 103 | 104 | class OpenAIDenoiser(DiscreteEpsDDPMDenoiser): 105 | """A wrapper for OpenAI diffusion models.""" 106 | 107 | def __init__(self, model, diffusion, quantize=False, has_learned_sigmas=True, device='cpu'): 108 | alphas_cumprod = torch.tensor(diffusion.alphas_cumprod, device=device, dtype=torch.float32) 109 | super().__init__(model, alphas_cumprod, quantize=quantize) 110 | self.has_learned_sigmas = has_learned_sigmas 111 | 112 | def get_eps(self, *args, **kwargs): 113 | model_output = self.inner_model(*args, **kwargs) 114 | if self.has_learned_sigmas: 115 | return model_output.chunk(2, dim=1)[0] 116 | return model_output 117 | 118 | 119 | class CompVisDenoiser(DiscreteEpsDDPMDenoiser): 120 | """A wrapper for CompVis diffusion models.""" 121 | 122 | def __init__(self, model, quantize=False, device='cpu'): 123 | super().__init__(model, model.alphas_cumprod, quantize=quantize) 124 | 125 | def get_eps(self, *args, **kwargs): 126 | return self.inner_model.apply_model(*args, **kwargs) 127 | -------------------------------------------------------------------------------- /k_diffusion/gns.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class DDPGradientStatsHook: 6 | def __init__(self, ddp_module): 7 | try: 8 | ddp_module.register_comm_hook(self, self._hook_fn) 9 | except AttributeError: 10 | raise ValueError('DDPGradientStatsHook does not support non-DDP wrapped modules') 11 | self._clear_state() 12 | 13 | def _clear_state(self): 14 | self.bucket_sq_norms_small_batch = [] 15 | self.bucket_sq_norms_large_batch = [] 16 | 17 | @staticmethod 18 | def _hook_fn(self, bucket): 19 | buf = bucket.buffer() 20 | self.bucket_sq_norms_small_batch.append(buf.pow(2).sum()) 21 | fut = torch.distributed.all_reduce(buf, op=torch.distributed.ReduceOp.AVG, async_op=True).get_future() 22 | def callback(fut): 23 | buf = fut.value()[0] 24 | self.bucket_sq_norms_large_batch.append(buf.pow(2).sum()) 25 | return buf 26 | return fut.then(callback) 27 | 28 | def get_stats(self): 29 | sq_norm_small_batch = sum(self.bucket_sq_norms_small_batch) 30 | sq_norm_large_batch = sum(self.bucket_sq_norms_large_batch) 31 | self._clear_state() 32 | return torch.stack([sq_norm_small_batch, sq_norm_large_batch])[None] 33 | 34 | 35 | class GradientNoiseScale: 36 | """Calculates the gradient noise scale (1 / SNR), or critical batch size, 37 | from _An Empirical Model of Large-Batch Training_, 38 | https://arxiv.org/abs/1812.06162). 39 | 40 | Args: 41 | beta (float): The decay factor for the exponential moving averages used to 42 | calculate the gradient noise scale. 43 | Default: 0.9998 44 | eps (float): Added for numerical stability. 45 | Default: 1e-8 46 | """ 47 | 48 | def __init__(self, beta=0.9998, eps=1e-8): 49 | self.beta = beta 50 | self.eps = eps 51 | self.ema_sq_norm = 0. 52 | self.ema_var = 0. 53 | self.beta_cumprod = 1. 54 | self.gradient_noise_scale = float('nan') 55 | 56 | def state_dict(self): 57 | """Returns the state of the object as a :class:`dict`.""" 58 | return dict(self.__dict__.items()) 59 | 60 | def load_state_dict(self, state_dict): 61 | """Loads the object's state. 62 | Args: 63 | state_dict (dict): object state. Should be an object returned 64 | from a call to :meth:`state_dict`. 65 | """ 66 | self.__dict__.update(state_dict) 67 | 68 | def update(self, sq_norm_small_batch, sq_norm_large_batch, n_small_batch, n_large_batch): 69 | """Updates the state with a new batch's gradient statistics, and returns the 70 | current gradient noise scale. 71 | 72 | Args: 73 | sq_norm_small_batch (float): The mean of the squared 2-norms of microbatch or 74 | per sample gradients. 75 | sq_norm_large_batch (float): The squared 2-norm of the mean of the microbatch or 76 | per sample gradients. 77 | n_small_batch (int): The batch size of the individual microbatch or per sample 78 | gradients (1 if per sample). 79 | n_large_batch (int): The total batch size of the mean of the microbatch or 80 | per sample gradients. 81 | """ 82 | est_sq_norm = (n_large_batch * sq_norm_large_batch - n_small_batch * sq_norm_small_batch) / (n_large_batch - n_small_batch) 83 | est_var = (sq_norm_small_batch - sq_norm_large_batch) / (1 / n_small_batch - 1 / n_large_batch) 84 | self.ema_sq_norm = self.beta * self.ema_sq_norm + (1 - self.beta) * est_sq_norm 85 | self.ema_var = self.beta * self.ema_var + (1 - self.beta) * est_var 86 | self.beta_cumprod *= self.beta 87 | self.gradient_noise_scale = max(self.ema_var, self.eps) / max(self.ema_sq_norm, self.eps) 88 | return self.gradient_noise_scale 89 | 90 | def get_gns(self): 91 | """Returns the current gradient noise scale.""" 92 | return self.gradient_noise_scale 93 | 94 | def get_stats(self): 95 | """Returns the current (debiased) estimates of the squared mean gradient 96 | and gradient variance.""" 97 | return self.ema_sq_norm / (1 - self.beta_cumprod), self.ema_var / (1 - self.beta_cumprod) 98 | -------------------------------------------------------------------------------- /k_diffusion/layers.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from einops import rearrange, repeat 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | 8 | from . import utils 9 | 10 | # Karras et al. preconditioned denoiser 11 | 12 | class Denoiser(nn.Module): 13 | """A Karras et al. preconditioner for denoising diffusion models.""" 14 | 15 | def __init__(self, inner_model, sigma_data=1.): 16 | super().__init__() 17 | self.inner_model = inner_model 18 | self.sigma_data = sigma_data 19 | 20 | def get_scalings(self, sigma): 21 | c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) 22 | c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 23 | c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 24 | return c_skip, c_out, c_in 25 | 26 | def loss(self, input, noise, sigma, **kwargs): 27 | c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] 28 | noised_input = input + noise * utils.append_dims(sigma, input.ndim) 29 | model_output = self.inner_model(noised_input * c_in, sigma, **kwargs) 30 | target = (input - c_skip * noised_input) / c_out 31 | return (model_output - target).pow(2).flatten(1).mean(1) 32 | 33 | def forward(self, input, sigma, **kwargs): 34 | c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] 35 | return self.inner_model(input * c_in, sigma, **kwargs) * c_out + input * c_skip 36 | 37 | 38 | # Residual blocks 39 | 40 | class ResidualBlock(nn.Module): 41 | def __init__(self, *main, skip=None): 42 | super().__init__() 43 | self.main = nn.Sequential(*main) 44 | self.skip = skip if skip else nn.Identity() 45 | 46 | def forward(self, input): 47 | return self.main(input) + self.skip(input) 48 | 49 | 50 | # Noise level (and other) conditioning 51 | 52 | class ConditionedModule(nn.Module): 53 | pass 54 | 55 | 56 | class UnconditionedModule(ConditionedModule): 57 | def __init__(self, module): 58 | self.module = module 59 | 60 | def forward(self, input, cond): 61 | return self.module(input) 62 | 63 | 64 | class ConditionedSequential(nn.Sequential, ConditionedModule): 65 | def forward(self, input, cond): 66 | for module in self: 67 | if isinstance(module, ConditionedModule): 68 | input = module(input, cond) 69 | else: 70 | input = module(input) 71 | return input 72 | 73 | 74 | class ConditionedResidualBlock(ConditionedModule): 75 | def __init__(self, *main, skip=None): 76 | super().__init__() 77 | self.main = ConditionedSequential(*main) 78 | self.skip = skip if skip else nn.Identity() 79 | 80 | def forward(self, input, cond): 81 | skip = self.skip(input, cond) if isinstance(self.skip, ConditionedModule) else self.skip(input) 82 | return self.main(input, cond) + skip 83 | 84 | 85 | class AdaGN(ConditionedModule): 86 | def __init__(self, feats_in, c_out, num_groups, eps=1e-5, cond_key='cond'): 87 | super().__init__() 88 | self.num_groups = num_groups 89 | self.eps = eps 90 | self.cond_key = cond_key 91 | self.mapper = nn.Linear(feats_in, c_out * 2) 92 | 93 | def forward(self, input, cond): 94 | weight, bias = self.mapper(cond[self.cond_key]).chunk(2, dim=-1) 95 | input = F.group_norm(input, self.num_groups, eps=self.eps) 96 | return torch.addcmul(utils.append_dims(bias, input.ndim), input, utils.append_dims(weight, input.ndim) + 1) 97 | 98 | 99 | # Attention 100 | 101 | class SelfAttention2d(ConditionedModule): 102 | def __init__(self, c_in, n_head, norm, dropout_rate=0.): 103 | super().__init__() 104 | assert c_in % n_head == 0 105 | self.norm_in = norm(c_in) 106 | self.n_head = n_head 107 | self.qkv_proj = nn.Conv2d(c_in, c_in * 3, 1) 108 | self.out_proj = nn.Conv2d(c_in, c_in, 1) 109 | self.dropout = nn.Dropout(dropout_rate) 110 | 111 | def forward(self, input, cond): 112 | n, c, h, w = input.shape 113 | qkv = self.qkv_proj(self.norm_in(input, cond)) 114 | qkv = qkv.view([n, self.n_head * 3, c // self.n_head, h * w]).transpose(2, 3) 115 | q, k, v = qkv.chunk(3, dim=1) 116 | scale = k.shape[3] ** -0.25 117 | att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3) 118 | att = self.dropout(att) 119 | y = (att @ v).transpose(2, 3).contiguous().view([n, c, h, w]) 120 | return input + self.out_proj(y) 121 | 122 | 123 | class CrossAttention2d(ConditionedModule): 124 | def __init__(self, c_dec, c_enc, n_head, norm_dec, dropout_rate=0., 125 | cond_key='cross', cond_key_padding='cross_padding'): 126 | super().__init__() 127 | assert c_dec % n_head == 0 128 | self.cond_key = cond_key 129 | self.cond_key_padding = cond_key_padding 130 | self.norm_enc = nn.LayerNorm(c_enc) 131 | self.norm_dec = norm_dec(c_dec) 132 | self.n_head = n_head 133 | self.q_proj = nn.Conv2d(c_dec, c_dec, 1) 134 | self.kv_proj = nn.Linear(c_enc, c_dec * 2) 135 | self.out_proj = nn.Conv2d(c_dec, c_dec, 1) 136 | self.dropout = nn.Dropout(dropout_rate) 137 | 138 | def forward(self, input, cond): 139 | n, c, h, w = input.shape 140 | q = self.q_proj(self.norm_dec(input, cond)) 141 | q = q.view([n, self.n_head, c // self.n_head, h * w]).transpose(2, 3) 142 | kv = self.kv_proj(self.norm_enc(cond[self.cond_key])) 143 | kv = kv.view([n, -1, self.n_head * 2, c // self.n_head]).transpose(1, 2) 144 | k, v = kv.chunk(2, dim=1) 145 | scale = k.shape[3] ** -0.25 146 | att = ((q * scale) @ (k.transpose(2, 3) * scale)) 147 | att = att - (cond[self.cond_key_padding][:, None, None, :]) * 10000 148 | att = att.softmax(3) 149 | att = self.dropout(att) 150 | y = (att @ v).transpose(2, 3) 151 | y = y.contiguous().view([n, c, h, w]) 152 | return input + self.out_proj(y) 153 | 154 | 155 | # Downsampling/upsampling 156 | 157 | _kernels = { 158 | 'linear': 159 | [1 / 8, 3 / 8, 3 / 8, 1 / 8], 160 | 'cubic': 161 | [-0.01171875, -0.03515625, 0.11328125, 0.43359375, 162 | 0.43359375, 0.11328125, -0.03515625, -0.01171875], 163 | 'lanczos3': 164 | [0.003689131001010537, 0.015056144446134567, -0.03399861603975296, 165 | -0.066637322306633, 0.13550527393817902, 0.44638532400131226, 166 | 0.44638532400131226, 0.13550527393817902, -0.066637322306633, 167 | -0.03399861603975296, 0.015056144446134567, 0.003689131001010537] 168 | } 169 | _kernels['bilinear'] = _kernels['linear'] 170 | _kernels['bicubic'] = _kernels['cubic'] 171 | 172 | 173 | class Downsample2d(nn.Module): 174 | def __init__(self, kernel='linear', pad_mode='reflect'): 175 | super().__init__() 176 | self.pad_mode = pad_mode 177 | kernel_1d = torch.tensor([_kernels[kernel]]) 178 | self.pad = kernel_1d.shape[1] // 2 - 1 179 | self.register_buffer('kernel', kernel_1d.T @ kernel_1d) 180 | 181 | def forward(self, x): 182 | x = F.pad(x, (self.pad,) * 4, self.pad_mode) 183 | weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]]) 184 | indices = torch.arange(x.shape[1], device=x.device) 185 | weight[indices, indices] = self.kernel.to(weight) 186 | return F.conv2d(x, weight, stride=2) 187 | 188 | 189 | class Upsample2d(nn.Module): 190 | def __init__(self, kernel='linear', pad_mode='reflect'): 191 | super().__init__() 192 | self.pad_mode = pad_mode 193 | kernel_1d = torch.tensor([_kernels[kernel]]) * 2 194 | self.pad = kernel_1d.shape[1] // 2 - 1 195 | self.register_buffer('kernel', kernel_1d.T @ kernel_1d) 196 | 197 | def forward(self, x): 198 | x = F.pad(x, ((self.pad + 1) // 2,) * 4, self.pad_mode) 199 | weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]]) 200 | indices = torch.arange(x.shape[1], device=x.device) 201 | weight[indices, indices] = self.kernel.to(weight) 202 | return F.conv_transpose2d(x, weight, stride=2, padding=self.pad * 2 + 1) 203 | 204 | 205 | # Embeddings 206 | 207 | class FourierFeatures(nn.Module): 208 | def __init__(self, in_features, out_features, std=1.): 209 | super().__init__() 210 | assert out_features % 2 == 0 211 | self.register_buffer('weight', torch.randn([out_features // 2, in_features]) * std) 212 | 213 | def forward(self, input): 214 | f = 2 * math.pi * input @ self.weight.T 215 | return torch.cat([f.cos(), f.sin()], dim=-1) 216 | 217 | 218 | # U-Nets 219 | 220 | class UNet(ConditionedModule): 221 | def __init__(self, d_blocks, u_blocks, skip_stages=0): 222 | super().__init__() 223 | self.d_blocks = nn.ModuleList(d_blocks) 224 | self.u_blocks = nn.ModuleList(u_blocks) 225 | self.skip_stages = skip_stages 226 | 227 | def forward(self, input, cond): 228 | skips = [] 229 | for block in self.d_blocks[self.skip_stages:]: 230 | input = block(input, cond) 231 | skips.append(input) 232 | for i, (block, skip) in enumerate(zip(self.u_blocks, reversed(skips))): 233 | input = block(input, cond, skip if i > 0 else None) 234 | return input 235 | -------------------------------------------------------------------------------- /k_diffusion/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .image_v1 import ImageDenoiserModelV1 2 | -------------------------------------------------------------------------------- /k_diffusion/models/image_v1.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | from .. import layers, utils 8 | 9 | 10 | class ResConvBlock(layers.ConditionedResidualBlock): 11 | def __init__(self, feats_in, c_in, c_mid, c_out, group_size=32, dropout_rate=0.): 12 | skip = None if c_in == c_out else nn.Conv2d(c_in, c_out, 1, bias=False) 13 | super().__init__( 14 | layers.AdaGN(feats_in, c_in, max(1, c_in // group_size)), 15 | nn.GELU(), 16 | nn.Conv2d(c_in, c_mid, 3, padding=1), 17 | nn.Dropout2d(dropout_rate, inplace=True), 18 | layers.AdaGN(feats_in, c_mid, max(1, c_mid // group_size)), 19 | nn.GELU(), 20 | nn.Conv2d(c_mid, c_out, 3, padding=1), 21 | nn.Dropout2d(dropout_rate, inplace=True), 22 | skip=skip) 23 | 24 | 25 | class DBlock(layers.ConditionedSequential): 26 | def __init__(self, n_layers, feats_in, c_in, c_mid, c_out, group_size=32, head_size=64, dropout_rate=0., downsample=False, self_attn=False, cross_attn=False, c_enc=0): 27 | modules = [nn.Identity()] 28 | for i in range(n_layers): 29 | my_c_in = c_in if i == 0 else c_mid 30 | my_c_out = c_mid if i < n_layers - 1 else c_out 31 | modules.append(ResConvBlock(feats_in, my_c_in, c_mid, my_c_out, group_size, dropout_rate)) 32 | if self_attn: 33 | norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size)) 34 | modules.append(layers.SelfAttention2d(my_c_out, max(1, my_c_out // head_size), norm, dropout_rate)) 35 | if cross_attn: 36 | norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size)) 37 | modules.append(layers.CrossAttention2d(my_c_out, c_enc, max(1, my_c_out // head_size), norm, dropout_rate)) 38 | super().__init__(*modules) 39 | self.set_downsample(downsample) 40 | 41 | def set_downsample(self, downsample): 42 | self[0] = layers.Downsample2d() if downsample else nn.Identity() 43 | return self 44 | 45 | 46 | class UBlock(layers.ConditionedSequential): 47 | def __init__(self, n_layers, feats_in, c_in, c_mid, c_out, group_size=32, head_size=64, dropout_rate=0., upsample=False, self_attn=False, cross_attn=False, c_enc=0): 48 | modules = [] 49 | for i in range(n_layers): 50 | my_c_in = c_in if i == 0 else c_mid 51 | my_c_out = c_mid if i < n_layers - 1 else c_out 52 | modules.append(ResConvBlock(feats_in, my_c_in, c_mid, my_c_out, group_size, dropout_rate)) 53 | if self_attn: 54 | norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size)) 55 | modules.append(layers.SelfAttention2d(my_c_out, max(1, my_c_out // head_size), norm, dropout_rate)) 56 | if cross_attn: 57 | norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size)) 58 | modules.append(layers.CrossAttention2d(my_c_out, c_enc, max(1, my_c_out // head_size), norm, dropout_rate)) 59 | modules.append(nn.Identity()) 60 | super().__init__(*modules) 61 | self.set_upsample(upsample) 62 | 63 | def forward(self, input, cond, skip=None): 64 | if skip is not None: 65 | input = torch.cat([input, skip], dim=1) 66 | return super().forward(input, cond) 67 | 68 | def set_upsample(self, upsample): 69 | self[-1] = layers.Upsample2d() if upsample else nn.Identity() 70 | return self 71 | 72 | 73 | class MappingNet(nn.Sequential): 74 | def __init__(self, feats_in, feats_out, n_layers=2): 75 | layers = [] 76 | for i in range(n_layers): 77 | layers.append(nn.Linear(feats_in if i == 0 else feats_out, feats_out)) 78 | layers.append(nn.GELU()) 79 | super().__init__(*layers) 80 | for layer in self: 81 | if isinstance(layer, nn.Linear): 82 | nn.init.orthogonal_(layer.weight) 83 | 84 | 85 | class ImageDenoiserModelV1(nn.Module): 86 | def __init__(self, c_in, feats_in, depths, channels, self_attn_depths, cross_attn_depths=None, mapping_cond_dim=0, unet_cond_dim=0, cross_cond_dim=0, dropout_rate=0., patch_size=1, skip_stages=0): 87 | super().__init__() 88 | self.c_in = c_in 89 | self.channels = channels 90 | self.unet_cond_dim = unet_cond_dim 91 | self.patch_size = patch_size 92 | self.timestep_embed = layers.FourierFeatures(1, feats_in) 93 | if mapping_cond_dim > 0: 94 | self.mapping_cond = nn.Linear(mapping_cond_dim, feats_in, bias=False) 95 | self.mapping = MappingNet(feats_in, feats_in) 96 | self.proj_in = nn.Conv2d((c_in + unet_cond_dim) * self.patch_size ** 2, channels[max(0, skip_stages - 1)], 1) 97 | self.proj_out = nn.Conv2d(channels[max(0, skip_stages - 1)], c_in * self.patch_size ** 2, 1) 98 | nn.init.zeros_(self.proj_out.weight) 99 | nn.init.zeros_(self.proj_out.bias) 100 | if cross_cond_dim == 0: 101 | cross_attn_depths = [False] * len(self_attn_depths) 102 | d_blocks, u_blocks = [], [] 103 | for i in range(len(depths)): 104 | my_c_in = channels[max(0, i - 1)] 105 | d_blocks.append(DBlock(depths[i], feats_in, my_c_in, channels[i], channels[i], downsample=i > skip_stages, self_attn=self_attn_depths[i], cross_attn=cross_attn_depths[i], c_enc=cross_cond_dim, dropout_rate=dropout_rate)) 106 | for i in range(len(depths)): 107 | my_c_in = channels[i] * 2 if i < len(depths) - 1 else channels[i] 108 | my_c_out = channels[max(0, i - 1)] 109 | u_blocks.append(UBlock(depths[i], feats_in, my_c_in, channels[i], my_c_out, upsample=i > skip_stages, self_attn=self_attn_depths[i], cross_attn=cross_attn_depths[i], c_enc=cross_cond_dim, dropout_rate=dropout_rate)) 110 | self.u_net = layers.UNet(d_blocks, reversed(u_blocks), skip_stages=skip_stages) 111 | 112 | def forward(self, input, sigma, mapping_cond=None, unet_cond=None, cross_cond=None, cross_cond_padding=None): 113 | c_noise = sigma.log() / 4 114 | timestep_embed = self.timestep_embed(utils.append_dims(c_noise, 2)) 115 | mapping_cond_embed = torch.zeros_like(timestep_embed) if mapping_cond is None else self.mapping_cond(mapping_cond) 116 | mapping_out = self.mapping(timestep_embed + mapping_cond_embed) 117 | cond = {'cond': mapping_out} 118 | if unet_cond is not None: 119 | input = torch.cat([input, unet_cond], dim=1) 120 | if cross_cond is not None: 121 | cond['cross'] = cross_cond 122 | cond['cross_padding'] = cross_cond_padding 123 | if self.patch_size > 1: 124 | input = F.pixel_unshuffle(input, self.patch_size) 125 | input = self.proj_in(input) 126 | input = self.u_net(input, cond) 127 | input = self.proj_out(input) 128 | if self.patch_size > 1: 129 | input = F.pixel_shuffle(input, self.patch_size) 130 | return input 131 | 132 | def set_skip_stages(self, skip_stages): 133 | self.proj_in = nn.Conv2d(self.proj_in.in_channels, self.channels[max(0, skip_stages - 1)], 1) 134 | self.proj_out = nn.Conv2d(self.channels[max(0, skip_stages - 1)], self.proj_out.out_channels, 1) 135 | nn.init.zeros_(self.proj_out.weight) 136 | nn.init.zeros_(self.proj_out.bias) 137 | self.u_net.skip_stages = skip_stages 138 | for i, block in enumerate(self.u_net.d_blocks): 139 | block.set_downsample(i > skip_stages) 140 | for i, block in enumerate(reversed(self.u_net.u_blocks)): 141 | block.set_upsample(i > skip_stages) 142 | return self 143 | 144 | def set_patch_size(self, patch_size): 145 | self.patch_size = patch_size 146 | self.proj_in = nn.Conv2d((self.c_in + self.unet_cond_dim) * self.patch_size ** 2, self.channels[max(0, self.u_net.skip_stages - 1)], 1) 147 | self.proj_out = nn.Conv2d(self.channels[max(0, self.u_net.skip_stages - 1)], self.c_in * self.patch_size ** 2, 1) 148 | nn.init.zeros_(self.proj_out.weight) 149 | nn.init.zeros_(self.proj_out.bias) 150 | -------------------------------------------------------------------------------- /k_diffusion/sampling.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from scipy import integrate 4 | import torch 5 | from torchdiffeq import odeint 6 | from tqdm.auto import trange, tqdm 7 | 8 | from . import utils 9 | 10 | 11 | def append_zero(x): 12 | return torch.cat([x, x.new_zeros([1])]) 13 | 14 | 15 | def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'): 16 | """Constructs the noise schedule of Karras et al. (2022).""" 17 | ramp = torch.linspace(0, 1, n) 18 | min_inv_rho = sigma_min ** (1 / rho) 19 | max_inv_rho = sigma_max ** (1 / rho) 20 | sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho 21 | return append_zero(sigmas).to(device) 22 | 23 | 24 | def get_sigmas_exponential(n, sigma_min, sigma_max, device='cpu'): 25 | """Constructs an exponential noise schedule.""" 26 | sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), n, device=device).exp() 27 | return append_zero(sigmas) 28 | 29 | 30 | def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'): 31 | """Constructs a continuous VP noise schedule.""" 32 | t = torch.linspace(1, eps_s, n, device=device) 33 | sigmas = torch.sqrt(torch.exp(beta_d * t ** 2 / 2 + beta_min * t) - 1) 34 | return append_zero(sigmas) 35 | 36 | 37 | def to_d(x, sigma, denoised): 38 | """Converts a denoiser output to a Karras ODE derivative.""" 39 | return (x - denoised) / utils.append_dims(sigma, x.ndim) 40 | 41 | 42 | def get_ancestral_step(sigma_from, sigma_to): 43 | """Calculates the noise level (sigma_down) to step down to and the amount 44 | of noise to add (sigma_up) when doing an ancestral sampling step.""" 45 | sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5 46 | sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5 47 | return sigma_down, sigma_up 48 | 49 | 50 | @torch.no_grad() 51 | def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): 52 | """Implements Algorithm 2 (Euler steps) from Karras et al. (2022).""" 53 | extra_args = {} if extra_args is None else extra_args 54 | s_in = x.new_ones([x.shape[0]]) 55 | for i in trange(len(sigmas) - 1, disable=disable): 56 | gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. 57 | eps = torch.randn_like(x) * s_noise 58 | sigma_hat = sigmas[i] * (gamma + 1) 59 | if gamma > 0: 60 | x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 61 | denoised = model(x, sigma_hat * s_in, **extra_args) 62 | d = to_d(x, sigma_hat, denoised) 63 | if callback is not None: 64 | callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) 65 | dt = sigmas[i + 1] - sigma_hat 66 | # Euler method 67 | x = x + d * dt 68 | return x 69 | 70 | 71 | @torch.no_grad() 72 | def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None): 73 | """Ancestral sampling with Euler method steps.""" 74 | extra_args = {} if extra_args is None else extra_args 75 | s_in = x.new_ones([x.shape[0]]) 76 | for i in trange(len(sigmas) - 1, disable=disable): 77 | denoised = model(x, sigmas[i] * s_in, **extra_args) 78 | sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1]) 79 | if callback is not None: 80 | callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) 81 | d = to_d(x, sigmas[i], denoised) 82 | # Euler method 83 | dt = sigma_down - sigmas[i] 84 | x = x + d * dt 85 | x = x + torch.randn_like(x) * sigma_up 86 | return x 87 | 88 | 89 | @torch.no_grad() 90 | def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): 91 | """Implements Algorithm 2 (Heun steps) from Karras et al. (2022).""" 92 | extra_args = {} if extra_args is None else extra_args 93 | s_in = x.new_ones([x.shape[0]]) 94 | for i in trange(len(sigmas) - 1, disable=disable): 95 | gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. 96 | eps = torch.randn_like(x) * s_noise 97 | sigma_hat = sigmas[i] * (gamma + 1) 98 | if gamma > 0: 99 | x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 100 | denoised = model(x, sigma_hat * s_in, **extra_args) 101 | d = to_d(x, sigma_hat, denoised) 102 | if callback is not None: 103 | callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) 104 | dt = sigmas[i + 1] - sigma_hat 105 | if sigmas[i + 1] == 0: 106 | # Euler method 107 | x = x + d * dt 108 | else: 109 | # Heun's method 110 | x_2 = x + d * dt 111 | denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args) 112 | d_2 = to_d(x_2, sigmas[i + 1], denoised_2) 113 | d_prime = (d + d_2) / 2 114 | x = x + d_prime * dt 115 | return x 116 | 117 | 118 | @torch.no_grad() 119 | def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): 120 | """A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022).""" 121 | extra_args = {} if extra_args is None else extra_args 122 | s_in = x.new_ones([x.shape[0]]) 123 | for i in trange(len(sigmas) - 1, disable=disable): 124 | gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. 125 | eps = torch.randn_like(x) * s_noise 126 | sigma_hat = sigmas[i] * (gamma + 1) 127 | if gamma > 0: 128 | x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 129 | denoised = model(x, sigma_hat * s_in, **extra_args) 130 | d = to_d(x, sigma_hat, denoised) 131 | if callback is not None: 132 | callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) 133 | # Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule 134 | sigma_mid = ((sigma_hat ** (1 / 3) + sigmas[i + 1] ** (1 / 3)) / 2) ** 3 135 | dt_1 = sigma_mid - sigma_hat 136 | dt_2 = sigmas[i + 1] - sigma_hat 137 | x_2 = x + d * dt_1 138 | denoised_2 = model(x_2, sigma_mid * s_in, **extra_args) 139 | d_2 = to_d(x_2, sigma_mid, denoised_2) 140 | x = x + d_2 * dt_2 141 | return x 142 | 143 | 144 | @torch.no_grad() 145 | def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None): 146 | """Ancestral sampling with DPM-Solver inspired second-order steps.""" 147 | extra_args = {} if extra_args is None else extra_args 148 | s_in = x.new_ones([x.shape[0]]) 149 | for i in trange(len(sigmas) - 1, disable=disable): 150 | denoised = model(x, sigmas[i] * s_in, **extra_args) 151 | sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1]) 152 | if callback is not None: 153 | callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) 154 | d = to_d(x, sigmas[i], denoised) 155 | # Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule 156 | sigma_mid = ((sigmas[i] ** (1 / 3) + sigma_down ** (1 / 3)) / 2) ** 3 157 | dt_1 = sigma_mid - sigmas[i] 158 | dt_2 = sigma_down - sigmas[i] 159 | x_2 = x + d * dt_1 160 | denoised_2 = model(x_2, sigma_mid * s_in, **extra_args) 161 | d_2 = to_d(x_2, sigma_mid, denoised_2) 162 | x = x + d_2 * dt_2 163 | x = x + torch.randn_like(x) * sigma_up 164 | return x 165 | 166 | 167 | def linear_multistep_coeff(order, t, i, j): 168 | if order - 1 > i: 169 | raise ValueError(f'Order {order} too high for step {i}') 170 | def fn(tau): 171 | prod = 1. 172 | for k in range(order): 173 | if j == k: 174 | continue 175 | prod *= (tau - t[i - k]) / (t[i - j] - t[i - k]) 176 | return prod 177 | return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0] 178 | 179 | 180 | @torch.no_grad() 181 | def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4): 182 | extra_args = {} if extra_args is None else extra_args 183 | s_in = x.new_ones([x.shape[0]]) 184 | ds = [] 185 | for i in trange(len(sigmas) - 1, disable=disable): 186 | denoised = model(x, sigmas[i] * s_in, **extra_args) 187 | d = to_d(x, sigmas[i], denoised) 188 | ds.append(d) 189 | if len(ds) > order: 190 | ds.pop(0) 191 | if callback is not None: 192 | callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) 193 | cur_order = min(i + 1, order) 194 | coeffs = [linear_multistep_coeff(cur_order, sigmas.cpu(), i, j) for j in range(cur_order)] 195 | x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds))) 196 | return x 197 | 198 | 199 | @torch.no_grad() 200 | def log_likelihood(model, x, sigma_min, sigma_max, extra_args=None, atol=1e-4, rtol=1e-4): 201 | extra_args = {} if extra_args is None else extra_args 202 | s_in = x.new_ones([x.shape[0]]) 203 | v = torch.randint_like(x, 2) * 2 - 1 204 | fevals = 0 205 | def ode_fn(sigma, x): 206 | nonlocal fevals 207 | with torch.enable_grad(): 208 | x = x[0].detach().requires_grad_() 209 | denoised = model(x, sigma * s_in, **extra_args) 210 | d = to_d(x, sigma, denoised) 211 | fevals += 1 212 | grad = torch.autograd.grad((d * v).sum(), x)[0] 213 | d_ll = (v * grad).flatten(1).sum(1) 214 | return d.detach(), d_ll 215 | x_min = x, x.new_zeros([x.shape[0]]) 216 | t = x.new_tensor([sigma_min, sigma_max]) 217 | sol = odeint(ode_fn, x_min, t, atol=atol, rtol=rtol, method='dopri5') 218 | latent, delta_ll = sol[0][-1], sol[1][-1] 219 | ll_prior = torch.distributions.Normal(0, sigma_max).log_prob(latent).flatten(1).sum(1) 220 | return ll_prior + delta_ll, {'fevals': fevals} 221 | -------------------------------------------------------------------------------- /k_diffusion/utils.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | import hashlib 3 | import math 4 | from pathlib import Path 5 | import shutil 6 | import urllib 7 | import warnings 8 | 9 | import torch 10 | from torch import optim 11 | from torchvision.transforms import functional as TF 12 | 13 | 14 | def from_pil_image(x): 15 | """Converts from a PIL image to a tensor.""" 16 | x = TF.to_tensor(x) 17 | if x.ndim == 2: 18 | x = x[..., None] 19 | return x * 2 - 1 20 | 21 | 22 | def to_pil_image(x): 23 | """Converts from a tensor to a PIL image.""" 24 | if x.ndim == 4: 25 | assert x.shape[0] == 1 26 | x = x[0] 27 | if x.shape[0] == 1: 28 | x = x[0] 29 | return TF.to_pil_image((x.clamp(-1, 1) + 1) / 2) 30 | 31 | 32 | def hf_datasets_augs_helper(examples, transform, image_key, mode='RGB'): 33 | """Apply passed in transforms for HuggingFace Datasets.""" 34 | images = [transform(image.convert(mode)) for image in examples[image_key]] 35 | return {image_key: images} 36 | 37 | 38 | def append_dims(x, target_dims): 39 | """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" 40 | dims_to_append = target_dims - x.ndim 41 | if dims_to_append < 0: 42 | raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') 43 | return x[(...,) + (None,) * dims_to_append] 44 | 45 | 46 | def n_params(module): 47 | """Returns the number of trainable parameters in a module.""" 48 | return sum(p.numel() for p in module.parameters()) 49 | 50 | 51 | def download_file(path, url, digest=None): 52 | """Downloads a file if it does not exist, optionally checking its SHA-256 hash.""" 53 | path = Path(path) 54 | path.parent.mkdir(parents=True, exist_ok=True) 55 | if not path.exists(): 56 | with urllib.request.urlopen(url) as response, open(path, 'wb') as f: 57 | shutil.copyfileobj(response, f) 58 | if digest is not None: 59 | file_digest = hashlib.sha256(open(path, 'rb').read()).hexdigest() 60 | if digest != file_digest: 61 | raise OSError(f'hash of {path} (url: {url}) failed to validate') 62 | return path 63 | 64 | 65 | @contextmanager 66 | def train_mode(model, mode=True): 67 | """A context manager that places a model into training mode and restores 68 | the previous mode on exit.""" 69 | modes = [module.training for module in model.modules()] 70 | try: 71 | yield model.train(mode) 72 | finally: 73 | for i, module in enumerate(model.modules()): 74 | module.training = modes[i] 75 | 76 | 77 | def eval_mode(model): 78 | """A context manager that places a model into evaluation mode and restores 79 | the previous mode on exit.""" 80 | return train_mode(model, False) 81 | 82 | 83 | @torch.no_grad() 84 | def ema_update(model, averaged_model, decay): 85 | """Incorporates updated model parameters into an exponential moving averaged 86 | version of a model. It should be called after each optimizer step.""" 87 | model_params = dict(model.named_parameters()) 88 | averaged_params = dict(averaged_model.named_parameters()) 89 | assert model_params.keys() == averaged_params.keys() 90 | 91 | for name, param in model_params.items(): 92 | averaged_params[name].mul_(decay).add_(param, alpha=1 - decay) 93 | 94 | model_buffers = dict(model.named_buffers()) 95 | averaged_buffers = dict(averaged_model.named_buffers()) 96 | assert model_buffers.keys() == averaged_buffers.keys() 97 | 98 | for name, buf in model_buffers.items(): 99 | averaged_buffers[name].copy_(buf) 100 | 101 | 102 | class EMAWarmup: 103 | """Implements an EMA warmup using an inverse decay schedule. 104 | If inv_gamma=1 and power=1, implements a simple average. inv_gamma=1, power=2/3 are 105 | good values for models you plan to train for a million or more steps (reaches decay 106 | factor 0.999 at 31.6K steps, 0.9999 at 1M steps), inv_gamma=1, power=3/4 for models 107 | you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at 108 | 215.4k steps). 109 | Args: 110 | inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. 111 | power (float): Exponential factor of EMA warmup. Default: 1. 112 | min_value (float): The minimum EMA decay rate. Default: 0. 113 | max_value (float): The maximum EMA decay rate. Default: 1. 114 | start_at (int): The epoch to start averaging at. Default: 0. 115 | last_epoch (int): The index of last epoch. Default: 0. 116 | """ 117 | 118 | def __init__(self, inv_gamma=1., power=1., min_value=0., max_value=1., start_at=0, 119 | last_epoch=0): 120 | self.inv_gamma = inv_gamma 121 | self.power = power 122 | self.min_value = min_value 123 | self.max_value = max_value 124 | self.start_at = start_at 125 | self.last_epoch = last_epoch 126 | 127 | def state_dict(self): 128 | """Returns the state of the class as a :class:`dict`.""" 129 | return dict(self.__dict__.items()) 130 | 131 | def load_state_dict(self, state_dict): 132 | """Loads the class's state. 133 | Args: 134 | state_dict (dict): scaler state. Should be an object returned 135 | from a call to :meth:`state_dict`. 136 | """ 137 | self.__dict__.update(state_dict) 138 | 139 | def get_value(self): 140 | """Gets the current EMA decay rate.""" 141 | epoch = max(0, self.last_epoch - self.start_at) 142 | value = 1 - (1 + epoch / self.inv_gamma) ** -self.power 143 | return 0. if epoch < 0 else min(self.max_value, max(self.min_value, value)) 144 | 145 | def step(self): 146 | """Updates the step count.""" 147 | self.last_epoch += 1 148 | 149 | 150 | class InverseLR(optim.lr_scheduler._LRScheduler): 151 | """Implements an inverse decay learning rate schedule with an optional exponential 152 | warmup. When last_epoch=-1, sets initial lr as lr. 153 | inv_gamma is the number of steps/epochs required for the learning rate to decay to 154 | (1 / 2)**power of its original value. 155 | Args: 156 | optimizer (Optimizer): Wrapped optimizer. 157 | inv_gamma (float): Inverse multiplicative factor of learning rate decay. Default: 1. 158 | power (float): Exponential factor of learning rate decay. Default: 1. 159 | warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable) 160 | Default: 0. 161 | min_lr (float): The minimum learning rate. Default: 0. 162 | last_epoch (int): The index of last epoch. Default: -1. 163 | verbose (bool): If ``True``, prints a message to stdout for 164 | each update. Default: ``False``. 165 | """ 166 | 167 | def __init__(self, optimizer, inv_gamma=1., power=1., warmup=0., min_lr=0., 168 | last_epoch=-1, verbose=False): 169 | self.inv_gamma = inv_gamma 170 | self.power = power 171 | if not 0. <= warmup < 1: 172 | raise ValueError('Invalid value for warmup') 173 | self.warmup = warmup 174 | self.min_lr = min_lr 175 | super().__init__(optimizer, last_epoch, verbose) 176 | 177 | def get_lr(self): 178 | if not self._get_lr_called_within_step: 179 | warnings.warn("To get the last learning rate computed by the scheduler, " 180 | "please use `get_last_lr()`.") 181 | 182 | return self._get_closed_form_lr() 183 | 184 | def _get_closed_form_lr(self): 185 | warmup = 1 - self.warmup ** (self.last_epoch + 1) 186 | lr_mult = (1 + self.last_epoch / self.inv_gamma) ** -self.power 187 | return [warmup * max(self.min_lr, base_lr * lr_mult) 188 | for base_lr in self.base_lrs] 189 | 190 | 191 | class ExponentialLR(optim.lr_scheduler._LRScheduler): 192 | """Implements an exponential learning rate schedule with an optional exponential 193 | warmup. When last_epoch=-1, sets initial lr as lr. Decays the learning rate 194 | continuously by decay (default 0.5) every num_steps steps. 195 | Args: 196 | optimizer (Optimizer): Wrapped optimizer. 197 | num_steps (float): The number of steps to decay the learning rate by decay in. 198 | decay (float): The factor by which to decay the learning rate every num_steps 199 | steps. Default: 0.5. 200 | warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable) 201 | Default: 0. 202 | min_lr (float): The minimum learning rate. Default: 0. 203 | last_epoch (int): The index of last epoch. Default: -1. 204 | verbose (bool): If ``True``, prints a message to stdout for 205 | each update. Default: ``False``. 206 | """ 207 | 208 | def __init__(self, optimizer, num_steps, decay=0.5, warmup=0., min_lr=0., 209 | last_epoch=-1, verbose=False): 210 | self.num_steps = num_steps 211 | self.decay = decay 212 | if not 0. <= warmup < 1: 213 | raise ValueError('Invalid value for warmup') 214 | self.warmup = warmup 215 | self.min_lr = min_lr 216 | super().__init__(optimizer, last_epoch, verbose) 217 | 218 | def get_lr(self): 219 | if not self._get_lr_called_within_step: 220 | warnings.warn("To get the last learning rate computed by the scheduler, " 221 | "please use `get_last_lr()`.") 222 | 223 | return self._get_closed_form_lr() 224 | 225 | def _get_closed_form_lr(self): 226 | warmup = 1 - self.warmup ** (self.last_epoch + 1) 227 | lr_mult = (self.decay ** (1 / self.num_steps)) ** self.last_epoch 228 | return [warmup * max(self.min_lr, base_lr * lr_mult) 229 | for base_lr in self.base_lrs] 230 | 231 | 232 | def rand_log_normal(shape, loc=0., scale=1., device='cpu', dtype=torch.float32): 233 | """Draws samples from an lognormal distribution.""" 234 | return (torch.randn(shape, device=device, dtype=dtype) * scale + loc).exp() 235 | 236 | 237 | def rand_log_logistic(shape, loc=0., scale=1., min_value=0., max_value=float('inf'), device='cpu', dtype=torch.float32): 238 | """Draws samples from an optionally truncated log-logistic distribution.""" 239 | min_value = torch.as_tensor(min_value, device=device, dtype=torch.float64) 240 | max_value = torch.as_tensor(max_value, device=device, dtype=torch.float64) 241 | min_cdf = min_value.log().sub(loc).div(scale).sigmoid() 242 | max_cdf = max_value.log().sub(loc).div(scale).sigmoid() 243 | u = torch.rand(shape, device=device, dtype=torch.float64) * (max_cdf - min_cdf) + min_cdf 244 | return u.logit().mul(scale).add(loc).exp().to(dtype) 245 | 246 | 247 | def rand_log_uniform(shape, min_value, max_value, device='cpu', dtype=torch.float32): 248 | """Draws samples from an log-uniform distribution.""" 249 | min_value = math.log(min_value) 250 | max_value = math.log(max_value) 251 | return (torch.rand(shape, device=device, dtype=dtype) * (max_value - min_value) + min_value).exp() 252 | -------------------------------------------------------------------------------- /make_grid.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """Assembles images into a grid.""" 4 | 5 | import argparse 6 | import math 7 | import sys 8 | 9 | from PIL import Image 10 | 11 | 12 | def main(): 13 | p = argparse.ArgumentParser(description=__doc__) 14 | p.add_argument('images', type=str, nargs='+', metavar='image', 15 | help='the input images') 16 | p.add_argument('--output', '-o', type=str, default='out.png', 17 | help='the output image') 18 | p.add_argument('--nrow', type=int, 19 | help='the number of images per row') 20 | args = p.parse_args() 21 | 22 | images = [Image.open(image) for image in args.images] 23 | mode = images[0].mode 24 | size = images[0].size 25 | for image, name in zip(images, args.images): 26 | if image.mode != mode: 27 | print(f'Error: Image {name} had mode {image.mode}, expected {mode}', file=sys.stderr) 28 | sys.exit(1) 29 | if image.size != size: 30 | print(f'Error: Image {name} had size {image.size}, expected {size}', file=sys.stderr) 31 | sys.exit(1) 32 | 33 | n = len(images) 34 | x = args.nrow if args.nrow else math.ceil(n**0.5) 35 | y = math.ceil(n / x) 36 | 37 | output = Image.new(mode, (size[0] * x, size[1] * y)) 38 | for i, image in enumerate(images): 39 | cur_x, cur_y = i % x, i // x 40 | output.paste(image, (size[0] * cur_x, size[1] * cur_y)) 41 | 42 | output.save(args.output) 43 | 44 | 45 | if __name__ == '__main__': 46 | main() 47 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | from distutils.command.config import config 2 | import gc 3 | import io 4 | import math 5 | from pyexpat import model 6 | import sys 7 | from einops import rearrange 8 | import numpy as np 9 | import tempfile 10 | import threading 11 | import typing 12 | import queue 13 | import open_clip 14 | 15 | from omegaconf import OmegaConf 16 | import clip 17 | import k_diffusion as K 18 | import lpips 19 | from PIL import Image 20 | import requests 21 | import torch 22 | from torch import nn 23 | from torch.nn import functional as F 24 | from torchvision import transforms, utils 25 | from torchvision.transforms import functional as TF 26 | from tqdm import tqdm 27 | 28 | sys.path.append('./guided-diffusion') 29 | sys.path.append('./latent-diffusion') 30 | from guided_diffusion.script_util import create_model_and_diffusion, model_and_diffusion_defaults 31 | from ldm.util import instantiate_from_config 32 | 33 | from cog import BasePredictor, Input, Path 34 | 35 | def fetch(url_or_path): 36 | if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'): 37 | r = requests.get(url_or_path) 38 | r.raise_for_status() 39 | fd = io.BytesIO() 40 | fd.write(r.content) 41 | fd.seek(0) 42 | return fd 43 | return open(url_or_path, 'rb') 44 | 45 | 46 | def parse_prompt(prompt): 47 | if prompt.startswith('http://') or prompt.startswith('https://'): 48 | vals = prompt.rsplit(':', 2) 49 | vals = [vals[0] + ':' + vals[1], *vals[2:]] 50 | else: 51 | vals = prompt.rsplit(':', 1) 52 | vals = vals + ['', '1'][len(vals):] 53 | return vals[0], float(vals[1]) 54 | 55 | 56 | class MakeCutouts(nn.Module): 57 | def __init__(self, cut_size, cutn, cut_pow=1.): 58 | super().__init__() 59 | self.cut_size = cut_size 60 | self.cutn = cutn 61 | self.cut_pow = cut_pow 62 | 63 | def forward(self, input): 64 | sideY, sideX = input.shape[2:4] 65 | max_size = min(sideX, sideY) 66 | min_size = min(sideX, sideY, self.cut_size) 67 | cutouts = [] 68 | for _ in range(self.cutn): 69 | size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size) 70 | offsetx = torch.randint(0, sideX - size + 1, ()) 71 | offsety = torch.randint(0, sideY - size + 1, ()) 72 | cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size] 73 | cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size)) 74 | return torch.cat(cutouts) 75 | 76 | 77 | def spherical_dist_loss(x, y): 78 | x = F.normalize(x, dim=-1) 79 | y = F.normalize(y, dim=-1) 80 | return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2) 81 | 82 | 83 | def tv_loss(input): 84 | """L2 total variation loss, as in Mahendran et al.""" 85 | input = F.pad(input, (0, 1, 0, 1), 'replicate') 86 | x_diff = input[..., :-1, 1:] - input[..., :-1, :-1] 87 | y_diff = input[..., 1:, :-1] - input[..., :-1, :-1] 88 | return (x_diff**2 + y_diff**2).mean([1, 2, 3]) 89 | 90 | 91 | def range_loss(input): 92 | return (input - input.clamp(-1, 1)).pow(2).mean([1, 2, 3]) 93 | 94 | 95 | class GuidedDenoiserWithGrad(nn.Module): 96 | def __init__(self, model, cond_fn): 97 | super().__init__() 98 | self.inner_model = model 99 | self.cond_fn = cond_fn 100 | self.orig_denoised = None 101 | 102 | def forward(self, x, sigma, **kwargs): 103 | with torch.enable_grad(): 104 | x = x.detach().requires_grad_() 105 | denoised = self.inner_model(x, sigma, **kwargs) 106 | self.orig_denoised = denoised.detach() 107 | cond_grad = self.cond_fn(x, sigma, denoised=denoised, **kwargs) 108 | cond_denoised = denoised + cond_grad * K.utils.append_dims(sigma ** 2, x.ndim) 109 | return cond_denoised 110 | 111 | class CFGDenoiser(nn.Module): 112 | def __init__(self, model): 113 | super().__init__() 114 | self.inner_model = model 115 | 116 | def forward(self, x, sigma, uncond, cond, cond_scale): 117 | x_in = torch.cat([x] * 2) 118 | sigma_in = torch.cat([sigma] * 2) 119 | cond_in = torch.cat([uncond, cond]) 120 | uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) 121 | return uncond + (cond - uncond) * cond_scale 122 | 123 | class Predictor(BasePredictor): 124 | 125 | def LoadCompVisModel(self): 126 | self.model_config = OmegaConf.load("/src/latent-diffusion/configs/latent-diffusion/txt2img-1p4B-eval.yaml") 127 | # self.model_config.update({ 128 | # 'attention_resolutions': '32, 16, 8', 129 | # 'class_cond': False, 130 | # 'diffusion_steps': 1000, 131 | # 'rescale_timesteps': True, 132 | # 'timestep_respacing': '1000', 133 | # 'learn_sigma': True, 134 | # 'noise_schedule': 'linear', 135 | # 'num_channels': 256, 136 | # 'num_head_channels': 64, 137 | # 'num_res_blocks': 2, 138 | # 'resblock_updown': True, 139 | # 'use_checkpoint': True, 140 | # 'use_fp16': True, 141 | # 'use_scale_shift_norm': True, 142 | # }) 143 | self.model_config['image_size'] = 256 144 | self.model_path = "/root/.cache/k-diffusion/txt2img-f8-large-jack000-finetuned-fp16.ckpt" 145 | sd = torch.load(self.model_path, map_location='cuda') 146 | #sd = pl_sd["state_dict"] 147 | self.model = instantiate_from_config(self.model_config.model) 148 | m, u = self.model.load_state_dict(sd, strict=False) 149 | 150 | def LoadOpenAIModel(self): 151 | self.model_config = model_and_diffusion_defaults() 152 | self.model_config.update({ 153 | 'attention_resolutions': '32, 16, 8', 154 | 'class_cond': False, 155 | 'diffusion_steps': 1000, 156 | 'rescale_timesteps': True, 157 | 'timestep_respacing': '1000', 158 | 'learn_sigma': True, 159 | 'noise_schedule': 'linear', 160 | 'num_channels': 256, 161 | 'num_head_channels': 64, 162 | 'num_res_blocks': 2, 163 | 'resblock_updown': True, 164 | 'use_checkpoint': False, 165 | 'use_fp16': True, 166 | 'use_scale_shift_norm': True, 167 | }) 168 | self.model_config['image_size'] = 256 169 | self.model_path = "/root/.cache/k-diffusion/256x256_diffusion_uncond.pt" 170 | self.model, self.diffusion = create_model_and_diffusion(**self.model_config) 171 | self.model.load_state_dict(torch.load(self.model_path, map_location='cuda')) 172 | 173 | def setup(self): 174 | self.device = torch.device('cuda') 175 | self.LoadCompVisModel() 176 | self.model.requires_grad_().eval().to(self.device) 177 | #if self.model_config['use_fp16']: 178 | # self.model.convert_to_fp16() 179 | 180 | self.clip_model, _, self.preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_e16') 181 | self.clip_model = self.clip_model.eval().requires_grad_(False).to(self.device) 182 | #self.clip_model = clip.load('ViT-B/32', jit=False)[0].eval().requires_grad_(False).to(self.device) 183 | #self.clip_size = self.clip_model.vison.input_resolution 184 | self.clip_size = 224 185 | self.normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], 186 | std=[0.26862954, 0.26130258, 0.27577711]) 187 | self.lpips_model = lpips.LPIPS(net='vgg').to(self.device) 188 | 189 | def predict( 190 | self, 191 | text_prompt: str = Input(description="Prompt",default="A mysterious orb by Ernst Fuchs"), 192 | #model: str = Input(description="Diffusion Model",default="latent_diffusion_txt2img_f8_large.ckpt", choices=['latent_diffusion_txt2img_f8_large.ckpt','256x256_diffusion_uncond.pt']), 193 | init_image: Path = Input(description="Initial image for the generation",default=None), 194 | sigma_start: int = Input(description="The starting noise level when using an init image", default=10), 195 | init_scale: int = Input(description="This enhances the effect of the init image, a good value is 1000.", default=1000), 196 | image_prompt: Path = Input(description="Image prompt",default=None), 197 | #batch_size: int = Input(description="The number of generations to run",ge=1,le=10,default=1), 198 | n_steps: int = Input(description="The number of timesteps to use", ge=50,le=1000,default=500), 199 | latent_scale: int = Input(description="Latent guidance scale, higher for stronger latent guidance", default=5.0), 200 | clip_guidance_scale: int = Input(description="Controls how much the image should look like the prompt.", default=1000), 201 | cutn: int = Input(description="The number of random crops per step.", default=16), 202 | cut_pow: float = Input(description="Cut power", default=0.5), 203 | seed: int = Input(description="Seed (leave empty to use a random seed)", default=None, le=(2**32-1), ge=0), 204 | ) -> typing.Iterator[Path]: 205 | prompts = [text_prompt] 206 | self.text_prompt = text_prompt 207 | self.latent_scale = latent_scale 208 | image_prompts = [] 209 | if (image_prompt): 210 | image_prompts = [str(image_prompt)] 211 | if (init_image): 212 | init_image = str(init_image) 213 | 214 | n_batches = 1 215 | 216 | 217 | make_cutouts = MakeCutouts(self.clip_size, cutn, cut_pow) 218 | side_x = side_y = self.model_config['image_size'] 219 | 220 | target_embeds, weights = [], [] 221 | 222 | 223 | 224 | # do_run 225 | 226 | make_cutouts = MakeCutouts(self.clip_size, cutn, cut_pow) 227 | side_x = side_y = self.model_config['image_size'] 228 | 229 | target_embeds, weights = [], [] 230 | 231 | for prompt in prompts: 232 | txt, weight = parse_prompt(prompt) 233 | target_embeds.append(self.clip_model.encode_text(open_clip.tokenize(txt).to(self.device)).float()) 234 | weights.append(weight) 235 | 236 | for prompt in image_prompts: 237 | path, weight = parse_prompt(prompt) 238 | img = Image.open(fetch(path)).convert('RGB') 239 | img = TF.resize(img, min(side_x, side_y, *img.size), transforms.InterpolationMode.LANCZOS) 240 | batch = make_cutouts(TF.to_tensor(img).unsqueeze(0).to(self.device)) 241 | embed = self.clip_model.encode_image(self.normalize(batch)).float() 242 | target_embeds.append(embed) 243 | weights.extend([weight / cutn] * cutn) 244 | 245 | target_embeds = torch.cat(target_embeds) 246 | weights = torch.tensor(weights, device=self.device) 247 | if weights.sum().abs() < 1e-3: 248 | raise RuntimeError('The weights must not sum to 0.') 249 | weights /= weights.sum().abs() 250 | 251 | init = None 252 | if init_image is not None: 253 | init = Image.open(fetch(init_image)).convert('RGB') 254 | init = init.resize((side_x, side_y), Image.Resampling.LANCZOS) 255 | init = TF.to_tensor(init).to(self.device)[None] * 2 - 1 256 | 257 | def cond_fn(x, sigma, denoised, cond, **kwargs): 258 | n = x.shape[0] 259 | 260 | # Anti-grain hack for the 256x256 ImageNet model 261 | #fac = sigma / (sigma ** 2 + 1) ** 0.5 262 | #denoised_in = x.lerp(denoised, fac) 263 | denoised_in = self.model.first_stage_model.decode(denoised / self.model.scale_factor) 264 | 265 | #clip_in = self.normalize(make_cutouts(denoised_in.add(1).div(2))) 266 | clip_in = self.normalize(make_cutouts(denoised_in.add(1).div(2))) 267 | image_embeds = self.clip_model.encode_image(clip_in).float() 268 | dists = spherical_dist_loss(image_embeds[:, None], target_embeds[None]) 269 | dists = dists.view([cutn, n, -1]) 270 | losses = dists.mul(weights).sum(2).mean(0) 271 | loss = losses.sum() * clip_guidance_scale 272 | if init is not None and init_scale: 273 | init_losses = self.lpips_model(denoised_in, init) 274 | loss = loss + init_losses.sum() * init_scale 275 | 276 | return -torch.autograd.grad(loss, x)[0] 277 | 278 | #model_wrap = K.external.OpenAIDenoiser(self.model, self.diffusion, device=self.device) 279 | self.model_wrap = K.external.CompVisDenoiser(self.model, False, device=self.device) 280 | sigmas = self.model_wrap.get_sigmas(n_steps) 281 | if init is not None: 282 | sigmas = sigmas[sigmas <= sigma_start] 283 | self.model_wrap_cfg = CFGDenoiser(self.model_wrap) 284 | self.model_guided = GuidedDenoiserWithGrad(self.model_wrap_cfg, cond_fn) 285 | 286 | output = queue.SimpleQueue() 287 | 288 | def callback(info): 289 | if info['i'] % 50 == 0: 290 | denoised = self.model.decode_first_stage(self.model_guided.orig_denoised) 291 | nrow = math.ceil(denoised.shape[0] ** 0.5) 292 | grid = utils.make_grid(denoised, nrow, padding=0) 293 | tqdm.write(f'Step {info["i"]} of {len(sigmas) - 1}:') 294 | filename = f'step_{i}.png' 295 | K.utils.to_pil_image(grid).save(filename) 296 | output.put(filename) 297 | #display.display(K.utils.to_pil_image(grid)) 298 | tqdm.write(f'') 299 | 300 | if seed is not None: 301 | torch.manual_seed(seed) 302 | 303 | self.success = False 304 | 305 | for i in range(n_batches): 306 | self.side_y = side_y 307 | self.side_x = side_x 308 | self.sigmas = sigmas 309 | self.init = init 310 | self.callback = callback 311 | t = threading.Thread(target=self.worker, daemon=True) 312 | t.start() 313 | while t.is_alive(): 314 | try: 315 | image = output.get(block=True, timeout=5) 316 | yield Path(image) 317 | except: {} 318 | 319 | tqdm.write('Done!') 320 | if (not self.success): 321 | raise RuntimeError('No output, check logs for errors') 322 | 323 | samples = self.model.decode_first_stage(self.samples) 324 | samples = torch.clamp((samples + 1.0) / 2.0, min=0.0, max=1.0) 325 | 326 | for i, out in enumerate(samples): 327 | sample = 255.0 * rearrange(out.cpu().numpy(), "c h w -> h w c") 328 | filename = f'out_{i}.png' 329 | Image.fromarray(sample.astype(np.uint8)).save(filename) 330 | yield Path(filename) 331 | 332 | @torch.no_grad() 333 | def worker(self): 334 | with self.model.ema_scope(): 335 | self.x = torch.randn([1, 4, self.side_y//8, self.side_x//8], device=self.device) 336 | if self.init is not None: 337 | self.x += self.init 338 | n_samples = 1 339 | c = self.model.get_learned_conditioning(n_samples * [self.text_prompt]) 340 | uc = self.model.get_learned_conditioning(n_samples * [""]) 341 | extra_args = {'cond': c, 'uncond': uc, 'cond_scale': self.latent_scale} 342 | self.samples = K.sampling.sample_heun(self.model_guided, self.x, self.sigmas, second_order=True, s_churn=20, callback=self.callback, extra_args=extra_args) 343 | self.success = True 344 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools"] 3 | build-backend = "setuptools.build_meta" 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | clean-fid 3 | einops 4 | jsonmerge 5 | kornia 6 | Pillow 7 | resize-right 8 | scikit-image 9 | scipy 10 | torch 11 | torchdiffeq 12 | torchvision 13 | tqdm 14 | wandb 15 | git+https://github.com/openai/CLIP 16 | -------------------------------------------------------------------------------- /sample.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """Samples from k-diffusion models.""" 4 | 5 | import argparse 6 | import math 7 | 8 | import accelerate 9 | import torch 10 | from tqdm import trange, tqdm 11 | 12 | import k_diffusion as K 13 | 14 | 15 | def main(): 16 | p = argparse.ArgumentParser(description=__doc__, 17 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 18 | p.add_argument('--batch-size', type=int, default=64, 19 | help='the batch size') 20 | p.add_argument('--checkpoint', type=str, required=True, 21 | help='the checkpoint to use') 22 | p.add_argument('--config', type=str, required=True, 23 | help='the model config') 24 | p.add_argument('-n', type=int, default=64, 25 | help='the number of images to sample') 26 | p.add_argument('--prefix', type=str, default='out', 27 | help='the output prefix') 28 | p.add_argument('--steps', type=int, default=50, 29 | help='the number of denoising steps') 30 | args = p.parse_args() 31 | 32 | config = K.config.load_config(open(args.config)) 33 | model_config = config['model'] 34 | # TODO: allow non-square input sizes 35 | assert len(model_config['input_size']) == 2 and model_config['input_size'][0] == model_config['input_size'][1] 36 | size = model_config['input_size'] 37 | 38 | accelerator = accelerate.Accelerator() 39 | device = accelerator.device 40 | print('Using device:', device, flush=True) 41 | 42 | inner_model = K.config.make_model(config).eval().requires_grad_(False).to(device) 43 | inner_model.load_state_dict(torch.load(args.checkpoint, map_location='cpu')['model_ema']) 44 | accelerator.print('Parameters:', K.utils.n_params(inner_model)) 45 | model = K.Denoiser(inner_model, sigma_data=model_config['sigma_data']) 46 | 47 | sigma_min = model_config['sigma_min'] 48 | sigma_max = model_config['sigma_max'] 49 | 50 | @torch.no_grad() 51 | @K.utils.eval_mode(model) 52 | def run(): 53 | if accelerator.is_local_main_process: 54 | tqdm.write('Sampling...') 55 | sigmas = K.sampling.get_sigmas_karras(args.steps, sigma_min, sigma_max, rho=7., device=device) 56 | def sample_fn(n): 57 | x = torch.randn([n, model_config['input_channels'], size[0], size[1]], device=device) * sigma_max 58 | x_0 = K.sampling.sample_lms(model, x, sigmas, disable=not accelerator.is_local_main_process) 59 | return x_0 60 | x_0 = K.evaluation.compute_features(accelerator, sample_fn, lambda x: x, args.n, args.batch_size) 61 | if accelerator.is_main_process: 62 | for i, out in enumerate(x_0): 63 | filename = f'{args.prefix}_{i:05}.png' 64 | K.utils.to_pil_image(out).save(filename) 65 | 66 | try: 67 | run() 68 | except KeyboardInterrupt: 69 | pass 70 | 71 | 72 | if __name__ == '__main__': 73 | main() 74 | -------------------------------------------------------------------------------- /sample_clip_guided.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """CLIP guided sampling from k-diffusion models.""" 4 | 5 | import argparse 6 | import math 7 | 8 | import accelerate 9 | import clip 10 | from kornia import augmentation as KA 11 | from resize_right import resize 12 | import torch 13 | from torch.nn import functional as F 14 | from torchvision import transforms 15 | from tqdm import trange, tqdm 16 | 17 | import k_diffusion as K 18 | 19 | 20 | def spherical_dist_loss(x, y): 21 | x = F.normalize(x, dim=-1) 22 | y = F.normalize(y, dim=-1) 23 | return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2) 24 | 25 | 26 | def make_cond_model_fn(model, cond_fn): 27 | def model_fn(x, sigma, **kwargs): 28 | with torch.enable_grad(): 29 | x = x.detach().requires_grad_() 30 | denoised = model(x, sigma, **kwargs) 31 | cond_grad = cond_fn(x, sigma, denoised=denoised, **kwargs).detach() 32 | cond_denoised = denoised.detach() + cond_grad * K.utils.append_dims(sigma**2, x.ndim) 33 | return cond_denoised 34 | return model_fn 35 | 36 | 37 | def make_static_thresh_model_fn(model, value=1.): 38 | def model_fn(x, sigma, **kwargs): 39 | return model(x, sigma, **kwargs).clamp(-value, value) 40 | return model_fn 41 | 42 | 43 | def main(): 44 | p = argparse.ArgumentParser(description=__doc__, 45 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 46 | p.add_argument('prompt', type=str, 47 | default='the prompt to use') 48 | p.add_argument('--batch-size', type=int, default=16, 49 | help='the batch size') 50 | p.add_argument('--checkpoint', type=str, required=True, 51 | help='the checkpoint to use') 52 | p.add_argument('--churn', type=float, default=50., 53 | help='the amount of noise to add during sampling') 54 | p.add_argument('--clip-guidance-scale', '-cgs', type=float, default=500., 55 | help='the CLIP guidance scale') 56 | p.add_argument('--clip-model', type=str, default='ViT-B/16', choices=clip.available_models(), 57 | help='the CLIP model to use') 58 | p.add_argument('--config', type=str, required=True, 59 | help='the model config') 60 | p.add_argument('-n', type=int, default=64, 61 | help='the number of images to sample') 62 | p.add_argument('--prefix', type=str, default='out', 63 | help='the output prefix') 64 | p.add_argument('--steps', type=int, default=100, 65 | help='the number of denoising steps') 66 | args = p.parse_args() 67 | 68 | config = K.config.load_config(open(args.config)) 69 | model_config = config['model'] 70 | # TODO: allow non-square input sizes 71 | assert len(model_config['input_size']) == 2 and model_config['input_size'][0] == model_config['input_size'][1] 72 | size = model_config['input_size'] 73 | 74 | accelerator = accelerate.Accelerator() 75 | device = accelerator.device 76 | print('Using device:', device, flush=True) 77 | 78 | inner_model = K.config.make_model(config).eval().requires_grad_(False).to(device) 79 | inner_model.load_state_dict(torch.load(args.checkpoint, map_location='cpu')['model_ema']) 80 | accelerator.print('Parameters:', K.utils.n_params(inner_model)) 81 | model = K.Denoiser(inner_model, sigma_data=model_config['sigma_data']) 82 | 83 | sigma_min = model_config['sigma_min'] 84 | sigma_max = model_config['sigma_max'] 85 | 86 | clip_model = clip.load(args.clip_model, device=device)[0].eval().requires_grad_(False) 87 | clip_normalize = transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), 88 | std=(0.26862954, 0.26130258, 0.27577711)) 89 | clip_size = (clip_model.visual.input_resolution, clip_model.visual.input_resolution) 90 | aug = KA.RandomAffine(0, (1/14, 1/14), p=1, padding_mode='border') 91 | 92 | def get_image_embed(x): 93 | if x.shape[2:4] != clip_size: 94 | x = resize(x, out_shape=clip_size, pad_mode='reflect') 95 | x = clip_normalize(x) 96 | x = clip_model.encode_image(x).float() 97 | return F.normalize(x) 98 | 99 | target_embed = F.normalize(clip_model.encode_text(clip.tokenize(args.prompt, truncate=True).to(device)).float()) 100 | 101 | def cond_fn(x, t, denoised): 102 | image_embed = get_image_embed(aug(denoised.add(1).div(2))) 103 | loss = spherical_dist_loss(image_embed, target_embed).sum() * args.clip_guidance_scale 104 | grad = -torch.autograd.grad(loss, x)[0] 105 | return grad 106 | 107 | model_fn = make_cond_model_fn(model, cond_fn) 108 | model_fn = make_static_thresh_model_fn(model_fn) 109 | 110 | @torch.no_grad() 111 | @K.utils.eval_mode(model) 112 | def run(): 113 | if accelerator.is_local_main_process: 114 | tqdm.write('Sampling...') 115 | sigmas = K.sampling.get_sigmas_karras(args.steps, sigma_min, sigma_max, rho=7., device=device) 116 | def sample_fn(n): 117 | x = torch.randn([n, model_config['input_channels'], size[0], size[1]], device=device) * sigmas[0] 118 | x_0 = K.sampling.sample_dpm_2(model_fn, x, sigmas, s_churn=args.churn, disable=not accelerator.is_local_main_process) 119 | return x_0 120 | x_0 = K.evaluation.compute_features(accelerator, sample_fn, lambda x: x, args.n, args.batch_size) 121 | if accelerator.is_main_process: 122 | for i, out in enumerate(x_0): 123 | filename = f'{args.prefix}_{i:05}.png' 124 | K.utils.to_pil_image(out).save(filename) 125 | 126 | try: 127 | run() 128 | except KeyboardInterrupt: 129 | pass 130 | 131 | 132 | if __name__ == '__main__': 133 | main() 134 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = k-diffusion 3 | version = 0.0.1 4 | author = Katherine Crowson 5 | author_email = crowsonkb@gmail.com 6 | url = https://github.com/crowsonkb/k-diffusion 7 | description = Karras et al. (2022) diffusion models for PyTorch 8 | long_description = file: README.md 9 | long_description_content_type = text/markdown 10 | license = MIT 11 | 12 | [options] 13 | packages = find: 14 | install_requires = 15 | accelerate 16 | clean-fid 17 | einops 18 | jsonmerge 19 | kornia 20 | Pillow 21 | resize-right 22 | scikit-image 23 | scipy 24 | torch 25 | torchdiffeq 26 | torchvision 27 | tqdm 28 | wandb 29 | CLIP @ git+https://github.com/openai/CLIP 30 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """Trains Karras et al. (2022) diffusion models.""" 4 | 5 | import argparse 6 | from copy import deepcopy 7 | from functools import partial 8 | import math 9 | import json 10 | from pathlib import Path 11 | 12 | import accelerate 13 | import torch 14 | from torch import optim 15 | from torch import multiprocessing as mp 16 | from torch.utils import data 17 | from torchvision import datasets, transforms, utils 18 | from tqdm.auto import trange, tqdm 19 | 20 | import k_diffusion as K 21 | 22 | 23 | def main(): 24 | p = argparse.ArgumentParser(description=__doc__, 25 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 26 | p.add_argument('--batch-size', type=int, default=64, 27 | help='the batch size') 28 | p.add_argument('--config', type=str, required=True, 29 | help='the configuration file') 30 | p.add_argument('--demo-every', type=int, default=500, 31 | help='save a demo grid every this many steps') 32 | p.add_argument('--evaluate-every', type=int, default=10000, 33 | help='save a demo grid every this many steps') 34 | p.add_argument('--evaluate-n', type=int, default=2000, 35 | help='the number of samples to draw to evaluate') 36 | p.add_argument('--gns', action='store_true', 37 | help='measure the gradient noise scale (DDP only)') 38 | p.add_argument('--grad-accum-steps', type=int, default=1, 39 | help='the number of gradient accumulation steps') 40 | p.add_argument('--grow', type=str, 41 | help='the checkpoint to grow from') 42 | p.add_argument('--grow-config', type=str, 43 | help='the configuration file of the model to grow from') 44 | p.add_argument('--lr', type=float, 45 | help='the learning rate') 46 | p.add_argument('--name', type=str, default='model', 47 | help='the name of the run') 48 | p.add_argument('--num-workers', type=int, default=8, 49 | help='the number of data loader workers') 50 | p.add_argument('--resume', type=str, 51 | help='the checkpoint to resume from') 52 | p.add_argument('--sample-n', type=int, default=64, 53 | help='the number of images to sample for demo grids') 54 | p.add_argument('--save-every', type=int, default=10000, 55 | help='save every this many steps') 56 | p.add_argument('--seed', type=int, 57 | help='the random seed') 58 | p.add_argument('--start-method', type=str, default='spawn', 59 | choices=['fork', 'forkserver', 'spawn'], 60 | help='the multiprocessing start method') 61 | p.add_argument('--wandb-entity', type=str, 62 | help='the wandb entity name') 63 | p.add_argument('--wandb-group', type=str, 64 | help='the wandb group name') 65 | p.add_argument('--wandb-project', type=str, 66 | help='the wandb project name (specify this to enable wandb)') 67 | p.add_argument('--wandb-save-model', action='store_true', 68 | help='save model to wandb') 69 | args = p.parse_args() 70 | 71 | mp.set_start_method(args.start_method) 72 | 73 | config = K.config.load_config(open(args.config)) 74 | model_config = config['model'] 75 | dataset_config = config['dataset'] 76 | opt_config = config['optimizer'] 77 | sched_config = config['lr_sched'] 78 | ema_sched_config = config['ema_sched'] 79 | 80 | # TODO: allow non-square input sizes 81 | assert len(model_config['input_size']) == 2 and model_config['input_size'][0] == model_config['input_size'][1] 82 | size = model_config['input_size'] 83 | 84 | ddp_kwargs = accelerate.DistributedDataParallelKwargs(find_unused_parameters=True) 85 | accelerator = accelerate.Accelerator(kwargs_handlers=[ddp_kwargs], gradient_accumulation_steps=args.grad_accum_steps) 86 | device = accelerator.device 87 | print(f'Process {accelerator.process_index} using device: {device}', flush=True) 88 | 89 | if args.seed is not None: 90 | seeds = torch.randint(-2 ** 63, 2 ** 63 - 1, [accelerator.num_processes], generator=torch.Generator().manual_seed(args.seed)) 91 | torch.manual_seed(seeds[accelerator.process_index]) 92 | 93 | inner_model = K.config.make_model(config) 94 | if accelerator.is_main_process: 95 | print('Parameters:', K.utils.n_params(inner_model)) 96 | 97 | # If logging to wandb, initialize the run 98 | use_wandb = accelerator.is_main_process and args.wandb_project 99 | if use_wandb: 100 | import wandb 101 | log_config = vars(args) 102 | log_config['config'] = config 103 | log_config['parameters'] = K.utils.n_params(inner_model) 104 | wandb.init(project=args.wandb_project, entity=args.wandb_entity, group=args.wandb_group, config=log_config, save_code=True) 105 | 106 | assert opt_config['type'] == 'adamw' 107 | opt = optim.AdamW(inner_model.parameters(), 108 | lr=opt_config['lr'] if args.lr is None else args.lr, 109 | betas=tuple(opt_config['betas']), 110 | eps=opt_config['eps'], 111 | weight_decay=opt_config['weight_decay']) 112 | 113 | if sched_config['type'] == 'inverse': 114 | sched = K.utils.InverseLR(opt, 115 | inv_gamma=sched_config['inv_gamma'], 116 | power=sched_config['power'], 117 | warmup=sched_config['warmup']) 118 | elif sched_config['type'] == 'exponential': 119 | sched = K.utils.ExponentialLR(opt, 120 | num_steps=sched_config['num_steps'], 121 | decay=sched_config['decay'], 122 | warmup=sched_config['warmup']) 123 | else: 124 | raise ValueError('Invalid schedule type') 125 | 126 | assert ema_sched_config['type'] == 'inverse' 127 | ema_sched = K.utils.EMAWarmup(power=ema_sched_config['power'], 128 | max_value=ema_sched_config['max_value']) 129 | 130 | tf = transforms.Compose([ 131 | transforms.Resize(size[0], interpolation=transforms.InterpolationMode.LANCZOS), 132 | transforms.CenterCrop(size[0]), 133 | K.augmentation.KarrasAugmentationPipeline(model_config['augment_prob']), 134 | ]) 135 | 136 | if dataset_config['type'] == 'imagefolder': 137 | train_set = datasets.ImageFolder(dataset_config['location'], transform=tf) 138 | elif dataset_config['type'] == 'cifar10': 139 | train_set = datasets.CIFAR10(dataset_config['location'], train=True, download=True, transform=tf) 140 | elif dataset_config['type'] == 'mnist': 141 | train_set = datasets.MNIST(dataset_config['location'], train=True, download=True, transform=tf) 142 | elif dataset_config['type'] == 'huggingface': 143 | from datasets import load_dataset 144 | train_set = load_dataset(dataset_config['location']) 145 | train_set.set_transform(partial(K.utils.hf_datasets_augs_helper, transform=tf, image_key=dataset_config['image_key'])) 146 | train_set = train_set['train'] 147 | else: 148 | raise ValueError('Invalid dataset type') 149 | 150 | image_key = dataset_config.get('image_key', 0) 151 | 152 | train_dl = data.DataLoader(train_set, args.batch_size, shuffle=True, drop_last=True, 153 | num_workers=args.num_workers, persistent_workers=True) 154 | 155 | if args.grow: 156 | if not args.grow_config: 157 | raise ValueError('--grow requires --grow-config') 158 | ckpt = torch.load(args.grow, map_location='cpu') 159 | old_config = K.config.load_config(open(args.grow_config)) 160 | old_inner_model = K.config.make_model(old_config) 161 | old_inner_model.load_state_dict(ckpt['model_ema']) 162 | if old_config['model']['skip_stages'] != model_config['skip_stages']: 163 | old_inner_model.set_skip_stages(model_config['skip_stages']) 164 | if old_config['model']['patch_size'] != model_config['patch_size']: 165 | old_inner_model.set_patch_size(model_config['patch_size']) 166 | inner_model.load_state_dict(old_inner_model.state_dict()) 167 | del ckpt, old_inner_model 168 | 169 | inner_model, opt, train_dl = accelerator.prepare(inner_model, opt, train_dl) 170 | if use_wandb: 171 | wandb.watch(inner_model) 172 | if args.gns: 173 | gns_stats_hook = K.gns.DDPGradientStatsHook(inner_model) 174 | gns_stats = K.gns.GradientNoiseScale() 175 | else: 176 | gns_stats = None 177 | 178 | model = K.Denoiser(inner_model, sigma_data=model_config['sigma_data']) 179 | model_ema = deepcopy(model) 180 | 181 | state_path = Path(f'{args.name}_state.json') 182 | 183 | if state_path.exists() or args.resume: 184 | if args.resume: 185 | ckpt_path = args.resume 186 | if not args.resume: 187 | state = json.load(open(state_path)) 188 | ckpt_path = state['latest_checkpoint'] 189 | if accelerator.is_main_process: 190 | print(f'Resuming from {ckpt_path}...') 191 | ckpt = torch.load(ckpt_path, map_location='cpu') 192 | accelerator.unwrap_model(model.inner_model).load_state_dict(ckpt['model']) 193 | accelerator.unwrap_model(model_ema.inner_model).load_state_dict(ckpt['model_ema']) 194 | opt.load_state_dict(ckpt['opt']) 195 | sched.load_state_dict(ckpt['sched']) 196 | ema_sched.load_state_dict(ckpt['ema_sched']) 197 | epoch = ckpt['epoch'] + 1 198 | step = ckpt['step'] + 1 199 | if args.gns and 'gns_stats' in ckpt and ckpt['gns_stats'] is not None: 200 | gns_stats.load_state_dict(ckpt['gns_stats']) 201 | del ckpt 202 | else: 203 | epoch = 0 204 | step = 0 205 | 206 | extractor = K.evaluation.InceptionV3FeatureExtractor(device=device) 207 | train_iter = iter(train_dl) 208 | if accelerator.is_main_process: 209 | print('Computing features for reals...') 210 | reals_features = K.evaluation.compute_features(accelerator, lambda x: next(train_iter)[image_key][1], extractor, args.evaluate_n, args.batch_size) 211 | if accelerator.is_main_process: 212 | metrics_log_filepath = Path(f'{args.name}_metrics.csv') 213 | if metrics_log_filepath.exists(): 214 | metrics_log_file = open(metrics_log_filepath, 'a') 215 | else: 216 | metrics_log_file = open(metrics_log_filepath, 'w') 217 | print('step', 'fid', 'kid', sep=',', file=metrics_log_file, flush=True) 218 | del train_iter 219 | 220 | sigma_min = model_config['sigma_min'] 221 | sigma_max = model_config['sigma_max'] 222 | sample_density = K.config.make_sample_density(model_config) 223 | 224 | @torch.no_grad() 225 | @K.utils.eval_mode(model_ema) 226 | def demo(): 227 | if accelerator.is_main_process: 228 | tqdm.write('Sampling...') 229 | filename = f'{args.name}_demo_{step:08}.png' 230 | n_per_proc = math.ceil(args.sample_n / accelerator.num_processes) 231 | x = torch.randn([n_per_proc, model_config['input_channels'], size[0], size[1]], device=device) * sigma_max 232 | sigmas = K.sampling.get_sigmas_karras(50, sigma_min, sigma_max, rho=7., device=device) 233 | x_0 = K.sampling.sample_lms(model_ema, x, sigmas, disable=not accelerator.is_main_process) 234 | x_0 = accelerator.gather(x_0)[:args.sample_n] 235 | if accelerator.is_main_process: 236 | grid = utils.make_grid(x_0, nrow=math.ceil(args.sample_n ** 0.5), padding=0) 237 | K.utils.to_pil_image(grid).save(filename) 238 | if use_wandb: 239 | wandb.log({'demo_grid': wandb.Image(filename)}, step=step) 240 | 241 | @torch.no_grad() 242 | @K.utils.eval_mode(model_ema) 243 | def evaluate(): 244 | if accelerator.is_main_process: 245 | tqdm.write('Evaluating...') 246 | sigmas = K.sampling.get_sigmas_karras(50, sigma_min, sigma_max, rho=7., device=device) 247 | def sample_fn(n): 248 | x = torch.randn([n, model_config['input_channels'], size[0], size[1]], device=device) * sigma_max 249 | x_0 = K.sampling.sample_lms(model_ema, x, sigmas, disable=True) 250 | return x_0 251 | fakes_features = K.evaluation.compute_features(accelerator, sample_fn, extractor, args.evaluate_n, args.batch_size) 252 | if accelerator.is_main_process: 253 | fid = K.evaluation.fid(fakes_features, reals_features) 254 | kid = K.evaluation.kid(fakes_features, reals_features) 255 | print(f'FID: {fid.item():g}, KID: {kid.item():g}') 256 | if accelerator.is_main_process: 257 | print(step, fid.item(), kid.item(), sep=',', file=metrics_log_file, flush=True) 258 | if use_wandb: 259 | wandb.log({'FID': fid.item(), 'KID': kid.item()}, step=step) 260 | 261 | def save(): 262 | accelerator.wait_for_everyone() 263 | filename = f'{args.name}_{step:08}.pth' 264 | if accelerator.is_main_process: 265 | tqdm.write(f'Saving to {filename}...') 266 | obj = { 267 | 'model': accelerator.unwrap_model(model.inner_model).state_dict(), 268 | 'model_ema': accelerator.unwrap_model(model_ema.inner_model).state_dict(), 269 | 'opt': opt.state_dict(), 270 | 'sched': sched.state_dict(), 271 | 'ema_sched': ema_sched.state_dict(), 272 | 'epoch': epoch, 273 | 'step': step, 274 | 'gns_stats': gns_stats.state_dict() if gns_stats is not None else None, 275 | } 276 | accelerator.save(obj, filename) 277 | if accelerator.is_main_process: 278 | state_obj = {'latest_checkpoint': filename} 279 | json.dump(state_obj, open(state_path, 'w')) 280 | if args.wandb_save_model and use_wandb: 281 | wandb.save(filename) 282 | 283 | try: 284 | while True: 285 | for batch in tqdm(train_dl, disable=not accelerator.is_main_process): 286 | with accelerator.accumulate(model): 287 | reals, _, aug_cond = batch[image_key] 288 | noise = torch.randn_like(reals) 289 | sigma = sample_density([reals.shape[0]], device=device) 290 | losses = model.loss(reals, noise, sigma, aug_cond=aug_cond) 291 | losses_all = accelerator.gather(losses.detach()) 292 | loss_local = losses.mean() 293 | loss = losses_all.mean() 294 | accelerator.backward(loss_local) 295 | if args.gns: 296 | sq_norm_small_batch, sq_norm_large_batch = accelerator.reduce(gns_stats_hook.get_stats(), 'mean').tolist() 297 | gns_stats.update(sq_norm_small_batch, sq_norm_large_batch, reals.shape[0], reals.shape[0] * accelerator.num_processes) 298 | opt.step() 299 | sched.step() 300 | opt.zero_grad() 301 | if accelerator.sync_gradients: 302 | ema_decay = ema_sched.get_value() 303 | K.utils.ema_update(model, model_ema, ema_decay) 304 | ema_sched.step() 305 | 306 | if accelerator.is_main_process: 307 | if step % 25 == 0: 308 | if args.gns: 309 | tqdm.write(f'Epoch: {epoch}, step: {step}, loss: {loss.item():g}, gns: {gns_stats.get_gns():g}') 310 | else: 311 | tqdm.write(f'Epoch: {epoch}, step: {step}, loss: {loss.item():g}') 312 | 313 | if use_wandb: 314 | log_dict = { 315 | 'epoch': epoch, 316 | 'loss': loss.item(), 317 | 'lr': sched.get_last_lr()[0], 318 | 'ema_decay': ema_decay, 319 | } 320 | if args.gns: 321 | log_dict['gradient_noise_scale'] = gns_stats.get_gns() 322 | wandb.log(log_dict, step=step) 323 | 324 | if step % args.demo_every == 0: 325 | demo() 326 | 327 | if step > 0 and args.evaluate_every > 0 and step % args.evaluate_every == 0: 328 | evaluate() 329 | 330 | if step > 0 and step % args.save_every == 0: 331 | save() 332 | 333 | step += 1 334 | epoch += 1 335 | except KeyboardInterrupt: 336 | pass 337 | 338 | 339 | if __name__ == '__main__': 340 | main() 341 | --------------------------------------------------------------------------------