├── k_diffusion ├── models │ ├── __init__.py │ └── image_v1.py ├── __init__.py ├── augmentation.py ├── gns.py ├── config.py ├── evaluation.py ├── external.py ├── layers.py ├── utils.py └── sampling.py ├── setup.py ├── pyproject.toml ├── .gitignore ├── requirements.txt ├── setup.cfg ├── LICENSE ├── .github └── workflows │ └── python-publish.yml ├── configs ├── config_cifar10.json ├── config_mnist.json ├── config_32x32_small.json ├── no-temb.json ├── config_fashion.json └── config_32x32_small_butterflies.json ├── make_grid.py ├── msample.py ├── sample.py ├── README.md ├── eval.py ├── sample_clip_guided.py ├── mtrain.py └── train.py /k_diffusion/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .image_v1 import ImageDenoiserModelV1 2 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | 4 | if __name__ == '__main__': 5 | setup() 6 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools"] 3 | build-backend = "setuptools.build_meta" 4 | -------------------------------------------------------------------------------- /k_diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | from . import augmentation, config, evaluation, external, gns, layers, models, sampling, utils 2 | from .layers import Denoiser 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | out.png 2 | tags 3 | outputs/ 4 | *_state.json 5 | *.swp 6 | venv* 7 | __pycache__ 8 | .ipynb_checkpoints 9 | *.pth 10 | *.egg-info 11 | data 12 | *_demo_*.png 13 | wandb/* 14 | *.csv 15 | .env 16 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | clip-anytorch 3 | einops 4 | fastcore 5 | jsonmerge 6 | kornia 7 | Pillow 8 | resize-right 9 | scikit-image 10 | scipy 11 | torch 12 | torchdiffeq 13 | torchvision 14 | tqdm 15 | wandb 16 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = k-diffusion 3 | version = 0.0.9 4 | author = Katherine Crowson 5 | author_email = crowsonkb@gmail.com 6 | url = https://github.com/crowsonkb/k-diffusion 7 | description = Karras et al. (2022) diffusion models for PyTorch 8 | long_description = file: README.md 9 | long_description_content_type = text/markdown 10 | license = MIT 11 | 12 | [options] 13 | packages = find: 14 | install_requires = 15 | accelerate 16 | clean-fid 17 | clip-anytorch 18 | einops 19 | fastcore 20 | jsonmerge 21 | kornia 22 | Pillow 23 | resize-right 24 | scikit-image 25 | scipy 26 | torch 27 | torchdiffeq 28 | torchvision 29 | tqdm 30 | wandb 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2022 Katherine Crowson 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | jobs: 8 | deploy: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v2 12 | - uses: actions-ecosystem/action-regex-match@v2 13 | id: regex-match 14 | with: 15 | text: ${{ github.event.head_commit.message }} 16 | regex: '^Release ([^ ]+)' 17 | - name: Set up Python 18 | uses: actions/setup-python@v2 19 | with: 20 | python-version: '3.8' 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install setuptools wheel twine 25 | - name: Release 26 | if: ${{ steps.regex-match.outputs.match != '' }} 27 | uses: softprops/action-gh-release@v1 28 | with: 29 | tag_name: v${{ steps.regex-match.outputs.group1 }} 30 | - name: Build and publish 31 | if: ${{ steps.regex-match.outputs.match != '' }} 32 | env: 33 | TWINE_USERNAME: __token__ 34 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 35 | run: | 36 | python setup.py sdist bdist_wheel 37 | twine upload dist/* 38 | -------------------------------------------------------------------------------- /configs/config_cifar10.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": { 3 | "type": "image_v1", 4 | "input_channels": 3, 5 | "input_size": [32, 32], 6 | "patch_size": 1, 7 | "mapping_out": 256, 8 | "depths": [2, 4, 4], 9 | "channels": [128, 256, 512], 10 | "self_attn_depths": [false, true, true], 11 | "has_variance": true, 12 | "dropout_rate": 0.05, 13 | "augment_wrapper": true, 14 | "augment_prob": 0.12, 15 | "sigma_data": 0.5, 16 | "sigma_min": 1e-2, 17 | "sigma_max": 80, 18 | "sigma_sample_density": { 19 | "type": "lognormal", 20 | "mean": -1.2, 21 | "std": 1.2 22 | } 23 | }, 24 | "dataset": { 25 | "type": "cifar10", 26 | "location": "data" 27 | }, 28 | "optimizer": { 29 | "type": "adamw", 30 | "lr": 4e-4, 31 | "betas": [0.95, 0.999], 32 | "eps": 1e-6, 33 | "weight_decay": 1e-3 34 | }, 35 | "lr_sched": { 36 | "type": "inverse", 37 | "inv_gamma": 20000.0, 38 | "power": 1.0, 39 | "warmup": 0.99 40 | }, 41 | "ema_sched": { 42 | "type": "inverse", 43 | "power": 0.6667, 44 | "max_value": 0.9999 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /configs/config_mnist.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": { 3 | "type": "image_v1", 4 | "input_channels": 1, 5 | "input_size": [28, 28], 6 | "patch_size": 1, 7 | "mapping_out": 256, 8 | "depths": [2, 4, 4], 9 | "channels": [64, 128, 256], 10 | "self_attn_depths": [false, false, true], 11 | "has_variance": true, 12 | "dropout_rate": 0.05, 13 | "augment_wrapper": true, 14 | "augment_prob": 0.12, 15 | "sigma_data": 0.6162, 16 | "sigma_min": 1e-2, 17 | "sigma_max": 80, 18 | "sigma_sample_density": { 19 | "type": "lognormal", 20 | "mean": -1.2, 21 | "std": 1.2 22 | } 23 | }, 24 | "dataset": { 25 | "type": "mnist", 26 | "location": "data" 27 | }, 28 | "optimizer": { 29 | "type": "adamw", 30 | "lr": 2e-4, 31 | "betas": [0.95, 0.999], 32 | "eps": 1e-6, 33 | "weight_decay": 1e-3 34 | }, 35 | "lr_sched": { 36 | "type": "inverse", 37 | "inv_gamma": 20000.0, 38 | "power": 1.0, 39 | "warmup": 0.99 40 | }, 41 | "ema_sched": { 42 | "type": "inverse", 43 | "power": 0.6667, 44 | "max_value": 0.9999 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /configs/config_32x32_small.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": { 3 | "type": "image_v1", 4 | "input_channels": 3, 5 | "input_size": [32, 32], 6 | "patch_size": 1, 7 | "mapping_out": 256, 8 | "depths": [2, 4, 4], 9 | "channels": [128, 256, 512], 10 | "self_attn_depths": [false, true, true], 11 | "has_variance": true, 12 | "dropout_rate": 0.05, 13 | "augment_wrapper": true, 14 | "augment_prob": 0.12, 15 | "sigma_data": 0.5, 16 | "sigma_min": 1e-2, 17 | "sigma_max": 80, 18 | "sigma_sample_density": { 19 | "type": "lognormal", 20 | "mean": -1.2, 21 | "std": 1.2 22 | } 23 | }, 24 | "dataset": { 25 | "type": "imagefolder", 26 | "location": "/path/to/dataset" 27 | }, 28 | "optimizer": { 29 | "type": "adamw", 30 | "lr": 1e-4, 31 | "betas": [0.95, 0.999], 32 | "eps": 1e-6, 33 | "weight_decay": 1e-3 34 | }, 35 | "lr_sched": { 36 | "type": "inverse", 37 | "inv_gamma": 20000.0, 38 | "power": 1.0, 39 | "warmup": 0.99 40 | }, 41 | "ema_sched": { 42 | "type": "inverse", 43 | "power": 0.6667, 44 | "max_value": 0.9999 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /configs/no-temb.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": { 3 | "type": "image_v1", 4 | "input_channels": 1, 5 | "input_size": [28, 28], 6 | "patch_size": 1, 7 | "mapping_out": 256, 8 | "depths": [2, 4, 4], 9 | "channels": [64, 128, 256], 10 | "self_attn_depths": [false, false, true], 11 | "has_variance": false, 12 | "dropout_rate": 0.05, 13 | "augment_wrapper": true, 14 | "augment_prob": 0.0, 15 | "unscaled": true, 16 | "t_embed": false, 17 | "sigma_data": 0.6162, 18 | "sigma_min": 1e-2, 19 | "sigma_max": 80, 20 | "sigma_sample_density": { 21 | "type": "lognormal", 22 | "mean": -1.2, 23 | "std": 1.2 24 | } 25 | }, 26 | "dataset": { 27 | "type": "fashion", 28 | "location": "data" 29 | }, 30 | "optimizer": { 31 | "type": "adamw", 32 | "lr": 8e-4, 33 | "betas": [0.95, 0.999], 34 | "eps": 1e-6, 35 | "weight_decay": 1e-3 36 | }, 37 | "lr_sched": { 38 | "type": "inverse", 39 | "inv_gamma": 20000.0, 40 | "power": 1.0, 41 | "warmup": 0.99 42 | }, 43 | "ema_sched": { 44 | "type": "inverse", 45 | "power": 0.6667, 46 | "max_value": 0.9999 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /configs/config_fashion.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": { 3 | "type": "image_v1", 4 | "input_channels": 1, 5 | "input_size": [28, 28], 6 | "patch_size": 1, 7 | "mapping_out": 256, 8 | "depths": [2, 4, 4], 9 | "channels": [64, 128, 256], 10 | "self_attn_depths": [false, false, true], 11 | "has_variance": false, 12 | "dropout_rate": 0.05, 13 | "augment_wrapper": true, 14 | "augment_prob": 0.0, 15 | "unscaled": false, 16 | "t_embed": true, 17 | "sigma_data": 0.6162, 18 | "sigma_min": 1e-2, 19 | "sigma_max": 80, 20 | "sigma_sample_density": { 21 | "type": "lognormal", 22 | "mean": -1.2, 23 | "std": 1.2 24 | } 25 | }, 26 | "dataset": { 27 | "type": "fashion", 28 | "location": "data" 29 | }, 30 | "optimizer": { 31 | "type": "adamw", 32 | "lr": 8e-4, 33 | "betas": [0.95, 0.999], 34 | "eps": 1e-6, 35 | "weight_decay": 1e-3 36 | }, 37 | "lr_sched": { 38 | "type": "inverse", 39 | "inv_gamma": 20000.0, 40 | "power": 1.0, 41 | "warmup": 0.99 42 | }, 43 | "ema_sched": { 44 | "type": "inverse", 45 | "power": 0.6667, 46 | "max_value": 0.9999 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /configs/config_32x32_small_butterflies.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": { 3 | "type": "image_v1", 4 | "input_channels": 3, 5 | "input_size": [32, 32], 6 | "patch_size": 1, 7 | "mapping_out": 256, 8 | "depths": [2, 4, 4], 9 | "channels": [128, 256, 512], 10 | "self_attn_depths": [false, true, true], 11 | "has_variance": true, 12 | "dropout_rate": 0.05, 13 | "augment_wrapper": true, 14 | "augment_prob": 0.12, 15 | "sigma_data": 0.5, 16 | "sigma_min": 1e-2, 17 | "sigma_max": 80, 18 | "sigma_sample_density": { 19 | "type": "lognormal", 20 | "mean": -1.2, 21 | "std": 1.2 22 | } 23 | }, 24 | "dataset": { 25 | "type": "huggingface", 26 | "location": "huggan/smithsonian_butterflies_subset", 27 | "image_key": "image" 28 | }, 29 | "optimizer": { 30 | "type": "adamw", 31 | "lr": 1e-4, 32 | "betas": [0.95, 0.999], 33 | "eps": 1e-6, 34 | "weight_decay": 1e-3 35 | }, 36 | "lr_sched": { 37 | "type": "inverse", 38 | "inv_gamma": 20000.0, 39 | "power": 1.0, 40 | "warmup": 0.99 41 | }, 42 | "ema_sched": { 43 | "type": "inverse", 44 | "power": 0.6667, 45 | "max_value": 0.9999 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /make_grid.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """Assembles images into a grid.""" 4 | 5 | import argparse 6 | import math 7 | import sys 8 | 9 | from PIL import Image 10 | 11 | 12 | def main(): 13 | p = argparse.ArgumentParser(description=__doc__) 14 | p.add_argument('images', type=str, nargs='+', metavar='image', 15 | help='the input images') 16 | p.add_argument('--output', '-o', type=str, default='out.png', 17 | help='the output image') 18 | p.add_argument('--nrow', type=int, 19 | help='the number of images per row') 20 | args = p.parse_args() 21 | 22 | images = [Image.open(image) for image in args.images] 23 | mode = images[0].mode 24 | size = images[0].size 25 | for image, name in zip(images, args.images): 26 | if image.mode != mode: 27 | print(f'Error: Image {name} had mode {image.mode}, expected {mode}', file=sys.stderr) 28 | sys.exit(1) 29 | if image.size != size: 30 | print(f'Error: Image {name} had size {image.size}, expected {size}', file=sys.stderr) 31 | sys.exit(1) 32 | 33 | n = len(images) 34 | x = args.nrow if args.nrow else math.ceil(n**0.5) 35 | y = math.ceil(n / x) 36 | 37 | output = Image.new(mode, (size[0] * x, size[1] * y)) 38 | for i, image in enumerate(images): 39 | cur_x, cur_y = i % x, i // x 40 | output.paste(image, (size[0] * cur_x, size[1] * cur_y)) 41 | 42 | output.save(args.output) 43 | 44 | 45 | if __name__ == '__main__': 46 | main() 47 | -------------------------------------------------------------------------------- /msample.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Samples from k-diffusion models.""" 3 | 4 | import argparse, math, accelerate, torch, k_diffusion as K 5 | from fastcore.script import call_parse 6 | from tqdm import trange, tqdm 7 | from torchvision import utils 8 | 9 | @call_parse 10 | def main( 11 | config: str, # model config 12 | checkpoint: str, # checkpoint to use 13 | batch_size:int=64, # batch size 14 | n:int=64, # number of images to sample 15 | out:str='out', # output file name without extension 16 | steps:int=50, # number of denoising steps 17 | seed:int=0, # random seed 18 | churn:float=0., # sampler churn 19 | sampler:str='sample_lms', # sample_lms, sample_dpm_2, sample_euler, etc 20 | ): 21 | sampler = getattr(K.sampling, sampler) 22 | if seed: torch.manual_seed(seed) 23 | config = K.config.load_config(open(config)) 24 | model_config = config['model'] 25 | size = model_config['input_size'] 26 | accelerator = accelerate.Accelerator() 27 | device = accelerator.device 28 | inner_model = K.config.make_model(config).eval().requires_grad_(False).to(device) 29 | inner_model.load_state_dict(torch.load(checkpoint, map_location='cpu')['model_ema']) 30 | model = K.config.make_denoiser_wrapper(config)(inner_model) 31 | 32 | sigma_max = model_config['sigma_max'] 33 | sigmas = K.sampling.get_sigmas_karras(steps, model_config['sigma_min'], sigma_max, rho=7., device=device) 34 | def sample_fn(n): 35 | x = torch.randn([n, model_config['input_channels'], *size], device=device) * sigma_max 36 | return sampler(model, x, sigmas, **{'s_churn':churn} if churn else {}) 37 | x_0 = K.evaluation.compute_features(accelerator, sample_fn, lambda x: x, n, batch_size) 38 | grid = utils.make_grid(x_0, nrow=math.ceil(n ** 0.5), padding=0) 39 | K.utils.to_pil_image(grid).save(f'{out}.png') 40 | 41 | -------------------------------------------------------------------------------- /sample.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """Samples from k-diffusion models.""" 4 | 5 | import argparse 6 | import math 7 | 8 | import accelerate 9 | import torch 10 | from tqdm import trange, tqdm 11 | 12 | import k_diffusion as K 13 | 14 | 15 | def main(): 16 | p = argparse.ArgumentParser(description=__doc__, 17 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 18 | p.add_argument('--batch-size', type=int, default=64, 19 | help='the batch size') 20 | p.add_argument('--checkpoint', type=str, required=True, 21 | help='the checkpoint to use') 22 | p.add_argument('--config', type=str, required=True, 23 | help='the model config') 24 | p.add_argument('-n', type=int, default=64, 25 | help='the number of images to sample') 26 | p.add_argument('--prefix', type=str, default='out', 27 | help='the output prefix') 28 | p.add_argument('--steps', type=int, default=50, 29 | help='the number of denoising steps') 30 | args = p.parse_args() 31 | 32 | config = K.config.load_config(open(args.config)) 33 | model_config = config['model'] 34 | # TODO: allow non-square input sizes 35 | assert len(model_config['input_size']) == 2 and model_config['input_size'][0] == model_config['input_size'][1] 36 | size = model_config['input_size'] 37 | 38 | accelerator = accelerate.Accelerator() 39 | device = accelerator.device 40 | print('Using device:', device, flush=True) 41 | 42 | inner_model = K.config.make_model(config).eval().requires_grad_(False).to(device) 43 | inner_model.load_state_dict(torch.load(args.checkpoint, map_location='cpu')['model_ema']) 44 | accelerator.print('Parameters:', K.utils.n_params(inner_model)) 45 | model = K.Denoiser(inner_model, sigma_data=model_config['sigma_data']) 46 | 47 | sigma_min = model_config['sigma_min'] 48 | sigma_max = model_config['sigma_max'] 49 | 50 | @torch.no_grad() 51 | @K.utils.eval_mode(model) 52 | def run(): 53 | if accelerator.is_local_main_process: 54 | tqdm.write('Sampling...') 55 | sigmas = K.sampling.get_sigmas_karras(args.steps, sigma_min, sigma_max, rho=7., device=device) 56 | def sample_fn(n): 57 | x = torch.randn([n, model_config['input_channels'], size[0], size[1]], device=device) * sigma_max 58 | x_0 = K.sampling.sample_lms(model, x, sigmas, disable=not accelerator.is_local_main_process) 59 | return x_0 60 | x_0 = K.evaluation.compute_features(accelerator, sample_fn, lambda x: x, args.n, args.batch_size) 61 | if accelerator.is_main_process: 62 | for i, out in enumerate(x_0): 63 | filename = f'{args.prefix}_{i:05}.png' 64 | K.utils.to_pil_image(out).save(filename) 65 | 66 | try: 67 | run() 68 | except KeyboardInterrupt: 69 | pass 70 | 71 | 72 | if __name__ == '__main__': 73 | main() 74 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # k-diffusion 2 | 3 | An implementation of [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364) (Karras et al., 2022) for PyTorch. The patching method in [Improving Diffusion Model Efficiency Through Patching](https://arxiv.org/abs/2207.04316) is implemented as well. 4 | 5 | ## Installation 6 | 7 | `k-diffusion` can be installed via PyPI (`pip install k-diffusion`) but it will not include training and inference scripts, only library code that others can depend on. To run the training and inference scripts, clone this repository and run `pip install -e `. 8 | 9 | ## Training: 10 | 11 | To train models: 12 | 13 | ```sh 14 | $ ./train.py --config CONFIG_FILE --name RUN_NAME 15 | ``` 16 | 17 | For instance, to train a model on MNIST: 18 | 19 | ```sh 20 | $ ./train.py --config configs/config_mnist.json --name RUN_NAME 21 | ``` 22 | 23 | The configuration file allows you to specify the dataset type. Currently supported types are `"imagefolder"` (finds all images in that folder and its subfolders, recursively), `"cifar10"` (CIFAR-10), and `"mnist"` (MNIST). `"huggingface"` [Hugging Face Datasets](https://huggingface.co/docs/datasets/index) is also supported. 24 | 25 | Multi-GPU and multi-node training is supported with [Hugging Face Accelerate](https://huggingface.co/docs/accelerate/index). You can configure Accelerate by running: 26 | 27 | ```sh 28 | $ accelerate config 29 | ``` 30 | 31 | on all nodes, then running: 32 | 33 | ```sh 34 | $ accelerate launch train.py --config CONFIG_FILE --name RUN_NAME 35 | ``` 36 | 37 | on all nodes. 38 | 39 | ## Enhancements/additional features: 40 | 41 | - k-diffusion supports an experimental model output type, an isotropic Gaussian, which seems to have a lower gradient noise scale and to train faster than Karras et al. (2022) diffusion models. 42 | 43 | - k-diffusion has wrappers for [v-diffusion-pytorch](https://github.com/crowsonkb/v-diffusion-pytorch), [OpenAI diffusion](https://github.com/openai/guided-diffusion), and [CompVis diffusion](https://github.com/CompVis/latent-diffusion) models allowing them to be used with its samplers and ODE/SDE. 44 | 45 | - k-diffusion models support progressive growing. 46 | 47 | - k-diffusion implements [DPM-Solver](https://arxiv.org/abs/2206.00927), which produces higher quality samples at the same number of function evalutions as Karras Algorithm 2, as well as supporting adaptive step size control. It also implements a [linear multistep](https://en.wikipedia.org/wiki/Linear_multistep_method#Adams–Bashforth_methods) sampler (comparable to [PLMS](https://arxiv.org/abs/2202.09778)). 48 | 49 | - k-diffusion supports [CLIP](https://openai.com/blog/clip/) guided sampling from unconditional diffusion models (see `sample_clip_guided.py`). 50 | 51 | - k-diffusion supports log likelihood calculation (not a variational lower bound) for native models and all wrapped models. 52 | 53 | - k-diffusion can calculate, during training, the [FID](https://papers.nips.cc/paper/2017/file/8a1d694707eb0fefe65871369074926d-Paper.pdf) and [KID](https://arxiv.org/abs/1801.01401) vs the training set. 54 | 55 | - k-diffusion can calculate, during training, the gradient noise scale (1 / SNR), from _An Empirical Model of Large-Batch Training_, https://arxiv.org/abs/1812.06162). 56 | 57 | ## To do: 58 | 59 | - Anything except unconditional image diffusion models 60 | 61 | - Latent diffusion 62 | -------------------------------------------------------------------------------- /k_diffusion/augmentation.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | import math 3 | import operator 4 | 5 | import numpy as np 6 | from skimage import transform 7 | import torch 8 | from torch import nn 9 | 10 | 11 | def translate2d(tx, ty): 12 | mat = [[1, 0, tx], 13 | [0, 1, ty], 14 | [0, 0, 1]] 15 | return torch.tensor(mat, dtype=torch.float32) 16 | 17 | 18 | def scale2d(sx, sy): 19 | mat = [[sx, 0, 0], 20 | [ 0, sy, 0], 21 | [ 0, 0, 1]] 22 | return torch.tensor(mat, dtype=torch.float32) 23 | 24 | 25 | def rotate2d(theta): 26 | mat = [[torch.cos(theta), torch.sin(-theta), 0], 27 | [torch.sin(theta), torch.cos(theta), 0], 28 | [ 0, 0, 1]] 29 | return torch.tensor(mat, dtype=torch.float32) 30 | 31 | 32 | class KarrasAugmentationPipeline: 33 | def __init__(self, a_prob=0.12, a_scale=2**0.2, a_aniso=2**0.2, a_trans=1/8): 34 | self.a_prob = a_prob 35 | self.a_scale = a_scale 36 | self.a_aniso = a_aniso 37 | self.a_trans = a_trans 38 | 39 | def __call__(self, image): 40 | h, w = image.size 41 | mats = [translate2d(h / 2 - 0.5, w / 2 - 0.5)] 42 | 43 | # x-flip 44 | a0 = torch.randint(2, []).float() 45 | mats.append(scale2d(1 - 2 * a0, 1)) 46 | # y-flip 47 | do = (torch.rand([]) < self.a_prob).float() 48 | a1 = torch.randint(2, []).float() * do 49 | mats.append(scale2d(1, 1 - 2 * a1)) 50 | # scaling 51 | do = (torch.rand([]) < self.a_prob).float() 52 | a2 = torch.randn([]) * do 53 | mats.append(scale2d(self.a_scale ** a2, self.a_scale ** a2)) 54 | # rotation 55 | do = (torch.rand([]) < self.a_prob).float() 56 | a3 = (torch.rand([]) * 2 * math.pi - math.pi) * do 57 | mats.append(rotate2d(-a3)) 58 | # anisotropy 59 | do = (torch.rand([]) < self.a_prob).float() 60 | a4 = (torch.rand([]) * 2 * math.pi - math.pi) * do 61 | a5 = torch.randn([]) * do 62 | mats.append(rotate2d(a4)) 63 | mats.append(scale2d(self.a_aniso ** a5, self.a_aniso ** -a5)) 64 | mats.append(rotate2d(-a4)) 65 | # translation 66 | do = (torch.rand([]) < self.a_prob).float() 67 | a6 = torch.randn([]) * do 68 | a7 = torch.randn([]) * do 69 | mats.append(translate2d(self.a_trans * w * a6, self.a_trans * h * a7)) 70 | 71 | # form the transformation matrix and conditioning vector 72 | mats.append(translate2d(-h / 2 + 0.5, -w / 2 + 0.5)) 73 | mat = reduce(operator.matmul, mats) 74 | cond = torch.stack([a0, a1, a2, a3.cos() - 1, a3.sin(), a5 * a4.cos(), a5 * a4.sin(), a6, a7]) 75 | 76 | # apply the transformation 77 | image_orig = np.array(image, dtype=np.float32) / 255 78 | if image_orig.ndim == 2: 79 | image_orig = image_orig[..., None] 80 | tf = transform.AffineTransform(mat.numpy()) 81 | image = transform.warp(image_orig, tf.inverse, order=3, mode='reflect', cval=0.5, clip=False, preserve_range=True) 82 | image_orig = torch.as_tensor(image_orig).movedim(2, 0) * 2 - 1 83 | image = torch.as_tensor(image).movedim(2, 0) * 2 - 1 84 | return image, image_orig, cond 85 | 86 | 87 | class KarrasAugmentWrapper(nn.Module): 88 | def __init__(self, model): 89 | super().__init__() 90 | self.inner_model = model 91 | 92 | def forward(self, input, sigma, aug_cond=None, mapping_cond=None, **kwargs): 93 | if aug_cond is None: 94 | aug_cond = input.new_zeros([input.shape[0], 9]) 95 | if mapping_cond is None: 96 | mapping_cond = aug_cond 97 | else: 98 | mapping_cond = torch.cat([aug_cond, mapping_cond], dim=1) 99 | return self.inner_model(input, sigma, mapping_cond=mapping_cond, **kwargs) 100 | 101 | def set_skip_stages(self, skip_stages): 102 | return self.inner_model.set_skip_stages(skip_stages) 103 | 104 | def set_patch_size(self, patch_size): 105 | return self.inner_model.set_patch_size(patch_size) 106 | -------------------------------------------------------------------------------- /k_diffusion/gns.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class DDPGradientStatsHook: 6 | def __init__(self, ddp_module): 7 | try: 8 | ddp_module.register_comm_hook(self, self._hook_fn) 9 | except AttributeError: 10 | raise ValueError('DDPGradientStatsHook does not support non-DDP wrapped modules') 11 | self._clear_state() 12 | 13 | def _clear_state(self): 14 | self.bucket_sq_norms_small_batch = [] 15 | self.bucket_sq_norms_large_batch = [] 16 | 17 | @staticmethod 18 | def _hook_fn(self, bucket): 19 | buf = bucket.buffer() 20 | self.bucket_sq_norms_small_batch.append(buf.pow(2).sum()) 21 | fut = torch.distributed.all_reduce(buf, op=torch.distributed.ReduceOp.AVG, async_op=True).get_future() 22 | def callback(fut): 23 | buf = fut.value()[0] 24 | self.bucket_sq_norms_large_batch.append(buf.pow(2).sum()) 25 | return buf 26 | return fut.then(callback) 27 | 28 | def get_stats(self): 29 | sq_norm_small_batch = sum(self.bucket_sq_norms_small_batch) 30 | sq_norm_large_batch = sum(self.bucket_sq_norms_large_batch) 31 | self._clear_state() 32 | stats = torch.stack([sq_norm_small_batch, sq_norm_large_batch]) 33 | torch.distributed.all_reduce(stats, op=torch.distributed.ReduceOp.AVG) 34 | return stats[0].item(), stats[1].item() 35 | 36 | 37 | class GradientNoiseScale: 38 | """Calculates the gradient noise scale (1 / SNR), or critical batch size, 39 | from _An Empirical Model of Large-Batch Training_, 40 | https://arxiv.org/abs/1812.06162). 41 | 42 | Args: 43 | beta (float): The decay factor for the exponential moving averages used to 44 | calculate the gradient noise scale. 45 | Default: 0.9998 46 | eps (float): Added for numerical stability. 47 | Default: 1e-8 48 | """ 49 | 50 | def __init__(self, beta=0.9998, eps=1e-8): 51 | self.beta = beta 52 | self.eps = eps 53 | self.ema_sq_norm = 0. 54 | self.ema_var = 0. 55 | self.beta_cumprod = 1. 56 | self.gradient_noise_scale = float('nan') 57 | 58 | def state_dict(self): 59 | """Returns the state of the object as a :class:`dict`.""" 60 | return dict(self.__dict__.items()) 61 | 62 | def load_state_dict(self, state_dict): 63 | """Loads the object's state. 64 | Args: 65 | state_dict (dict): object state. Should be an object returned 66 | from a call to :meth:`state_dict`. 67 | """ 68 | self.__dict__.update(state_dict) 69 | 70 | def update(self, sq_norm_small_batch, sq_norm_large_batch, n_small_batch, n_large_batch): 71 | """Updates the state with a new batch's gradient statistics, and returns the 72 | current gradient noise scale. 73 | 74 | Args: 75 | sq_norm_small_batch (float): The mean of the squared 2-norms of microbatch or 76 | per sample gradients. 77 | sq_norm_large_batch (float): The squared 2-norm of the mean of the microbatch or 78 | per sample gradients. 79 | n_small_batch (int): The batch size of the individual microbatch or per sample 80 | gradients (1 if per sample). 81 | n_large_batch (int): The total batch size of the mean of the microbatch or 82 | per sample gradients. 83 | """ 84 | est_sq_norm = (n_large_batch * sq_norm_large_batch - n_small_batch * sq_norm_small_batch) / (n_large_batch - n_small_batch) 85 | est_var = (sq_norm_small_batch - sq_norm_large_batch) / (1 / n_small_batch - 1 / n_large_batch) 86 | self.ema_sq_norm = self.beta * self.ema_sq_norm + (1 - self.beta) * est_sq_norm 87 | self.ema_var = self.beta * self.ema_var + (1 - self.beta) * est_var 88 | self.beta_cumprod *= self.beta 89 | self.gradient_noise_scale = max(self.ema_var, self.eps) / max(self.ema_sq_norm, self.eps) 90 | return self.gradient_noise_scale 91 | 92 | def get_gns(self): 93 | """Returns the current gradient noise scale.""" 94 | return self.gradient_noise_scale 95 | 96 | def get_stats(self): 97 | """Returns the current (debiased) estimates of the squared mean gradient 98 | and gradient variance.""" 99 | return self.ema_sq_norm / (1 - self.beta_cumprod), self.ema_var / (1 - self.beta_cumprod) 100 | -------------------------------------------------------------------------------- /k_diffusion/config.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import json 3 | import math 4 | import warnings 5 | 6 | from jsonmerge import merge 7 | 8 | from . import augmentation, layers, models, utils 9 | 10 | 11 | def load_config(file): 12 | defaults = { 13 | 'model': { 14 | 'sigma_data': 1., 15 | 'patch_size': 1, 16 | 'dropout_rate': 0., 17 | 'augment_wrapper': True, 18 | 'augment_prob': 0., 19 | 'mapping_cond_dim': 0, 20 | 'unet_cond_dim': 0, 21 | 'cross_cond_dim': 0, 22 | 'cross_attn_depths': None, 23 | 'skip_stages': 0, 24 | 'has_variance': False, 25 | }, 26 | 'dataset': { 27 | 'type': 'imagefolder', 28 | }, 29 | 'optimizer': { 30 | 'type': 'adamw', 31 | 'lr': 1e-4, 32 | 'betas': [0.95, 0.999], 33 | 'eps': 1e-6, 34 | 'weight_decay': 1e-3, 35 | }, 36 | 'lr_sched': { 37 | 'type': 'inverse', 38 | 'inv_gamma': 20000., 39 | 'power': 1., 40 | 'warmup': 0.99, 41 | }, 42 | 'ema_sched': { 43 | 'type': 'inverse', 44 | 'power': 0.6667, 45 | 'max_value': 0.9999 46 | }, 47 | } 48 | config = json.load(file) 49 | return merge(defaults, config) 50 | 51 | 52 | def make_model(config): 53 | config = config['model'] 54 | assert config['type'] == 'image_v1' 55 | model = models.ImageDenoiserModelV1( 56 | config['input_channels'], 57 | config['mapping_out'], 58 | config['depths'], 59 | config['channels'], 60 | config['self_attn_depths'], 61 | config['cross_attn_depths'], 62 | patch_size=config['patch_size'], 63 | dropout_rate=config['dropout_rate'], 64 | mapping_cond_dim=config['mapping_cond_dim'] + (9 if config['augment_wrapper'] else 0), 65 | unet_cond_dim=config['unet_cond_dim'], 66 | cross_cond_dim=config['cross_cond_dim'], 67 | skip_stages=config['skip_stages'], 68 | has_variance=config['has_variance'], 69 | t_embed=config.get('t_embed', True), 70 | ) 71 | return augmentation.KarrasAugmentWrapper(model) if config['augment_wrapper'] else model 72 | 73 | 74 | def make_denoiser_wrapper(config): 75 | config = config['model'] 76 | sigma_data = config.get('sigma_data', 1.) 77 | has_variance = config.get('has_variance', False) 78 | if not has_variance: 79 | return partial(layers.Denoiser, sigma_data=sigma_data, unscaled=config.get('unscaled', False)) 80 | return partial(layers.DenoiserWithVariance, sigma_data=sigma_data) 81 | 82 | 83 | def make_sample_density(config): 84 | sd_config = config['sigma_sample_density'] 85 | sigma_data = config['sigma_data'] 86 | if sd_config['type'] == 'lognormal': 87 | loc = sd_config['mean'] if 'mean' in sd_config else sd_config['loc'] 88 | scale = sd_config['std'] if 'std' in sd_config else sd_config['scale'] 89 | return partial(utils.rand_log_normal, loc=loc, scale=scale) 90 | if sd_config['type'] == 'loglogistic': 91 | loc = sd_config['loc'] if 'loc' in sd_config else math.log(sigma_data) 92 | scale = sd_config['scale'] if 'scale' in sd_config else 0.5 93 | min_value = sd_config['min_value'] if 'min_value' in sd_config else 0. 94 | max_value = sd_config['max_value'] if 'max_value' in sd_config else float('inf') 95 | return partial(utils.rand_log_logistic, loc=loc, scale=scale, min_value=min_value, max_value=max_value) 96 | if sd_config['type'] == 'loguniform': 97 | min_value = sd_config['min_value'] if 'min_value' in sd_config else config['sigma_min'] 98 | max_value = sd_config['max_value'] if 'max_value' in sd_config else config['sigma_max'] 99 | return partial(utils.rand_log_uniform, min_value=min_value, max_value=max_value) 100 | if sd_config['type'] == 'v-diffusion': 101 | min_value = sd_config['min_value'] if 'min_value' in sd_config else 0. 102 | max_value = sd_config['max_value'] if 'max_value' in sd_config else float('inf') 103 | return partial(utils.rand_v_diffusion, sigma_data=sigma_data, min_value=min_value, max_value=max_value) 104 | if sd_config['type'] == 'split-lognormal': 105 | loc = sd_config['mean'] if 'mean' in sd_config else sd_config['loc'] 106 | scale_1 = sd_config['std_1'] if 'std_1' in sd_config else sd_config['scale_1'] 107 | scale_2 = sd_config['std_2'] if 'std_2' in sd_config else sd_config['scale_2'] 108 | return partial(utils.rand_split_log_normal, loc=loc, scale_1=scale_1, scale_2=scale_2) 109 | raise ValueError('Unknown sample density type') 110 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Evaluate models.""" 3 | 4 | import math, accelerate, torch 5 | from copy import deepcopy 6 | from functools import partial 7 | from pathlib import Path 8 | from fastcore.script import call_parse 9 | 10 | from torch import nn, optim 11 | from torch import multiprocessing as mp 12 | from torch.utils import data 13 | from torchvision import datasets, transforms, utils 14 | from tqdm.auto import trange, tqdm 15 | 16 | import k_diffusion as K 17 | 18 | #sampler = K.sampling.sample_lms 19 | #sampler = K.sampling.sample_euler 20 | sampler = K.sampling.sample_heun 21 | 22 | @call_parse 23 | def main( 24 | config:str, # the configuration file 25 | batch_size:int=256, # the batch size 26 | sample_steps:int=50, # number of steps to use when sampling 27 | evaluate_n:int=2000, # the number of samples to draw to evaluate 28 | checkpoint:str='model_00050000.pth', # the path of the checkpoint 29 | sample_n:int=64, # the number of images to sample for demo grids 30 | ): 31 | path = Path('outputs') 32 | path.mkdir(exist_ok=True) 33 | 34 | config = K.config.load_config(open(config)) 35 | model_cfg = config['model'] 36 | dataset_cfg = config['dataset'] 37 | 38 | # TODO: allow non-square input sizes 39 | assert len(model_cfg['input_size']) == 2 and model_cfg['input_size'][0] == model_cfg['input_size'][1] 40 | size = model_cfg['input_size'] 41 | 42 | accelerator = accelerate.Accelerator() 43 | device = accelerator.device 44 | print(f'Process {accelerator.process_index} using device: {device}', flush=True) 45 | 46 | inner_model = K.config.make_model(config) 47 | print('Parameters:', K.utils.n_params(inner_model)) 48 | 49 | tf = transforms.Compose([ 50 | transforms.Resize(size[0], interpolation=transforms.InterpolationMode.LANCZOS), 51 | transforms.CenterCrop(size[0]), 52 | K.augmentation.KarrasAugmentationPipeline(model_cfg['augment_prob']), 53 | ]) 54 | 55 | if dataset_cfg['type'] == 'imagefolder': 56 | train_set = K.utils.FolderOfImages(dataset_cfg['location'], transform=tf) 57 | elif dataset_cfg['type'] == 'cifar10': 58 | train_set = datasets.CIFAR10(dataset_cfg['location'], train=True, download=True, transform=tf) 59 | elif dataset_cfg['type'] == 'fashion': 60 | train_set = datasets.FashionMNIST(dataset_cfg['location'], train=True, download=True, transform=tf) 61 | elif dataset_cfg['type'] == 'mnist': 62 | train_set = datasets.MNIST(dataset_cfg['location'], train=True, download=True, transform=tf) 63 | elif dataset_cfg['type'] == 'huggingface': 64 | from datasets import load_dataset 65 | train_set = load_dataset(dataset_cfg['location']) 66 | train_set.set_transform(partial(K.utils.hf_datasets_augs_helper, transform=tf, image_key=dataset_cfg['image_key'])) 67 | train_set = train_set['train'] 68 | else: raise ValueError('Invalid dataset type') 69 | 70 | try: print('Number of items in dataset:', len(train_set)) 71 | except TypeError: pass 72 | 73 | image_key = dataset_cfg.get('image_key', 0) 74 | train_dl = data.DataLoader(train_set, batch_size, shuffle=True, drop_last=True, num_workers=8, persistent_workers=True) 75 | 76 | inner_model, train_dl = accelerator.prepare(inner_model, train_dl) 77 | sigma_min = model_cfg['sigma_min'] 78 | sigma_max = model_cfg['sigma_max'] 79 | model = K.config.make_denoiser_wrapper(config)(inner_model) 80 | 81 | # Load checkpoint 82 | ckpt_path = checkpoint 83 | print(f'Loading ema from {ckpt_path}...') 84 | ckpt = torch.load(ckpt_path, map_location='cpu') 85 | accelerator.unwrap_model(model.inner_model).load_state_dict(ckpt['model_ema']) 86 | 87 | extractor = K.evaluation.InceptionV3FeatureExtractor(device=device) 88 | train_iter = iter(train_dl) 89 | reals_features = K.evaluation.compute_features(accelerator, lambda x: next(train_iter)[image_key][1], extractor, evaluate_n, batch_size) 90 | del train_iter 91 | 92 | @torch.no_grad() 93 | @K.utils.eval_mode(model) 94 | def demo(): 95 | tqdm.write('Sampling...') 96 | filename = f'{checkpoint}_eval.png' 97 | x = torch.randn([sample_n, model_cfg['input_channels'], size[0], size[1]], device=device) * sigma_max 98 | sigmas = K.sampling.get_sigmas_karras(sample_steps, sigma_min, sigma_max, rho=7., device=device) 99 | x_0 = sampler(model, x, sigmas) 100 | x_0 = x_0[:sample_n] 101 | grid = utils.make_grid(-x_0, nrow=math.ceil(sample_n ** 0.5), padding=0) 102 | K.utils.to_pil_image(grid).save(filename) 103 | 104 | @torch.no_grad() 105 | @K.utils.eval_mode(model) 106 | def evaluate(): 107 | tqdm.write('Evaluating...') 108 | sigmas = K.sampling.get_sigmas_karras(sample_steps, sigma_min, sigma_max, rho=7., device=device) 109 | def sample_fn(n): 110 | x = torch.randn([n, model_cfg['input_channels'], size[0], size[1]], device=device) * sigma_max 111 | return sampler(model, x, sigmas) 112 | fakes_features = K.evaluation.compute_features(accelerator, sample_fn, extractor, evaluate_n, batch_size) 113 | fid = K.evaluation.fid(fakes_features, reals_features) 114 | kid = K.evaluation.kid(fakes_features, reals_features) 115 | print(f'FID: {fid.item():g}, KID: {kid.item():g}') 116 | # metrics_log.write(step, fid.item(), kid.item()) 117 | 118 | demo() 119 | evaluate() 120 | -------------------------------------------------------------------------------- /k_diffusion/evaluation.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | from pathlib import Path 4 | 5 | from cleanfid.inception_torchscript import InceptionV3W 6 | import clip 7 | from resize_right import resize 8 | import torch 9 | from torch import nn 10 | from torch.nn import functional as F 11 | from torchvision import transforms 12 | from tqdm.auto import trange 13 | 14 | from . import utils 15 | 16 | 17 | class InceptionV3FeatureExtractor(nn.Module): 18 | def __init__(self, device='cpu'): 19 | super().__init__() 20 | path = Path(os.environ.get('XDG_CACHE_HOME', Path.home() / '.cache')) / 'k-diffusion' 21 | url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt' 22 | digest = 'f58cb9b6ec323ed63459aa4fb441fe750cfe39fafad6da5cb504a16f19e958f4' 23 | utils.download_file(path / 'inception-2015-12-05.pt', url, digest) 24 | self.model = InceptionV3W(str(path), resize_inside=False).to(device) 25 | self.size = (299, 299) 26 | 27 | def forward(self, x): 28 | if x.shape[2:4] != self.size: 29 | x = resize(x, out_shape=self.size, pad_mode='reflect') 30 | if x.shape[1] == 1: 31 | x = torch.cat([x] * 3, dim=1) 32 | x = (x * 127.5 + 127.5).clamp(0, 255) 33 | return self.model(x) 34 | 35 | 36 | class CLIPFeatureExtractor(nn.Module): 37 | def __init__(self, name='ViT-L/14@336px', device='cpu'): 38 | super().__init__() 39 | self.model = clip.load(name, device=device)[0].eval().requires_grad_(False) 40 | self.normalize = transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), 41 | std=(0.26862954, 0.26130258, 0.27577711)) 42 | self.size = (self.model.visual.input_resolution, self.model.visual.input_resolution) 43 | 44 | def forward(self, x): 45 | if x.shape[2:4] != self.size: 46 | x = resize(x.add(1).div(2), out_shape=self.size, pad_mode='reflect').clamp(0, 1) 47 | x = self.normalize(x) 48 | x = self.model.encode_image(x).float() 49 | x = F.normalize(x) * x.shape[1] ** 0.5 50 | return x 51 | 52 | 53 | def compute_features(accelerator, sample_fn, extractor_fn, n, batch_size): 54 | n_per_proc = math.ceil(n / accelerator.num_processes) 55 | feats_all = [] 56 | try: 57 | for i in trange(0, n_per_proc, batch_size, disable=not accelerator.is_main_process): 58 | cur_batch_size = min(n - i, batch_size) 59 | samples = sample_fn(cur_batch_size)[:cur_batch_size] 60 | feats_all.append(accelerator.gather(extractor_fn(samples))) 61 | except StopIteration: 62 | pass 63 | return torch.cat(feats_all)[:n] 64 | 65 | 66 | def polynomial_kernel(x, y): 67 | d = x.shape[-1] 68 | dot = x @ y.transpose(-2, -1) 69 | return (dot / d + 1) ** 3 70 | 71 | 72 | def squared_mmd(x, y, kernel=polynomial_kernel): 73 | m = x.shape[-2] 74 | n = y.shape[-2] 75 | kxx = kernel(x, x) 76 | kyy = kernel(y, y) 77 | kxy = kernel(x, y) 78 | kxx_sum = kxx.sum([-1, -2]) - kxx.diagonal(dim1=-1, dim2=-2).sum(-1) 79 | kyy_sum = kyy.sum([-1, -2]) - kyy.diagonal(dim1=-1, dim2=-2).sum(-1) 80 | kxy_sum = kxy.sum([-1, -2]) 81 | term_1 = kxx_sum / m / (m - 1) 82 | term_2 = kyy_sum / n / (n - 1) 83 | term_3 = kxy_sum * 2 / m / n 84 | return term_1 + term_2 - term_3 85 | 86 | 87 | @utils.tf32_mode(matmul=False) 88 | def kid(x, y, max_size=5000): 89 | x_size, y_size = x.shape[0], y.shape[0] 90 | n_partitions = math.ceil(max(x_size / max_size, y_size / max_size)) 91 | total_mmd = x.new_zeros([]) 92 | for i in range(n_partitions): 93 | cur_x = x[round(i * x_size / n_partitions):round((i + 1) * x_size / n_partitions)] 94 | cur_y = y[round(i * y_size / n_partitions):round((i + 1) * y_size / n_partitions)] 95 | total_mmd = total_mmd + squared_mmd(cur_x, cur_y) 96 | return total_mmd / n_partitions 97 | 98 | 99 | class _MatrixSquareRootEig(torch.autograd.Function): 100 | @staticmethod 101 | def forward(ctx, a): 102 | vals, vecs = torch.linalg.eigh(a) 103 | ctx.save_for_backward(vals, vecs) 104 | return vecs @ vals.abs().sqrt().diag_embed() @ vecs.transpose(-2, -1) 105 | 106 | @staticmethod 107 | def backward(ctx, grad_output): 108 | vals, vecs = ctx.saved_tensors 109 | d = vals.abs().sqrt().unsqueeze(-1).repeat_interleave(vals.shape[-1], -1) 110 | vecs_t = vecs.transpose(-2, -1) 111 | return vecs @ (vecs_t @ grad_output @ vecs / (d + d.transpose(-2, -1))) @ vecs_t 112 | 113 | 114 | def sqrtm_eig(a): 115 | if a.ndim < 2: 116 | raise RuntimeError('tensor of matrices must have at least 2 dimensions') 117 | if a.shape[-2] != a.shape[-1]: 118 | raise RuntimeError('tensor must be batches of square matrices') 119 | return _MatrixSquareRootEig.apply(a) 120 | 121 | 122 | @utils.tf32_mode(matmul=False) 123 | def fid(x, y, eps=1e-8): 124 | x_mean = x.mean(dim=0) 125 | y_mean = y.mean(dim=0) 126 | mean_term = (x_mean - y_mean).pow(2).sum() 127 | x_cov = torch.cov(x.T) 128 | y_cov = torch.cov(y.T) 129 | eps_eye = torch.eye(x_cov.shape[0], device=x_cov.device, dtype=x_cov.dtype) * eps 130 | x_cov = x_cov + eps_eye 131 | y_cov = y_cov + eps_eye 132 | x_cov_sqrt = sqrtm_eig(x_cov) 133 | cov_term = torch.trace(x_cov + y_cov - 2 * sqrtm_eig(x_cov_sqrt @ y_cov @ x_cov_sqrt)) 134 | return mean_term + cov_term 135 | -------------------------------------------------------------------------------- /k_diffusion/external.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from . import sampling, utils 7 | 8 | 9 | class VDenoiser(nn.Module): 10 | """A v-diffusion-pytorch model wrapper for k-diffusion.""" 11 | 12 | def __init__(self, inner_model): 13 | super().__init__() 14 | self.inner_model = inner_model 15 | self.sigma_data = 1. 16 | 17 | def get_scalings(self, sigma): 18 | c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) 19 | c_out = -sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 20 | c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 21 | return c_skip, c_out, c_in 22 | 23 | def sigma_to_t(self, sigma): 24 | return sigma.atan() / math.pi * 2 25 | 26 | def t_to_sigma(self, t): 27 | return (t * math.pi / 2).tan() 28 | 29 | def loss(self, input, noise, sigma, **kwargs): 30 | c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] 31 | noised_input = input + noise * utils.append_dims(sigma, input.ndim) 32 | model_output = self.inner_model(noised_input * c_in, self.sigma_to_t(sigma), **kwargs) 33 | target = (input - c_skip * noised_input) / c_out 34 | return (model_output - target).pow(2).flatten(1).mean(1) 35 | 36 | def forward(self, input, sigma, **kwargs): 37 | c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] 38 | return self.inner_model(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip 39 | 40 | 41 | class DiscreteSchedule(nn.Module): 42 | """A mapping between continuous noise levels (sigmas) and a list of discrete noise 43 | levels.""" 44 | 45 | def __init__(self, sigmas, quantize): 46 | super().__init__() 47 | self.register_buffer('sigmas', sigmas) 48 | self.register_buffer('log_sigmas', sigmas.log()) 49 | self.quantize = quantize 50 | 51 | @property 52 | def sigma_min(self): 53 | return self.sigmas[0] 54 | 55 | @property 56 | def sigma_max(self): 57 | return self.sigmas[-1] 58 | 59 | def get_sigmas(self, n=None): 60 | if n is None: 61 | return sampling.append_zero(self.sigmas.flip(0)) 62 | t_max = len(self.sigmas) - 1 63 | t = torch.linspace(t_max, 0, n, device=self.sigmas.device) 64 | return sampling.append_zero(self.t_to_sigma(t)) 65 | 66 | def sigma_to_t(self, sigma, quantize=None): 67 | quantize = self.quantize if quantize is None else quantize 68 | log_sigma = sigma.log() 69 | dists = log_sigma - self.log_sigmas[:, None] 70 | if quantize: 71 | return dists.abs().argmin(dim=0).view(sigma.shape) 72 | low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2) 73 | high_idx = low_idx + 1 74 | low, high = self.log_sigmas[low_idx], self.log_sigmas[high_idx] 75 | w = (low - log_sigma) / (low - high) 76 | w = w.clamp(0, 1) 77 | t = (1 - w) * low_idx + w * high_idx 78 | return t.view(sigma.shape) 79 | 80 | def t_to_sigma(self, t): 81 | t = t.float() 82 | low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac() 83 | log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx] 84 | return log_sigma.exp() 85 | 86 | 87 | class DiscreteEpsDDPMDenoiser(DiscreteSchedule): 88 | """A wrapper for discrete schedule DDPM models that output eps (the predicted 89 | noise).""" 90 | 91 | def __init__(self, model, alphas_cumprod, quantize): 92 | super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize) 93 | self.inner_model = model 94 | self.sigma_data = 1. 95 | 96 | def get_scalings(self, sigma): 97 | c_out = -sigma 98 | c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 99 | return c_out, c_in 100 | 101 | def get_eps(self, *args, **kwargs): 102 | return self.inner_model(*args, **kwargs) 103 | 104 | def loss(self, input, noise, sigma, **kwargs): 105 | c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] 106 | noised_input = input + noise * utils.append_dims(sigma, input.ndim) 107 | eps = self.get_eps(noised_input * c_in, self.sigma_to_t(sigma), **kwargs) 108 | return (eps - noise).pow(2).flatten(1).mean(1) 109 | 110 | def forward(self, input, sigma, **kwargs): 111 | c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] 112 | eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs) 113 | return input + eps * c_out 114 | 115 | 116 | class OpenAIDenoiser(DiscreteEpsDDPMDenoiser): 117 | """A wrapper for OpenAI diffusion models.""" 118 | 119 | def __init__(self, model, diffusion, quantize=False, has_learned_sigmas=True, device='cpu'): 120 | alphas_cumprod = torch.tensor(diffusion.alphas_cumprod, device=device, dtype=torch.float32) 121 | super().__init__(model, alphas_cumprod, quantize=quantize) 122 | self.has_learned_sigmas = has_learned_sigmas 123 | 124 | def get_eps(self, *args, **kwargs): 125 | model_output = self.inner_model(*args, **kwargs) 126 | if self.has_learned_sigmas: 127 | return model_output.chunk(2, dim=1)[0] 128 | return model_output 129 | 130 | 131 | class CompVisDenoiser(DiscreteEpsDDPMDenoiser): 132 | """A wrapper for CompVis diffusion models.""" 133 | 134 | def __init__(self, model, quantize=False, device='cpu'): 135 | super().__init__(model, model.alphas_cumprod, quantize=quantize) 136 | 137 | def get_eps(self, *args, **kwargs): 138 | return self.inner_model.apply_model(*args, **kwargs) 139 | -------------------------------------------------------------------------------- /sample_clip_guided.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """CLIP guided sampling from k-diffusion models.""" 4 | 5 | import argparse 6 | import math 7 | 8 | import accelerate 9 | import clip 10 | from kornia import augmentation as KA 11 | from resize_right import resize 12 | import torch 13 | from torch.nn import functional as F 14 | from torchvision import transforms 15 | from tqdm import trange, tqdm 16 | 17 | import k_diffusion as K 18 | 19 | 20 | def spherical_dist_loss(x, y): 21 | x = F.normalize(x, dim=-1) 22 | y = F.normalize(y, dim=-1) 23 | return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2) 24 | 25 | 26 | def make_cond_model_fn(model, cond_fn): 27 | def model_fn(x, sigma, **kwargs): 28 | with torch.enable_grad(): 29 | x = x.detach().requires_grad_() 30 | denoised = model(x, sigma, **kwargs) 31 | cond_grad = cond_fn(x, sigma, denoised=denoised, **kwargs).detach() 32 | cond_denoised = denoised.detach() + cond_grad * K.utils.append_dims(sigma**2, x.ndim) 33 | return cond_denoised 34 | return model_fn 35 | 36 | 37 | def make_static_thresh_model_fn(model, value=1.): 38 | def model_fn(x, sigma, **kwargs): 39 | return model(x, sigma, **kwargs).clamp(-value, value) 40 | return model_fn 41 | 42 | 43 | def main(): 44 | p = argparse.ArgumentParser(description=__doc__, 45 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 46 | p.add_argument('prompt', type=str, 47 | default='the prompt to use') 48 | p.add_argument('--batch-size', type=int, default=16, 49 | help='the batch size') 50 | p.add_argument('--checkpoint', type=str, required=True, 51 | help='the checkpoint to use') 52 | p.add_argument('--churn', type=float, default=50., 53 | help='the amount of noise to add during sampling') 54 | p.add_argument('--clip-guidance-scale', '-cgs', type=float, default=500., 55 | help='the CLIP guidance scale') 56 | p.add_argument('--clip-model', type=str, default='ViT-B/16', choices=clip.available_models(), 57 | help='the CLIP model to use') 58 | p.add_argument('--config', type=str, required=True, 59 | help='the model config') 60 | p.add_argument('-n', type=int, default=64, 61 | help='the number of images to sample') 62 | p.add_argument('--prefix', type=str, default='out', 63 | help='the output prefix') 64 | p.add_argument('--steps', type=int, default=100, 65 | help='the number of denoising steps') 66 | args = p.parse_args() 67 | 68 | config = K.config.load_config(open(args.config)) 69 | model_config = config['model'] 70 | # TODO: allow non-square input sizes 71 | assert len(model_config['input_size']) == 2 and model_config['input_size'][0] == model_config['input_size'][1] 72 | size = model_config['input_size'] 73 | 74 | accelerator = accelerate.Accelerator() 75 | device = accelerator.device 76 | print('Using device:', device, flush=True) 77 | 78 | inner_model = K.config.make_model(config).eval().requires_grad_(False).to(device) 79 | inner_model.load_state_dict(torch.load(args.checkpoint, map_location='cpu')['model_ema']) 80 | accelerator.print('Parameters:', K.utils.n_params(inner_model)) 81 | model = K.Denoiser(inner_model, sigma_data=model_config['sigma_data']) 82 | 83 | sigma_min = model_config['sigma_min'] 84 | sigma_max = model_config['sigma_max'] 85 | 86 | clip_model = clip.load(args.clip_model, device=device)[0].eval().requires_grad_(False) 87 | clip_normalize = transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), 88 | std=(0.26862954, 0.26130258, 0.27577711)) 89 | clip_size = (clip_model.visual.input_resolution, clip_model.visual.input_resolution) 90 | aug = KA.RandomAffine(0, (1/14, 1/14), p=1, padding_mode='border') 91 | 92 | def get_image_embed(x): 93 | if x.shape[2:4] != clip_size: 94 | x = resize(x, out_shape=clip_size, pad_mode='reflect') 95 | x = clip_normalize(x) 96 | x = clip_model.encode_image(x).float() 97 | return F.normalize(x) 98 | 99 | target_embed = F.normalize(clip_model.encode_text(clip.tokenize(args.prompt, truncate=True).to(device)).float()) 100 | 101 | def cond_fn(x, t, denoised): 102 | image_embed = get_image_embed(aug(denoised.add(1).div(2))) 103 | loss = spherical_dist_loss(image_embed, target_embed).sum() * args.clip_guidance_scale 104 | grad = -torch.autograd.grad(loss, x)[0] 105 | return grad 106 | 107 | model_fn = make_cond_model_fn(model, cond_fn) 108 | model_fn = make_static_thresh_model_fn(model_fn) 109 | 110 | @torch.no_grad() 111 | @K.utils.eval_mode(model) 112 | def run(): 113 | if accelerator.is_local_main_process: 114 | tqdm.write('Sampling...') 115 | sigmas = K.sampling.get_sigmas_karras(args.steps, sigma_min, sigma_max, rho=7., device=device) 116 | def sample_fn(n): 117 | x = torch.randn([n, model_config['input_channels'], size[0], size[1]], device=device) * sigmas[0] 118 | x_0 = K.sampling.sample_dpm_2(model_fn, x, sigmas, s_churn=args.churn, disable=not accelerator.is_local_main_process) 119 | return x_0 120 | x_0 = K.evaluation.compute_features(accelerator, sample_fn, lambda x: x, args.n, args.batch_size) 121 | if accelerator.is_main_process: 122 | for i, out in enumerate(x_0): 123 | filename = f'{args.prefix}_{i:05}.png' 124 | K.utils.to_pil_image(out).save(filename) 125 | 126 | try: 127 | run() 128 | except KeyboardInterrupt: 129 | pass 130 | 131 | 132 | if __name__ == '__main__': 133 | main() 134 | -------------------------------------------------------------------------------- /k_diffusion/models/image_v1.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | from .. import layers, utils 8 | 9 | 10 | def orthogonal_(module): 11 | nn.init.orthogonal_(module.weight) 12 | return module 13 | 14 | 15 | class ResConvBlock(layers.ConditionedResidualBlock): 16 | def __init__(self, feats_in, c_in, c_mid, c_out, group_size=32, dropout_rate=0.): 17 | skip = None if c_in == c_out else orthogonal_(nn.Conv2d(c_in, c_out, 1, bias=False)) 18 | super().__init__( 19 | layers.AdaGN(feats_in, c_in, max(1, c_in // group_size)), 20 | nn.GELU(), 21 | nn.Conv2d(c_in, c_mid, 3, padding=1), 22 | nn.Dropout2d(dropout_rate, inplace=True), 23 | layers.AdaGN(feats_in, c_mid, max(1, c_mid // group_size)), 24 | nn.GELU(), 25 | nn.Conv2d(c_mid, c_out, 3, padding=1), 26 | nn.Dropout2d(dropout_rate, inplace=True), 27 | skip=skip) 28 | 29 | 30 | class DBlock(layers.ConditionedSequential): 31 | def __init__(self, n_layers, feats_in, c_in, c_mid, c_out, group_size=32, head_size=64, dropout_rate=0., downsample=False, self_attn=False, cross_attn=False, c_enc=0): 32 | modules = [nn.Identity()] 33 | for i in range(n_layers): 34 | my_c_in = c_in if i == 0 else c_mid 35 | my_c_out = c_mid if i < n_layers - 1 else c_out 36 | modules.append(ResConvBlock(feats_in, my_c_in, c_mid, my_c_out, group_size, dropout_rate)) 37 | if self_attn: 38 | norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size)) 39 | modules.append(layers.SelfAttention2d(my_c_out, max(1, my_c_out // head_size), norm, dropout_rate)) 40 | if cross_attn: 41 | norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size)) 42 | modules.append(layers.CrossAttention2d(my_c_out, c_enc, max(1, my_c_out // head_size), norm, dropout_rate)) 43 | super().__init__(*modules) 44 | self.set_downsample(downsample) 45 | 46 | def set_downsample(self, downsample): 47 | self[0] = layers.Downsample2d() if downsample else nn.Identity() 48 | return self 49 | 50 | 51 | class UBlock(layers.ConditionedSequential): 52 | def __init__(self, n_layers, feats_in, c_in, c_mid, c_out, group_size=32, head_size=64, dropout_rate=0., upsample=False, self_attn=False, cross_attn=False, c_enc=0): 53 | modules = [] 54 | for i in range(n_layers): 55 | my_c_in = c_in if i == 0 else c_mid 56 | my_c_out = c_mid if i < n_layers - 1 else c_out 57 | modules.append(ResConvBlock(feats_in, my_c_in, c_mid, my_c_out, group_size, dropout_rate)) 58 | if self_attn: 59 | norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size)) 60 | modules.append(layers.SelfAttention2d(my_c_out, max(1, my_c_out // head_size), norm, dropout_rate)) 61 | if cross_attn: 62 | norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size)) 63 | modules.append(layers.CrossAttention2d(my_c_out, c_enc, max(1, my_c_out // head_size), norm, dropout_rate)) 64 | modules.append(nn.Identity()) 65 | super().__init__(*modules) 66 | self.set_upsample(upsample) 67 | 68 | def forward(self, input, cond, skip=None): 69 | if skip is not None: 70 | input = torch.cat([input, skip], dim=1) 71 | return super().forward(input, cond) 72 | 73 | def set_upsample(self, upsample): 74 | self[-1] = layers.Upsample2d() if upsample else nn.Identity() 75 | return self 76 | 77 | 78 | class MappingNet(nn.Sequential): 79 | def __init__(self, feats_in, feats_out, n_layers=2): 80 | layers = [] 81 | for i in range(n_layers): 82 | layers.append(orthogonal_(nn.Linear(feats_in if i == 0 else feats_out, feats_out))) 83 | layers.append(nn.GELU()) 84 | super().__init__(*layers) 85 | 86 | 87 | class ImageDenoiserModelV1(nn.Module): 88 | def __init__(self, c_in, feats_in, depths, channels, self_attn_depths, cross_attn_depths=None, mapping_cond_dim=0, unet_cond_dim=0, cross_cond_dim=0, dropout_rate=0., patch_size=1, skip_stages=0, has_variance=False, t_embed=True): 89 | super().__init__() 90 | self.c_in = c_in 91 | self.channels = channels 92 | self.unet_cond_dim = unet_cond_dim 93 | self.patch_size = patch_size 94 | self.has_variance = has_variance 95 | self.t_embed = t_embed 96 | self.timestep_embed = layers.FourierFeatures(1, feats_in) 97 | if mapping_cond_dim > 0: 98 | self.mapping_cond = nn.Linear(mapping_cond_dim, feats_in, bias=False) 99 | self.mapping = MappingNet(feats_in, feats_in) 100 | self.proj_in = nn.Conv2d((c_in + unet_cond_dim) * self.patch_size ** 2, channels[max(0, skip_stages - 1)], 1) 101 | self.proj_out = nn.Conv2d(channels[max(0, skip_stages - 1)], c_in * self.patch_size ** 2 + (1 if self.has_variance else 0), 1) 102 | nn.init.zeros_(self.proj_out.weight) 103 | nn.init.zeros_(self.proj_out.bias) 104 | if cross_cond_dim == 0: 105 | cross_attn_depths = [False] * len(self_attn_depths) 106 | d_blocks, u_blocks = [], [] 107 | for i in range(len(depths)): 108 | my_c_in = channels[max(0, i - 1)] 109 | d_blocks.append(DBlock(depths[i], feats_in, my_c_in, channels[i], channels[i], downsample=i > skip_stages, self_attn=self_attn_depths[i], cross_attn=cross_attn_depths[i], c_enc=cross_cond_dim, dropout_rate=dropout_rate)) 110 | for i in range(len(depths)): 111 | my_c_in = channels[i] * 2 if i < len(depths) - 1 else channels[i] 112 | my_c_out = channels[max(0, i - 1)] 113 | u_blocks.append(UBlock(depths[i], feats_in, my_c_in, channels[i], my_c_out, upsample=i > skip_stages, self_attn=self_attn_depths[i], cross_attn=cross_attn_depths[i], c_enc=cross_cond_dim, dropout_rate=dropout_rate)) 114 | self.u_net = layers.UNet(d_blocks, reversed(u_blocks), skip_stages=skip_stages) 115 | 116 | def forward(self, input, sigma, mapping_cond=None, unet_cond=None, cross_cond=None, cross_cond_padding=None, return_variance=False): 117 | #print(sigma) 118 | #import pdb; pdb.set_trace() 119 | #sigma = sigma*0. + 0.1 120 | #sigma = (sigma.clip(0,1)*4+0.25).round()/4+0.01 121 | #sigma = sigma.clip(0.,0.5) 122 | #sigma = sigma * (torch.rand_like(sigma)+0.5) 123 | c_noise = sigma.log() / 4 124 | timestep_embed = self.timestep_embed(utils.append_dims(c_noise, 2)) 125 | if not self.t_embed: timestep_embed *= 0. 126 | mapping_cond_embed = torch.zeros_like(timestep_embed) if mapping_cond is None else self.mapping_cond(mapping_cond) 127 | mapping_out = self.mapping(timestep_embed + mapping_cond_embed) 128 | cond = {'cond': mapping_out} 129 | if unet_cond is not None: 130 | input = torch.cat([input, unet_cond], dim=1) 131 | if cross_cond is not None: 132 | cond['cross'] = cross_cond 133 | cond['cross_padding'] = cross_cond_padding 134 | if self.patch_size > 1: 135 | input = F.pixel_unshuffle(input, self.patch_size) 136 | input = self.proj_in(input) 137 | input = self.u_net(input, cond) 138 | input = self.proj_out(input) 139 | if self.has_variance: 140 | input, logvar = input[:, :-1], input[:, -1].flatten(1).mean(1) 141 | if self.patch_size > 1: 142 | input = F.pixel_shuffle(input, self.patch_size) 143 | if self.has_variance and return_variance: 144 | return input, logvar 145 | return input 146 | 147 | def set_skip_stages(self, skip_stages): 148 | self.proj_in = nn.Conv2d(self.proj_in.in_channels, self.channels[max(0, skip_stages - 1)], 1) 149 | self.proj_out = nn.Conv2d(self.channels[max(0, skip_stages - 1)], self.proj_out.out_channels, 1) 150 | nn.init.zeros_(self.proj_out.weight) 151 | nn.init.zeros_(self.proj_out.bias) 152 | self.u_net.skip_stages = skip_stages 153 | for i, block in enumerate(self.u_net.d_blocks): 154 | block.set_downsample(i > skip_stages) 155 | for i, block in enumerate(reversed(self.u_net.u_blocks)): 156 | block.set_upsample(i > skip_stages) 157 | return self 158 | 159 | def set_patch_size(self, patch_size): 160 | self.patch_size = patch_size 161 | self.proj_in = nn.Conv2d((self.c_in + self.unet_cond_dim) * self.patch_size ** 2, self.channels[max(0, self.u_net.skip_stages - 1)], 1) 162 | self.proj_out = nn.Conv2d(self.channels[max(0, self.u_net.skip_stages - 1)], self.c_in * self.patch_size ** 2 + (1 if self.has_variance else 0), 1) 163 | nn.init.zeros_(self.proj_out.weight) 164 | nn.init.zeros_(self.proj_out.bias) 165 | -------------------------------------------------------------------------------- /mtrain.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Trains Karras et al. (2022) diffusion models.""" 3 | 4 | import math, accelerate, torch 5 | from copy import deepcopy 6 | from functools import partial 7 | from pathlib import Path 8 | from fastcore.script import call_parse 9 | 10 | from torch import optim, multiprocessing as mp 11 | from torch.utils import data 12 | from torchvision import datasets, transforms, utils 13 | from tqdm.auto import tqdm 14 | 15 | import k_diffusion as K 16 | 17 | sampler = K.sampling.sample_lms 18 | #sampler = K.sampling.sample_euler 19 | #sampler = K.sampling.sample_heun 20 | 21 | @call_parse 22 | def main( 23 | config:str, # the configuration file 24 | batch_size:int=256, # the batch size 25 | demo_every:int=1000, # save a demo grid every this many steps 26 | sample_steps:int=50, # number of steps to use when sampling 27 | evaluate_every:int=5000, # save a demo grid every this many steps 28 | evaluate_n:int=2000, # the number of samples to draw to evaluate 29 | lr:float=None, # the learning rate 30 | name:str='model', # the name of the run 31 | num_workers:int=8, # the number of data loader workers 32 | sample_n:int=64, # the number of images to sample for demo grids 33 | save_every:int=10000, # save every this many steps 34 | seed:int=None, # the random seed 35 | start_method:str='spawn' # the multiprocessing start method 36 | ): 37 | #choices=['fork', 'forkserver', 'spawn'], 38 | mp.set_start_method(start_method) 39 | torch.backends.cuda.matmul.allow_tf32 = True 40 | path = Path('outputs') 41 | path.mkdir(exist_ok=True) 42 | 43 | config = K.config.load_config(open(config)) 44 | model_cfg = config['model'] 45 | dataset_cfg = config['dataset'] 46 | opt_cfg = config['optimizer'] 47 | sched_cfg = config['lr_sched'] 48 | ema_sched_cfg = config['ema_sched'] 49 | 50 | # TODO: allow non-square input sizes 51 | assert len(model_cfg['input_size']) == 2 and model_cfg['input_size'][0] == model_cfg['input_size'][1] 52 | size = model_cfg['input_size'] 53 | 54 | accelerator = accelerate.Accelerator() 55 | device = accelerator.device 56 | print(f'Process {accelerator.process_index} using device: {device}', flush=True) 57 | 58 | if seed is not None: torch.manual_seed(seed) 59 | inner_model = K.config.make_model(config) 60 | print('Parameters:', K.utils.n_params(inner_model)) 61 | 62 | if not lr: lr = opt_cfg['lr'] 63 | if opt_cfg['type'] == 'adamw': 64 | opt = optim.AdamW(inner_model.parameters(), lr=lr, 65 | betas=tuple(opt_cfg['betas']), 66 | eps=opt_cfg['eps'], 67 | weight_decay=opt_cfg['weight_decay']) 68 | elif opt_cfg['type'] == 'sgd': 69 | opt = optim.SGD(inner_model.parameters(), lr=lr, 70 | momentum=opt_cfg.get('momentum', 0.), 71 | nesterov=opt_cfg.get('nesterov', False), 72 | weight_decay=opt_cfg.get('weight_decay', 0.)) 73 | else: raise ValueError('Invalid optimizer type') 74 | 75 | if sched_cfg['type'] == 'inverse': 76 | sched = K.utils.InverseLR(opt, inv_gamma=sched_cfg['inv_gamma'], power=sched_cfg['power'], warmup=sched_cfg['warmup']) 77 | elif sched_cfg['type'] == 'exponential': 78 | sched = K.utils.ExponentialLR(opt, num_steps=sched_cfg['num_steps'], decay=sched_cfg['decay'], warmup=sched_cfg['warmup']) 79 | else: raise ValueError('Invalid schedule type') 80 | 81 | assert ema_sched_cfg['type'] == 'inverse' 82 | ema_sched = K.utils.EMAWarmup(power=ema_sched_cfg['power'], max_value=ema_sched_cfg['max_value']) 83 | 84 | tf = transforms.Compose([ 85 | transforms.Resize(size[0], interpolation=transforms.InterpolationMode.LANCZOS), 86 | transforms.CenterCrop(size[0]), 87 | K.augmentation.KarrasAugmentationPipeline(model_cfg['augment_prob']), 88 | ]) 89 | 90 | if dataset_cfg['type'] == 'imagefolder': 91 | train_set = K.utils.FolderOfImages(dataset_cfg['location'], transform=tf) 92 | elif dataset_cfg['type'] == 'cifar10': 93 | train_set = datasets.CIFAR10(dataset_cfg['location'], train=True, download=True, transform=tf) 94 | elif dataset_cfg['type'] == 'fashion': 95 | train_set = datasets.FashionMNIST(dataset_cfg['location'], train=True, download=True, transform=tf) 96 | elif dataset_cfg['type'] == 'mnist': 97 | train_set = datasets.MNIST(dataset_cfg['location'], train=True, download=True, transform=tf) 98 | elif dataset_cfg['type'] == 'huggingface': 99 | from datasets import load_dataset 100 | train_set = load_dataset(dataset_cfg['location']) 101 | train_set.set_transform(partial(K.utils.hf_datasets_augs_helper, transform=tf, image_key=dataset_cfg['image_key'])) 102 | train_set = train_set['train'] 103 | else: raise ValueError('Invalid dataset type') 104 | 105 | try: print('Number of items in dataset:', len(train_set)) 106 | except TypeError: pass 107 | 108 | image_key = dataset_cfg.get('image_key', 0) 109 | train_dl = data.DataLoader(train_set, batch_size, shuffle=True, drop_last=True, num_workers=num_workers, persistent_workers=True) 110 | 111 | inner_model, opt, train_dl = accelerator.prepare(inner_model, opt, train_dl) 112 | sigma_min = model_cfg['sigma_min'] 113 | sigma_max = model_cfg['sigma_max'] 114 | sample_density = K.config.make_sample_density(model_cfg) 115 | model = K.config.make_denoiser_wrapper(config)(inner_model) 116 | model_ema = deepcopy(model) 117 | epoch,step = 0,0 118 | 119 | evaluate_enabled = evaluate_every > 0 and evaluate_n > 0 120 | if evaluate_enabled: 121 | extractor = K.evaluation.InceptionV3FeatureExtractor(device=device) 122 | train_iter = iter(train_dl) 123 | print('Computing features for reals...') 124 | reals_features = K.evaluation.compute_features(accelerator, lambda x: next(train_iter)[image_key][1], extractor, evaluate_n, batch_size) 125 | metrics_log = K.utils.CSVLogger(path/f'{name}_metrics.csv', ['step', 'fid', 'kid']) 126 | del train_iter 127 | 128 | @torch.no_grad() 129 | @K.utils.eval_mode(model_ema) 130 | def demo(): 131 | tqdm.write('Sampling...') 132 | filename = path/f'{name}_demo_{step:08}.png' 133 | x = torch.randn([sample_n, model_cfg['input_channels'], size[0], size[1]], device=device) * sigma_max 134 | sigmas = K.sampling.get_sigmas_karras(sample_steps, sigma_min, sigma_max, rho=7., device=device) 135 | x_0 = sampler(model_ema, x, sigmas) 136 | x_0 = x_0[:sample_n] 137 | grid = utils.make_grid(x_0, nrow=math.ceil(sample_n ** 0.5), padding=0) 138 | K.utils.to_pil_image(grid).save(filename) 139 | 140 | @torch.no_grad() 141 | @K.utils.eval_mode(model_ema) 142 | def evaluate(): 143 | if not evaluate_enabled: return 144 | tqdm.write('Evaluating...') 145 | sigmas = K.sampling.get_sigmas_karras(sample_steps, sigma_min, sigma_max, rho=7., device=device) 146 | def sample_fn(n): 147 | x = torch.randn([n, model_cfg['input_channels'], size[0], size[1]], device=device) * sigma_max 148 | return sampler(model_ema, x, sigmas) 149 | fakes_features = K.evaluation.compute_features(accelerator, sample_fn, extractor, evaluate_n, batch_size) 150 | fid = K.evaluation.fid(fakes_features, reals_features) 151 | kid = K.evaluation.kid(fakes_features, reals_features) 152 | print(f'FID: {fid.item():g}, KID: {kid.item():g}') 153 | metrics_log.write(step, fid.item(), kid.item()) 154 | 155 | def save(): 156 | filename = path/f'{name}_{step:08}.pth' 157 | tqdm.write(f'Saving to {filename}...') 158 | obj = { 159 | 'model': accelerator.unwrap_model(model.inner_model).state_dict(), 160 | 'model_ema': accelerator.unwrap_model(model_ema.inner_model).state_dict(), 161 | 'opt': opt.state_dict(), 'sched': sched.state_dict(), 'ema_sched': ema_sched.state_dict(), 162 | 'epoch': epoch, 'step': step, } 163 | accelerator.save(obj, filename) 164 | 165 | try: 166 | while True: 167 | for batch in tqdm(train_dl): 168 | reals, _, aug_cond = batch[image_key] 169 | noise = torch.randn_like(reals) 170 | sigma = sample_density([reals.shape[0]], device=device) 171 | loss = model.loss(reals, noise, sigma, aug_cond=aug_cond).mean() 172 | accelerator.backward(loss) 173 | opt.step() 174 | sched.step() 175 | opt.zero_grad() 176 | ema_decay = ema_sched.get_value() 177 | K.utils.ema_update(model, model_ema, ema_decay) 178 | ema_sched.step() 179 | 180 | if step % 25 == 0: tqdm.write(f'Epoch: {epoch}, step: {step}, loss: {loss.item():g}, lr: {lr}') 181 | if step in (100,200,500) or step % demo_every == 0: demo() 182 | if evaluate_enabled and step > 0 and step % evaluate_every == 0: evaluate() 183 | if step > 0 and step % save_every == 0: save() 184 | step += 1 185 | epoch += 1 186 | except KeyboardInterrupt: pass 187 | 188 | -------------------------------------------------------------------------------- /k_diffusion/layers.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from einops import rearrange, repeat 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | 8 | from . import utils 9 | 10 | # Karras et al. preconditioned denoiser 11 | 12 | def huber(x): 13 | a = x.abs() 14 | return (0.5*x**2).where(a<=1, a-0.5) 15 | 16 | class Denoiser(nn.Module): 17 | """A Karras et al. preconditioner for denoising diffusion models.""" 18 | 19 | def __init__(self, inner_model, sigma_data=1., unscaled=False): 20 | super().__init__() 21 | self.inner_model,self.sigma_data,self.unscaled = inner_model,sigma_data,unscaled 22 | 23 | def get_scalings(self, sigma): 24 | c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) 25 | c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 26 | c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 27 | ones,zeros = torch.ones_like(sigma),torch.zeros_like(sigma) 28 | #c_skip,c_out,c_in = ones,-sigma,1/(sigma**2+1).sqrt() 29 | #c_skip,c_out,c_in = ones,-sigma,1/sigma 30 | #c_skip,c_out,c_in = ones,sigma,ones 31 | if self.unscaled: c_skip,c_out,c_in = ones,-ones,ones 32 | return c_skip, c_out, c_in 33 | 34 | def loss(self, input, noise, sigma, **kwargs): 35 | c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] 36 | noised_input = input + noise * utils.append_dims(sigma, input.ndim) 37 | model_output = self.inner_model(noised_input * c_in, sigma, **kwargs) 38 | target = (input - c_skip * noised_input) / c_out 39 | return (model_output - target).pow(2).flatten(1).mean(1) 40 | 41 | def forward(self, input, sigma, **kwargs): 42 | c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] 43 | return self.inner_model(input * c_in, sigma, **kwargs) * c_out + input * c_skip 44 | 45 | 46 | class DenoiserWithVariance(Denoiser): 47 | def loss(self, input, noise, sigma, **kwargs): 48 | c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] 49 | noised_input = input + noise * utils.append_dims(sigma, input.ndim) 50 | model_output, logvar = self.inner_model(noised_input * c_in, sigma, return_variance=True, **kwargs) 51 | logvar = utils.append_dims(logvar, model_output.ndim) 52 | target = (input - c_skip * noised_input) / c_out 53 | losses = ((model_output - target) ** 2 / logvar.exp() + logvar) / 2 54 | return losses.flatten(1).mean(1) 55 | 56 | 57 | # Residual blocks 58 | 59 | class ResidualBlock(nn.Module): 60 | def __init__(self, *main, skip=None): 61 | super().__init__() 62 | self.main = nn.Sequential(*main) 63 | self.skip = skip if skip else nn.Identity() 64 | 65 | def forward(self, input): 66 | return self.main(input) + self.skip(input) 67 | 68 | 69 | # Noise level (and other) conditioning 70 | 71 | class ConditionedModule(nn.Module): 72 | pass 73 | 74 | 75 | class UnconditionedModule(ConditionedModule): 76 | def __init__(self, module): 77 | super().__init__() 78 | self.module = module 79 | 80 | def forward(self, input, cond=None): 81 | return self.module(input) 82 | 83 | 84 | class ConditionedSequential(nn.Sequential, ConditionedModule): 85 | def forward(self, input, cond): 86 | for module in self: 87 | if isinstance(module, ConditionedModule): 88 | input = module(input, cond) 89 | else: 90 | input = module(input) 91 | return input 92 | 93 | 94 | class ConditionedResidualBlock(ConditionedModule): 95 | def __init__(self, *main, skip=None): 96 | super().__init__() 97 | self.main = ConditionedSequential(*main) 98 | self.skip = skip if skip else nn.Identity() 99 | 100 | def forward(self, input, cond): 101 | skip = self.skip(input, cond) if isinstance(self.skip, ConditionedModule) else self.skip(input) 102 | return self.main(input, cond) + skip 103 | 104 | 105 | class AdaGN(ConditionedModule): 106 | def __init__(self, feats_in, c_out, num_groups, eps=1e-5, cond_key='cond'): 107 | super().__init__() 108 | self.num_groups = num_groups 109 | self.eps = eps 110 | self.cond_key = cond_key 111 | self.mapper = nn.Linear(feats_in, c_out * 2) 112 | 113 | def forward(self, input, cond): 114 | weight, bias = self.mapper(cond[self.cond_key]).chunk(2, dim=-1) 115 | input = F.group_norm(input, self.num_groups, eps=self.eps) 116 | return torch.addcmul(utils.append_dims(bias, input.ndim), input, utils.append_dims(weight, input.ndim) + 1) 117 | 118 | 119 | # Attention 120 | 121 | class SelfAttention2d(ConditionedModule): 122 | def __init__(self, c_in, n_head, norm, dropout_rate=0.): 123 | super().__init__() 124 | assert c_in % n_head == 0 125 | self.norm_in = norm(c_in) 126 | self.n_head = n_head 127 | self.qkv_proj = nn.Conv2d(c_in, c_in * 3, 1) 128 | self.out_proj = nn.Conv2d(c_in, c_in, 1) 129 | self.dropout = nn.Dropout(dropout_rate) 130 | 131 | def forward(self, input, cond): 132 | n, c, h, w = input.shape 133 | qkv = self.qkv_proj(self.norm_in(input, cond)) 134 | qkv = qkv.view([n, self.n_head * 3, c // self.n_head, h * w]).transpose(2, 3) 135 | q, k, v = qkv.chunk(3, dim=1) 136 | scale = k.shape[3] ** -0.25 137 | att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3) 138 | att = self.dropout(att) 139 | y = (att @ v).transpose(2, 3).contiguous().view([n, c, h, w]) 140 | return input + self.out_proj(y) 141 | 142 | 143 | class CrossAttention2d(ConditionedModule): 144 | def __init__(self, c_dec, c_enc, n_head, norm_dec, dropout_rate=0., 145 | cond_key='cross', cond_key_padding='cross_padding'): 146 | super().__init__() 147 | assert c_dec % n_head == 0 148 | self.cond_key = cond_key 149 | self.cond_key_padding = cond_key_padding 150 | self.norm_enc = nn.LayerNorm(c_enc) 151 | self.norm_dec = norm_dec(c_dec) 152 | self.n_head = n_head 153 | self.q_proj = nn.Conv2d(c_dec, c_dec, 1) 154 | self.kv_proj = nn.Linear(c_enc, c_dec * 2) 155 | self.out_proj = nn.Conv2d(c_dec, c_dec, 1) 156 | self.dropout = nn.Dropout(dropout_rate) 157 | 158 | def forward(self, input, cond): 159 | n, c, h, w = input.shape 160 | q = self.q_proj(self.norm_dec(input, cond)) 161 | q = q.view([n, self.n_head, c // self.n_head, h * w]).transpose(2, 3) 162 | kv = self.kv_proj(self.norm_enc(cond[self.cond_key])) 163 | kv = kv.view([n, -1, self.n_head * 2, c // self.n_head]).transpose(1, 2) 164 | k, v = kv.chunk(2, dim=1) 165 | scale = k.shape[3] ** -0.25 166 | att = ((q * scale) @ (k.transpose(2, 3) * scale)) 167 | att = att - (cond[self.cond_key_padding][:, None, None, :]) * 10000 168 | att = att.softmax(3) 169 | att = self.dropout(att) 170 | y = (att @ v).transpose(2, 3) 171 | y = y.contiguous().view([n, c, h, w]) 172 | return input + self.out_proj(y) 173 | 174 | 175 | # Downsampling/upsampling 176 | 177 | _kernels = { 178 | 'linear': 179 | [1 / 8, 3 / 8, 3 / 8, 1 / 8], 180 | 'cubic': 181 | [-0.01171875, -0.03515625, 0.11328125, 0.43359375, 182 | 0.43359375, 0.11328125, -0.03515625, -0.01171875], 183 | 'lanczos3': 184 | [0.003689131001010537, 0.015056144446134567, -0.03399861603975296, 185 | -0.066637322306633, 0.13550527393817902, 0.44638532400131226, 186 | 0.44638532400131226, 0.13550527393817902, -0.066637322306633, 187 | -0.03399861603975296, 0.015056144446134567, 0.003689131001010537] 188 | } 189 | _kernels['bilinear'] = _kernels['linear'] 190 | _kernels['bicubic'] = _kernels['cubic'] 191 | 192 | 193 | class Downsample2d(nn.Module): 194 | def __init__(self, kernel='linear', pad_mode='reflect'): 195 | super().__init__() 196 | self.pad_mode = pad_mode 197 | kernel_1d = torch.tensor([_kernels[kernel]]) 198 | self.pad = kernel_1d.shape[1] // 2 - 1 199 | self.register_buffer('kernel', kernel_1d.T @ kernel_1d) 200 | 201 | def forward(self, x): 202 | x = F.pad(x, (self.pad,) * 4, self.pad_mode) 203 | weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]]) 204 | indices = torch.arange(x.shape[1], device=x.device) 205 | weight[indices, indices] = self.kernel.to(weight) 206 | return F.conv2d(x, weight, stride=2) 207 | 208 | 209 | class Upsample2d(nn.Module): 210 | def __init__(self, kernel='linear', pad_mode='reflect'): 211 | super().__init__() 212 | self.pad_mode = pad_mode 213 | kernel_1d = torch.tensor([_kernels[kernel]]) * 2 214 | self.pad = kernel_1d.shape[1] // 2 - 1 215 | self.register_buffer('kernel', kernel_1d.T @ kernel_1d) 216 | 217 | def forward(self, x): 218 | x = F.pad(x, ((self.pad + 1) // 2,) * 4, self.pad_mode) 219 | weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]]) 220 | indices = torch.arange(x.shape[1], device=x.device) 221 | weight[indices, indices] = self.kernel.to(weight) 222 | return F.conv_transpose2d(x, weight, stride=2, padding=self.pad * 2 + 1) 223 | 224 | 225 | # Embeddings 226 | 227 | class FourierFeatures(nn.Module): 228 | def __init__(self, in_features, out_features): 229 | super().__init__() 230 | assert out_features % 2 == 0 231 | self.register_buffer('weight', torch.randn([out_features // 2, in_features]) / in_features) 232 | 233 | def forward(self, input): 234 | #f = input@self.weight.T 235 | f = 2 * math.pi * input @ self.weight.T 236 | return torch.cat([f.cos(), f.sin()], dim=-1) 237 | 238 | 239 | # U-Nets 240 | 241 | class UNet(ConditionedModule): 242 | def __init__(self, d_blocks, u_blocks, skip_stages=0): 243 | super().__init__() 244 | self.d_blocks = nn.ModuleList(d_blocks) 245 | self.u_blocks = nn.ModuleList(u_blocks) 246 | self.skip_stages = skip_stages 247 | 248 | def forward(self, input, cond): 249 | skips = [] 250 | for block in self.d_blocks[self.skip_stages:]: 251 | input = block(input, cond) 252 | skips.append(input) 253 | for i, (block, skip) in enumerate(zip(self.u_blocks, reversed(skips))): 254 | input = block(input, cond, skip if i > 0 else None) 255 | return input 256 | -------------------------------------------------------------------------------- /k_diffusion/utils.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | import hashlib 3 | import math 4 | from pathlib import Path 5 | import shutil 6 | import urllib 7 | import warnings 8 | 9 | from PIL import Image 10 | import torch 11 | from torch import nn, optim 12 | from torch.utils import data 13 | from torchvision.transforms import functional as TF 14 | 15 | 16 | def from_pil_image(x): 17 | """Converts from a PIL image to a tensor.""" 18 | x = TF.to_tensor(x) 19 | if x.ndim == 2: 20 | x = x[..., None] 21 | return x * 2 - 1 22 | 23 | 24 | def to_pil_image(x): 25 | """Converts from a tensor to a PIL image.""" 26 | if x.ndim == 4: 27 | assert x.shape[0] == 1 28 | x = x[0] 29 | if x.shape[0] == 1: 30 | x = -x[0] 31 | return TF.to_pil_image((x.clamp(-1, 1) + 1) / 2) 32 | 33 | 34 | def hf_datasets_augs_helper(examples, transform, image_key, mode='RGB'): 35 | """Apply passed in transforms for HuggingFace Datasets.""" 36 | images = [transform(image.convert(mode)) for image in examples[image_key]] 37 | return {image_key: images} 38 | 39 | 40 | def append_dims(x, target_dims): 41 | """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" 42 | dims_to_append = target_dims - x.ndim 43 | if dims_to_append < 0: 44 | raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') 45 | return x[(...,) + (None,) * dims_to_append] 46 | 47 | 48 | def n_params(module): 49 | """Returns the number of trainable parameters in a module.""" 50 | return sum(p.numel() for p in module.parameters()) 51 | 52 | 53 | def download_file(path, url, digest=None): 54 | """Downloads a file if it does not exist, optionally checking its SHA-256 hash.""" 55 | path = Path(path) 56 | path.parent.mkdir(parents=True, exist_ok=True) 57 | if not path.exists(): 58 | with urllib.request.urlopen(url) as response, open(path, 'wb') as f: 59 | shutil.copyfileobj(response, f) 60 | if digest is not None: 61 | file_digest = hashlib.sha256(open(path, 'rb').read()).hexdigest() 62 | if digest != file_digest: 63 | raise OSError(f'hash of {path} (url: {url}) failed to validate') 64 | return path 65 | 66 | 67 | @contextmanager 68 | def train_mode(model, mode=True): 69 | """A context manager that places a model into training mode and restores 70 | the previous mode on exit.""" 71 | modes = [module.training for module in model.modules()] 72 | try: 73 | yield model.train(mode) 74 | finally: 75 | for i, module in enumerate(model.modules()): 76 | module.training = modes[i] 77 | 78 | 79 | def eval_mode(model): 80 | """A context manager that places a model into evaluation mode and restores 81 | the previous mode on exit.""" 82 | return train_mode(model, False) 83 | 84 | 85 | @torch.no_grad() 86 | def ema_update(model, averaged_model, decay): 87 | """Incorporates updated model parameters into an exponential moving averaged 88 | version of a model. It should be called after each optimizer step.""" 89 | model_params = dict(model.named_parameters()) 90 | averaged_params = dict(averaged_model.named_parameters()) 91 | assert model_params.keys() == averaged_params.keys() 92 | 93 | for name, param in model_params.items(): 94 | averaged_params[name].mul_(decay).add_(param, alpha=1 - decay) 95 | 96 | model_buffers = dict(model.named_buffers()) 97 | averaged_buffers = dict(averaged_model.named_buffers()) 98 | assert model_buffers.keys() == averaged_buffers.keys() 99 | 100 | for name, buf in model_buffers.items(): 101 | averaged_buffers[name].copy_(buf) 102 | 103 | 104 | class EMAWarmup: 105 | """Implements an EMA warmup using an inverse decay schedule. 106 | If inv_gamma=1 and power=1, implements a simple average. inv_gamma=1, power=2/3 are 107 | good values for models you plan to train for a million or more steps (reaches decay 108 | factor 0.999 at 31.6K steps, 0.9999 at 1M steps), inv_gamma=1, power=3/4 for models 109 | you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at 110 | 215.4k steps). 111 | Args: 112 | inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. 113 | power (float): Exponential factor of EMA warmup. Default: 1. 114 | min_value (float): The minimum EMA decay rate. Default: 0. 115 | max_value (float): The maximum EMA decay rate. Default: 1. 116 | start_at (int): The epoch to start averaging at. Default: 0. 117 | last_epoch (int): The index of last epoch. Default: 0. 118 | """ 119 | 120 | def __init__(self, inv_gamma=1., power=1., min_value=0., max_value=1., start_at=0, 121 | last_epoch=0): 122 | self.inv_gamma = inv_gamma 123 | self.power = power 124 | self.min_value = min_value 125 | self.max_value = max_value 126 | self.start_at = start_at 127 | self.last_epoch = last_epoch 128 | 129 | def state_dict(self): 130 | """Returns the state of the class as a :class:`dict`.""" 131 | return dict(self.__dict__.items()) 132 | 133 | def load_state_dict(self, state_dict): 134 | """Loads the class's state. 135 | Args: 136 | state_dict (dict): scaler state. Should be an object returned 137 | from a call to :meth:`state_dict`. 138 | """ 139 | self.__dict__.update(state_dict) 140 | 141 | def get_value(self): 142 | """Gets the current EMA decay rate.""" 143 | epoch = max(0, self.last_epoch - self.start_at) 144 | value = 1 - (1 + epoch / self.inv_gamma) ** -self.power 145 | return 0. if epoch < 0 else min(self.max_value, max(self.min_value, value)) 146 | 147 | def step(self): 148 | """Updates the step count.""" 149 | self.last_epoch += 1 150 | 151 | 152 | class InverseLR(optim.lr_scheduler._LRScheduler): 153 | """Implements an inverse decay learning rate schedule with an optional exponential 154 | warmup. When last_epoch=-1, sets initial lr as lr. 155 | inv_gamma is the number of steps/epochs required for the learning rate to decay to 156 | (1 / 2)**power of its original value. 157 | Args: 158 | optimizer (Optimizer): Wrapped optimizer. 159 | inv_gamma (float): Inverse multiplicative factor of learning rate decay. Default: 1. 160 | power (float): Exponential factor of learning rate decay. Default: 1. 161 | warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable) 162 | Default: 0. 163 | min_lr (float): The minimum learning rate. Default: 0. 164 | last_epoch (int): The index of last epoch. Default: -1. 165 | verbose (bool): If ``True``, prints a message to stdout for 166 | each update. Default: ``False``. 167 | """ 168 | 169 | def __init__(self, optimizer, inv_gamma=1., power=1., warmup=0., min_lr=0., 170 | last_epoch=-1, verbose=False): 171 | self.inv_gamma = inv_gamma 172 | self.power = power 173 | if not 0. <= warmup < 1: 174 | raise ValueError('Invalid value for warmup') 175 | self.warmup = warmup 176 | self.min_lr = min_lr 177 | super().__init__(optimizer, last_epoch, verbose) 178 | 179 | def get_lr(self): 180 | if not self._get_lr_called_within_step: 181 | warnings.warn("To get the last learning rate computed by the scheduler, " 182 | "please use `get_last_lr()`.") 183 | 184 | return self._get_closed_form_lr() 185 | 186 | def _get_closed_form_lr(self): 187 | warmup = 1 - self.warmup ** (self.last_epoch + 1) 188 | lr_mult = (1 + self.last_epoch / self.inv_gamma) ** -self.power 189 | return [warmup * max(self.min_lr, base_lr * lr_mult) 190 | for base_lr in self.base_lrs] 191 | 192 | 193 | class ExponentialLR(optim.lr_scheduler._LRScheduler): 194 | """Implements an exponential learning rate schedule with an optional exponential 195 | warmup. When last_epoch=-1, sets initial lr as lr. Decays the learning rate 196 | continuously by decay (default 0.5) every num_steps steps. 197 | Args: 198 | optimizer (Optimizer): Wrapped optimizer. 199 | num_steps (float): The number of steps to decay the learning rate by decay in. 200 | decay (float): The factor by which to decay the learning rate every num_steps 201 | steps. Default: 0.5. 202 | warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable) 203 | Default: 0. 204 | min_lr (float): The minimum learning rate. Default: 0. 205 | last_epoch (int): The index of last epoch. Default: -1. 206 | verbose (bool): If ``True``, prints a message to stdout for 207 | each update. Default: ``False``. 208 | """ 209 | 210 | def __init__(self, optimizer, num_steps, decay=0.5, warmup=0., min_lr=0., 211 | last_epoch=-1, verbose=False): 212 | self.num_steps = num_steps 213 | self.decay = decay 214 | if not 0. <= warmup < 1: 215 | raise ValueError('Invalid value for warmup') 216 | self.warmup = warmup 217 | self.min_lr = min_lr 218 | super().__init__(optimizer, last_epoch, verbose) 219 | 220 | def get_lr(self): 221 | if not self._get_lr_called_within_step: 222 | warnings.warn("To get the last learning rate computed by the scheduler, " 223 | "please use `get_last_lr()`.") 224 | 225 | return self._get_closed_form_lr() 226 | 227 | def _get_closed_form_lr(self): 228 | warmup = 1 - self.warmup ** (self.last_epoch + 1) 229 | lr_mult = (self.decay ** (1 / self.num_steps)) ** self.last_epoch 230 | return [warmup * max(self.min_lr, base_lr * lr_mult) 231 | for base_lr in self.base_lrs] 232 | 233 | 234 | def rand_log_normal(shape, loc=0., scale=1., device='cpu', dtype=torch.float32): 235 | """Draws samples from an lognormal distribution.""" 236 | return (torch.randn(shape, device=device, dtype=dtype) * scale + loc).exp() 237 | 238 | 239 | def rand_log_logistic(shape, loc=0., scale=1., min_value=0., max_value=float('inf'), device='cpu', dtype=torch.float32): 240 | """Draws samples from an optionally truncated log-logistic distribution.""" 241 | min_value = torch.as_tensor(min_value, device=device, dtype=torch.float64) 242 | max_value = torch.as_tensor(max_value, device=device, dtype=torch.float64) 243 | min_cdf = min_value.log().sub(loc).div(scale).sigmoid() 244 | max_cdf = max_value.log().sub(loc).div(scale).sigmoid() 245 | u = torch.rand(shape, device=device, dtype=torch.float64) * (max_cdf - min_cdf) + min_cdf 246 | return u.logit().mul(scale).add(loc).exp().to(dtype) 247 | 248 | 249 | def rand_log_uniform(shape, min_value, max_value, device='cpu', dtype=torch.float32): 250 | """Draws samples from an log-uniform distribution.""" 251 | min_value = math.log(min_value) 252 | max_value = math.log(max_value) 253 | return (torch.rand(shape, device=device, dtype=dtype) * (max_value - min_value) + min_value).exp() 254 | 255 | 256 | def rand_v_diffusion(shape, sigma_data=1., min_value=0., max_value=float('inf'), device='cpu', dtype=torch.float32): 257 | """Draws samples from a truncated v-diffusion training timestep distribution.""" 258 | min_cdf = math.atan(min_value / sigma_data) * 2 / math.pi 259 | max_cdf = math.atan(max_value / sigma_data) * 2 / math.pi 260 | u = torch.rand(shape, device=device, dtype=dtype) * (max_cdf - min_cdf) + min_cdf 261 | return torch.tan(u * math.pi / 2) * sigma_data 262 | 263 | 264 | def rand_split_log_normal(shape, loc, scale_1, scale_2, device='cpu', dtype=torch.float32): 265 | """Draws samples from a split lognormal distribution.""" 266 | n = torch.randn(shape, device=device, dtype=dtype).abs() 267 | u = torch.rand(shape, device=device, dtype=dtype) 268 | n_left = n * -scale_1 + loc 269 | n_right = n * scale_2 + loc 270 | ratio = scale_1 / (scale_1 + scale_2) 271 | return torch.where(u < ratio, n_left, n_right).exp() 272 | 273 | 274 | class FolderOfImages(data.Dataset): 275 | """Recursively finds all images in a directory. It does not support 276 | classes/targets.""" 277 | 278 | IMG_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp'} 279 | 280 | def __init__(self, root, transform=None): 281 | super().__init__() 282 | self.root = Path(root) 283 | self.transform = nn.Identity() if transform is None else transform 284 | self.paths = sorted(path for path in self.root.rglob('*') if path.suffix.lower() in self.IMG_EXTENSIONS) 285 | 286 | def __repr__(self): 287 | return f'FolderOfImages(root="{self.root}", len: {len(self)})' 288 | 289 | def __len__(self): 290 | return len(self.paths) 291 | 292 | def __getitem__(self, key): 293 | path = self.paths[key] 294 | with open(path, 'rb') as f: 295 | image = Image.open(f).convert('RGB') 296 | image = self.transform(image) 297 | return image, 298 | 299 | 300 | class CSVLogger: 301 | def __init__(self, filename, columns): 302 | self.filename = Path(filename) 303 | self.columns = columns 304 | if self.filename.exists(): 305 | self.file = open(self.filename, 'a') 306 | else: 307 | self.file = open(self.filename, 'w') 308 | self.write(*self.columns) 309 | 310 | def write(self, *args): 311 | print(*args, sep=',', file=self.file, flush=True) 312 | 313 | 314 | @contextmanager 315 | def tf32_mode(cudnn=None, matmul=None): 316 | """A context manager that sets whether TF32 is allowed on cuDNN or matmul.""" 317 | cudnn_old = torch.backends.cudnn.allow_tf32 318 | matmul_old = torch.backends.cuda.matmul.allow_tf32 319 | try: 320 | if cudnn is not None: 321 | torch.backends.cudnn.allow_tf32 = cudnn 322 | if matmul is not None: 323 | torch.backends.cuda.matmul.allow_tf32 = matmul 324 | yield 325 | finally: 326 | if cudnn is not None: 327 | torch.backends.cudnn.allow_tf32 = cudnn_old 328 | if matmul is not None: 329 | torch.backends.cuda.matmul.allow_tf32 = matmul_old 330 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """Trains Karras et al. (2022) diffusion models.""" 4 | 5 | import argparse 6 | from copy import deepcopy 7 | from functools import partial 8 | import math 9 | import json 10 | from pathlib import Path 11 | 12 | import accelerate 13 | import torch 14 | from torch import nn, optim 15 | from torch import multiprocessing as mp 16 | from torch.utils import data 17 | from torchvision import datasets, transforms, utils 18 | from tqdm.auto import trange, tqdm 19 | 20 | import k_diffusion as K 21 | 22 | 23 | def main(): 24 | p = argparse.ArgumentParser(description=__doc__, 25 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 26 | p.add_argument('--batch-size', type=int, default=256, 27 | help='the batch size') 28 | p.add_argument('--config', type=str, required=True, 29 | help='the configuration file') 30 | p.add_argument('--demo-every', type=int, default=500, 31 | help='save a demo grid every this many steps') 32 | p.add_argument('--evaluate-every', type=int, default=10000, 33 | help='save a demo grid every this many steps') 34 | p.add_argument('--evaluate-n', type=int, default=2000, 35 | help='the number of samples to draw to evaluate') 36 | p.add_argument('--gns', action='store_true', 37 | help='measure the gradient noise scale (DDP only)') 38 | p.add_argument('--grad-accum-steps', type=int, default=1, 39 | help='the number of gradient accumulation steps') 40 | p.add_argument('--grow', type=str, 41 | help='the checkpoint to grow from') 42 | p.add_argument('--grow-config', type=str, 43 | help='the configuration file of the model to grow from') 44 | p.add_argument('--lr', type=float, 45 | help='the learning rate') 46 | p.add_argument('--name', type=str, default='model', 47 | help='the name of the run') 48 | p.add_argument('--num-workers', type=int, default=8, 49 | help='the number of data loader workers') 50 | p.add_argument('--resume', type=str, 51 | help='the checkpoint to resume from') 52 | p.add_argument('--sample-n', type=int, default=64, 53 | help='the number of images to sample for demo grids') 54 | p.add_argument('--save-every', type=int, default=10000, 55 | help='save every this many steps') 56 | p.add_argument('--seed', type=int, 57 | help='the random seed') 58 | p.add_argument('--start-method', type=str, default='spawn', 59 | choices=['fork', 'forkserver', 'spawn'], 60 | help='the multiprocessing start method') 61 | p.add_argument('--wandb-entity', type=str, 62 | help='the wandb entity name') 63 | p.add_argument('--wandb-group', type=str, 64 | help='the wandb group name') 65 | p.add_argument('--wandb-project', type=str, 66 | help='the wandb project name (specify this to enable wandb)') 67 | p.add_argument('--wandb-save-model', action='store_true', 68 | help='save model to wandb') 69 | args = p.parse_args() 70 | 71 | mp.set_start_method(args.start_method) 72 | torch.backends.cuda.matmul.allow_tf32 = True 73 | 74 | config = K.config.load_config(open(args.config)) 75 | model_config = config['model'] 76 | dataset_config = config['dataset'] 77 | opt_config = config['optimizer'] 78 | sched_config = config['lr_sched'] 79 | ema_sched_config = config['ema_sched'] 80 | 81 | # TODO: allow non-square input sizes 82 | assert len(model_config['input_size']) == 2 and model_config['input_size'][0] == model_config['input_size'][1] 83 | size = model_config['input_size'] 84 | 85 | ddp_kwargs = accelerate.DistributedDataParallelKwargs(find_unused_parameters=model_config['skip_stages'] > 0) 86 | accelerator = accelerate.Accelerator(kwargs_handlers=[ddp_kwargs], gradient_accumulation_steps=args.grad_accum_steps) 87 | device = accelerator.device 88 | print(f'Process {accelerator.process_index} using device: {device}', flush=True) 89 | 90 | if args.seed is not None: 91 | seeds = torch.randint(-2 ** 63, 2 ** 63 - 1, [accelerator.num_processes], generator=torch.Generator().manual_seed(args.seed)) 92 | torch.manual_seed(seeds[accelerator.process_index]) 93 | 94 | inner_model = K.config.make_model(config) 95 | if accelerator.is_main_process: 96 | print('Parameters:', K.utils.n_params(inner_model)) 97 | 98 | # If logging to wandb, initialize the run 99 | use_wandb = accelerator.is_main_process and args.wandb_project 100 | if use_wandb: 101 | import wandb 102 | log_config = vars(args) 103 | log_config['config'] = config 104 | log_config['parameters'] = K.utils.n_params(inner_model) 105 | wandb.init(project=args.wandb_project, entity=args.wandb_entity, group=args.wandb_group, config=log_config, save_code=True) 106 | 107 | if opt_config['type'] == 'adamw': 108 | opt = optim.AdamW(inner_model.parameters(), 109 | lr=opt_config['lr'] if args.lr is None else args.lr, 110 | betas=tuple(opt_config['betas']), 111 | eps=opt_config['eps'], 112 | weight_decay=opt_config['weight_decay']) 113 | elif opt_config['type'] == 'sgd': 114 | opt = optim.SGD(inner_model.parameters(), 115 | lr=opt_config['lr'] if args.lr is None else args.lr, 116 | momentum=opt_config.get('momentum', 0.), 117 | nesterov=opt_config.get('nesterov', False), 118 | weight_decay=opt_config.get('weight_decay', 0.)) 119 | else: 120 | raise ValueError('Invalid optimizer type') 121 | 122 | if sched_config['type'] == 'inverse': 123 | sched = K.utils.InverseLR(opt, 124 | inv_gamma=sched_config['inv_gamma'], 125 | power=sched_config['power'], 126 | warmup=sched_config['warmup']) 127 | elif sched_config['type'] == 'exponential': 128 | sched = K.utils.ExponentialLR(opt, 129 | num_steps=sched_config['num_steps'], 130 | decay=sched_config['decay'], 131 | warmup=sched_config['warmup']) 132 | else: 133 | raise ValueError('Invalid schedule type') 134 | 135 | assert ema_sched_config['type'] == 'inverse' 136 | ema_sched = K.utils.EMAWarmup(power=ema_sched_config['power'], 137 | max_value=ema_sched_config['max_value']) 138 | 139 | tf = transforms.Compose([ 140 | transforms.Resize(size[0], interpolation=transforms.InterpolationMode.LANCZOS), 141 | transforms.CenterCrop(size[0]), 142 | K.augmentation.KarrasAugmentationPipeline(model_config['augment_prob']), 143 | ]) 144 | 145 | if dataset_config['type'] == 'imagefolder': 146 | train_set = K.utils.FolderOfImages(dataset_config['location'], transform=tf) 147 | elif dataset_config['type'] == 'cifar10': 148 | train_set = datasets.CIFAR10(dataset_config['location'], train=True, download=True, transform=tf) 149 | elif dataset_config['type'] == 'fashion': 150 | train_set = datasets.FashionMNIST(dataset_config['location'], train=True, download=True, transform=tf) 151 | elif dataset_config['type'] == 'mnist': 152 | train_set = datasets.MNIST(dataset_config['location'], train=True, download=True, transform=tf) 153 | elif dataset_config['type'] == 'huggingface': 154 | from datasets import load_dataset 155 | train_set = load_dataset(dataset_config['location']) 156 | train_set.set_transform(partial(K.utils.hf_datasets_augs_helper, transform=tf, image_key=dataset_config['image_key'])) 157 | train_set = train_set['train'] 158 | else: 159 | raise ValueError('Invalid dataset type') 160 | 161 | if accelerator.is_main_process: 162 | try: 163 | print('Number of items in dataset:', len(train_set)) 164 | except TypeError: 165 | pass 166 | 167 | image_key = dataset_config.get('image_key', 0) 168 | 169 | train_dl = data.DataLoader(train_set, args.batch_size, shuffle=True, drop_last=True, 170 | num_workers=args.num_workers, persistent_workers=True) 171 | 172 | if args.grow: 173 | if not args.grow_config: 174 | raise ValueError('--grow requires --grow-config') 175 | ckpt = torch.load(args.grow, map_location='cpu') 176 | old_config = K.config.load_config(open(args.grow_config)) 177 | old_inner_model = K.config.make_model(old_config) 178 | old_inner_model.load_state_dict(ckpt['model_ema']) 179 | if old_config['model']['skip_stages'] != model_config['skip_stages']: 180 | old_inner_model.set_skip_stages(model_config['skip_stages']) 181 | if old_config['model']['patch_size'] != model_config['patch_size']: 182 | old_inner_model.set_patch_size(model_config['patch_size']) 183 | inner_model.load_state_dict(old_inner_model.state_dict()) 184 | del ckpt, old_inner_model 185 | 186 | inner_model, opt, train_dl = accelerator.prepare(inner_model, opt, train_dl) 187 | if use_wandb: 188 | wandb.watch(inner_model) 189 | if args.gns: 190 | gns_stats_hook = K.gns.DDPGradientStatsHook(inner_model) 191 | gns_stats = K.gns.GradientNoiseScale() 192 | else: 193 | gns_stats = None 194 | sigma_min = model_config['sigma_min'] 195 | sigma_max = model_config['sigma_max'] 196 | sample_density = K.config.make_sample_density(model_config) 197 | 198 | model = K.config.make_denoiser_wrapper(config)(inner_model) 199 | model_ema = deepcopy(model) 200 | 201 | state_path = Path(f'{args.name}_state.json') 202 | 203 | if state_path.exists() or args.resume: 204 | if args.resume: 205 | ckpt_path = args.resume 206 | if not args.resume: 207 | state = json.load(open(state_path)) 208 | ckpt_path = state['latest_checkpoint'] 209 | if accelerator.is_main_process: 210 | print(f'Resuming from {ckpt_path}...') 211 | ckpt = torch.load(ckpt_path, map_location='cpu') 212 | accelerator.unwrap_model(model.inner_model).load_state_dict(ckpt['model']) 213 | accelerator.unwrap_model(model_ema.inner_model).load_state_dict(ckpt['model_ema']) 214 | opt.load_state_dict(ckpt['opt']) 215 | sched.load_state_dict(ckpt['sched']) 216 | ema_sched.load_state_dict(ckpt['ema_sched']) 217 | epoch = ckpt['epoch'] + 1 218 | step = ckpt['step'] + 1 219 | if args.gns and ckpt.get('gns_stats', None) is not None: 220 | gns_stats.load_state_dict(ckpt['gns_stats']) 221 | 222 | del ckpt 223 | else: 224 | epoch = 0 225 | step = 0 226 | 227 | evaluate_enabled = args.evaluate_every > 0 and args.evaluate_n > 0 228 | if evaluate_enabled: 229 | extractor = K.evaluation.InceptionV3FeatureExtractor(device=device) 230 | train_iter = iter(train_dl) 231 | if accelerator.is_main_process: 232 | print('Computing features for reals...') 233 | reals_features = K.evaluation.compute_features(accelerator, lambda x: next(train_iter)[image_key][1], extractor, args.evaluate_n, args.batch_size) 234 | if accelerator.is_main_process: 235 | metrics_log = K.utils.CSVLogger(f'{args.name}_metrics.csv', ['step', 'fid', 'kid']) 236 | del train_iter 237 | 238 | @torch.no_grad() 239 | @K.utils.eval_mode(model_ema) 240 | def demo(): 241 | if accelerator.is_main_process: 242 | tqdm.write('Sampling...') 243 | filename = f'{args.name}_demo_{step:08}.png' 244 | n_per_proc = math.ceil(args.sample_n / accelerator.num_processes) 245 | x = torch.randn([n_per_proc, model_config['input_channels'], size[0], size[1]], device=device) * sigma_max 246 | sigmas = K.sampling.get_sigmas_karras(50, sigma_min, sigma_max, rho=7., device=device) 247 | x_0 = K.sampling.sample_lms(model_ema, x, sigmas, disable=not accelerator.is_main_process) 248 | x_0 = accelerator.gather(x_0)[:args.sample_n] 249 | if accelerator.is_main_process: 250 | grid = utils.make_grid(x_0, nrow=math.ceil(args.sample_n ** 0.5), padding=0) 251 | K.utils.to_pil_image(grid).save(filename) 252 | if use_wandb: 253 | wandb.log({'demo_grid': wandb.Image(filename)}, step=step) 254 | 255 | @torch.no_grad() 256 | @K.utils.eval_mode(model_ema) 257 | def evaluate(): 258 | if not evaluate_enabled: 259 | return 260 | if accelerator.is_main_process: 261 | tqdm.write('Evaluating...') 262 | sigmas = K.sampling.get_sigmas_karras(50, sigma_min, sigma_max, rho=7., device=device) 263 | def sample_fn(n): 264 | x = torch.randn([n, model_config['input_channels'], size[0], size[1]], device=device) * sigma_max 265 | x_0 = K.sampling.sample_lms(model_ema, x, sigmas, disable=True) 266 | return x_0 267 | fakes_features = K.evaluation.compute_features(accelerator, sample_fn, extractor, args.evaluate_n, args.batch_size) 268 | if accelerator.is_main_process: 269 | fid = K.evaluation.fid(fakes_features, reals_features) 270 | kid = K.evaluation.kid(fakes_features, reals_features) 271 | print(f'FID: {fid.item():g}, KID: {kid.item():g}') 272 | if accelerator.is_main_process: 273 | metrics_log.write(step, fid.item(), kid.item()) 274 | if use_wandb: 275 | wandb.log({'FID': fid.item(), 'KID': kid.item()}, step=step) 276 | 277 | def save(): 278 | accelerator.wait_for_everyone() 279 | filename = f'{args.name}_{step:08}.pth' 280 | if accelerator.is_main_process: 281 | tqdm.write(f'Saving to {filename}...') 282 | obj = { 283 | 'model': accelerator.unwrap_model(model.inner_model).state_dict(), 284 | 'model_ema': accelerator.unwrap_model(model_ema.inner_model).state_dict(), 285 | 'opt': opt.state_dict(), 286 | 'sched': sched.state_dict(), 287 | 'ema_sched': ema_sched.state_dict(), 288 | 'epoch': epoch, 289 | 'step': step, 290 | 'gns_stats': gns_stats.state_dict() if gns_stats is not None else None, 291 | } 292 | accelerator.save(obj, filename) 293 | if accelerator.is_main_process: 294 | state_obj = {'latest_checkpoint': filename} 295 | json.dump(state_obj, open(state_path, 'w')) 296 | if args.wandb_save_model and use_wandb: 297 | wandb.save(filename) 298 | 299 | try: 300 | while True: 301 | for batch in tqdm(train_dl, disable=not accelerator.is_main_process): 302 | with accelerator.accumulate(model): 303 | reals, _, aug_cond = batch[image_key] 304 | noise = torch.randn_like(reals) 305 | sigma = sample_density([reals.shape[0]], device=device) 306 | losses = model.loss(reals, noise, sigma, aug_cond=aug_cond) 307 | losses_all = accelerator.gather(losses) 308 | loss = losses_all.mean() 309 | accelerator.backward(losses.mean()) 310 | if args.gns: 311 | sq_norm_small_batch, sq_norm_large_batch = gns_stats_hook.get_stats() 312 | gns_stats.update(sq_norm_small_batch, sq_norm_large_batch, reals.shape[0], reals.shape[0] * accelerator.num_processes) 313 | opt.step() 314 | sched.step() 315 | opt.zero_grad() 316 | if accelerator.sync_gradients: 317 | ema_decay = ema_sched.get_value() 318 | K.utils.ema_update(model, model_ema, ema_decay) 319 | ema_sched.step() 320 | 321 | if accelerator.is_main_process: 322 | if step % 25 == 0: 323 | if args.gns: 324 | tqdm.write(f'Epoch: {epoch}, step: {step}, loss: {loss.item():g}, gns: {gns_stats.get_gns():g}') 325 | else: 326 | tqdm.write(f'Epoch: {epoch}, step: {step}, loss: {loss.item():g}') 327 | 328 | if use_wandb: 329 | log_dict = { 330 | 'epoch': epoch, 331 | 'loss': loss.item(), 332 | 'lr': sched.get_last_lr()[0], 333 | 'ema_decay': ema_decay, 334 | } 335 | if args.gns: 336 | log_dict['gradient_noise_scale'] = gns_stats.get_gns() 337 | wandb.log(log_dict, step=step) 338 | 339 | if step % args.demo_every == 0: 340 | demo() 341 | 342 | if evaluate_enabled and step > 0 and step % args.evaluate_every == 0: 343 | evaluate() 344 | 345 | if step > 0 and step % args.save_every == 0: 346 | save() 347 | 348 | step += 1 349 | epoch += 1 350 | except KeyboardInterrupt: 351 | pass 352 | 353 | 354 | if __name__ == '__main__': 355 | main() 356 | -------------------------------------------------------------------------------- /k_diffusion/sampling.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from scipy import integrate 4 | import torch 5 | from torch import nn 6 | from torchdiffeq import odeint 7 | from tqdm.auto import trange, tqdm 8 | 9 | from . import utils 10 | 11 | 12 | def append_zero(x): 13 | return torch.cat([x, x.new_zeros([1])]) 14 | 15 | 16 | def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'): 17 | """Constructs the noise schedule of Karras et al. (2022).""" 18 | ramp = torch.linspace(0, 1, n) 19 | min_inv_rho = sigma_min ** (1 / rho) 20 | max_inv_rho = sigma_max ** (1 / rho) 21 | sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho 22 | return append_zero(sigmas).to(device) 23 | 24 | 25 | def get_sigmas_exponential(n, sigma_min, sigma_max, device='cpu'): 26 | """Constructs an exponential noise schedule.""" 27 | sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), n, device=device).exp() 28 | return append_zero(sigmas) 29 | 30 | 31 | def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'): 32 | """Constructs a continuous VP noise schedule.""" 33 | t = torch.linspace(1, eps_s, n, device=device) 34 | sigmas = torch.sqrt(torch.exp(beta_d * t ** 2 / 2 + beta_min * t) - 1) 35 | return append_zero(sigmas) 36 | 37 | 38 | def to_d(x, sigma, denoised): 39 | """Converts a denoiser output to a Karras ODE derivative.""" 40 | return (x - denoised) / utils.append_dims(sigma, x.ndim) 41 | 42 | 43 | def get_ancestral_step(sigma_from, sigma_to, eta=1.): 44 | """Calculates the noise level (sigma_down) to step down to and the amount 45 | of noise to add (sigma_up) when doing an ancestral sampling step.""" 46 | if not eta: 47 | return sigma_to, 0. 48 | sigma_up = min(sigma_to, eta * (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5) 49 | sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5 50 | return sigma_down, sigma_up 51 | 52 | 53 | @torch.no_grad() 54 | def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): 55 | """Implements Algorithm 2 (Euler steps) from Karras et al. (2022).""" 56 | extra_args = {} if extra_args is None else extra_args 57 | s_in = x.new_ones([x.shape[0]]) 58 | for i in trange(len(sigmas) - 1, disable=disable): 59 | gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. 60 | eps = torch.randn_like(x) * s_noise 61 | sigma_hat = sigmas[i] * (gamma + 1) 62 | if gamma > 0: 63 | x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 64 | denoised = model(x, sigma_hat * s_in, **extra_args) 65 | d = to_d(x, sigma_hat, denoised) 66 | if callback is not None: 67 | callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) 68 | dt = sigmas[i + 1] - sigma_hat 69 | # Euler method 70 | x = x + d * dt 71 | return x 72 | 73 | 74 | @torch.no_grad() 75 | def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.): 76 | """Ancestral sampling with Euler method steps.""" 77 | extra_args = {} if extra_args is None else extra_args 78 | s_in = x.new_ones([x.shape[0]]) 79 | for i in trange(len(sigmas) - 1, disable=disable): 80 | denoised = model(x, sigmas[i] * s_in, **extra_args) 81 | sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) 82 | if callback is not None: 83 | callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) 84 | d = to_d(x, sigmas[i], denoised) 85 | # Euler method 86 | dt = sigma_down - sigmas[i] 87 | x = x + d * dt 88 | x = x + torch.randn_like(x) * sigma_up 89 | return x 90 | 91 | 92 | @torch.no_grad() 93 | def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): 94 | """Implements Algorithm 2 (Heun steps) from Karras et al. (2022).""" 95 | extra_args = {} if extra_args is None else extra_args 96 | s_in = x.new_ones([x.shape[0]]) 97 | for i in trange(len(sigmas) - 1, disable=disable): 98 | gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. 99 | eps = torch.randn_like(x) * s_noise 100 | sigma_hat = sigmas[i] * (gamma + 1) 101 | if gamma > 0: 102 | x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 103 | denoised = model(x, sigma_hat * s_in, **extra_args) 104 | d = to_d(x, sigma_hat, denoised) 105 | if callback is not None: 106 | callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) 107 | dt = sigmas[i + 1] - sigma_hat 108 | if sigmas[i + 1] == 0: 109 | # Euler method 110 | x = x + d * dt 111 | else: 112 | # Heun's method 113 | x_2 = x + d * dt 114 | denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args) 115 | d_2 = to_d(x_2, sigmas[i + 1], denoised_2) 116 | d_prime = (d + d_2) / 2 117 | x = x + d_prime * dt 118 | return x 119 | 120 | 121 | @torch.no_grad() 122 | def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): 123 | """A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022).""" 124 | extra_args = {} if extra_args is None else extra_args 125 | s_in = x.new_ones([x.shape[0]]) 126 | for i in trange(len(sigmas) - 1, disable=disable): 127 | gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. 128 | eps = torch.randn_like(x) * s_noise 129 | sigma_hat = sigmas[i] * (gamma + 1) 130 | if gamma > 0: 131 | x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 132 | denoised = model(x, sigma_hat * s_in, **extra_args) 133 | d = to_d(x, sigma_hat, denoised) 134 | if callback is not None: 135 | callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) 136 | if sigmas[i + 1] == 0: 137 | # Euler method 138 | dt = sigmas[i + 1] - sigma_hat 139 | x = x + d * dt 140 | else: 141 | # DPM-Solver-2 142 | sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp() 143 | dt_1 = sigma_mid - sigma_hat 144 | dt_2 = sigmas[i + 1] - sigma_hat 145 | x_2 = x + d * dt_1 146 | denoised_2 = model(x_2, sigma_mid * s_in, **extra_args) 147 | d_2 = to_d(x_2, sigma_mid, denoised_2) 148 | x = x + d_2 * dt_2 149 | return x 150 | 151 | 152 | @torch.no_grad() 153 | def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.): 154 | """Ancestral sampling with DPM-Solver inspired second-order steps.""" 155 | extra_args = {} if extra_args is None else extra_args 156 | s_in = x.new_ones([x.shape[0]]) 157 | for i in trange(len(sigmas) - 1, disable=disable): 158 | denoised = model(x, sigmas[i] * s_in, **extra_args) 159 | sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) 160 | if callback is not None: 161 | callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) 162 | d = to_d(x, sigmas[i], denoised) 163 | if sigma_down == 0: 164 | # Euler method 165 | dt = sigma_down - sigmas[i] 166 | x = x + d * dt 167 | else: 168 | # DPM-Solver-2 169 | sigma_mid = sigmas[i].log().lerp(sigma_down.log(), 0.5).exp() 170 | dt_1 = sigma_mid - sigmas[i] 171 | dt_2 = sigma_down - sigmas[i] 172 | x_2 = x + d * dt_1 173 | denoised_2 = model(x_2, sigma_mid * s_in, **extra_args) 174 | d_2 = to_d(x_2, sigma_mid, denoised_2) 175 | x = x + d_2 * dt_2 176 | x = x + torch.randn_like(x) * sigma_up 177 | return x 178 | 179 | 180 | def linear_multistep_coeff(order, t, i, j): 181 | if order - 1 > i: 182 | raise ValueError(f'Order {order} too high for step {i}') 183 | def fn(tau): 184 | prod = 1. 185 | for k in range(order): 186 | if j == k: 187 | continue 188 | prod *= (tau - t[i - k]) / (t[i - j] - t[i - k]) 189 | return prod 190 | return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0] 191 | 192 | 193 | @torch.no_grad() 194 | def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4): 195 | extra_args = {} if extra_args is None else extra_args 196 | s_in = x.new_ones([x.shape[0]]) 197 | sigmas_cpu = sigmas.detach().cpu().numpy() 198 | ds = [] 199 | for i in trange(len(sigmas) - 1, disable=disable): 200 | denoised = model(x, sigmas[i] * s_in, **extra_args) 201 | d = to_d(x, sigmas[i], denoised) 202 | ds.append(d) 203 | if len(ds) > order: 204 | ds.pop(0) 205 | if callback is not None: 206 | callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) 207 | cur_order = min(i + 1, order) 208 | coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)] 209 | x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds))) 210 | return x 211 | 212 | 213 | @torch.no_grad() 214 | def log_likelihood(model, x, sigma_min, sigma_max, extra_args=None, atol=1e-4, rtol=1e-4): 215 | extra_args = {} if extra_args is None else extra_args 216 | s_in = x.new_ones([x.shape[0]]) 217 | v = torch.randint_like(x, 2) * 2 - 1 218 | fevals = 0 219 | def ode_fn(sigma, x): 220 | nonlocal fevals 221 | with torch.enable_grad(): 222 | x = x[0].detach().requires_grad_() 223 | denoised = model(x, sigma * s_in, **extra_args) 224 | d = to_d(x, sigma, denoised) 225 | fevals += 1 226 | grad = torch.autograd.grad((d * v).sum(), x)[0] 227 | d_ll = (v * grad).flatten(1).sum(1) 228 | return d.detach(), d_ll 229 | x_min = x, x.new_zeros([x.shape[0]]) 230 | t = x.new_tensor([sigma_min, sigma_max]) 231 | sol = odeint(ode_fn, x_min, t, atol=atol, rtol=rtol, method='dopri5') 232 | latent, delta_ll = sol[0][-1], sol[1][-1] 233 | ll_prior = torch.distributions.Normal(0, sigma_max).log_prob(latent).flatten(1).sum(1) 234 | return ll_prior + delta_ll, {'fevals': fevals} 235 | 236 | 237 | class PIDStepSizeController: 238 | """A PID controller for ODE adaptive step size control.""" 239 | def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8): 240 | self.h = h 241 | self.b1 = (pcoeff + icoeff + dcoeff) / order 242 | self.b2 = -(pcoeff + 2 * dcoeff) / order 243 | self.b3 = dcoeff / order 244 | self.accept_safety = accept_safety 245 | self.eps = eps 246 | self.errs = [] 247 | 248 | def limiter(self, x): 249 | return 1 + math.atan(x - 1) 250 | 251 | def propose_step(self, error): 252 | inv_error = 1 / (float(error) + self.eps) 253 | if not self.errs: 254 | self.errs = [inv_error, inv_error, inv_error] 255 | self.errs[0] = inv_error 256 | factor = self.errs[0] ** self.b1 * self.errs[1] ** self.b2 * self.errs[2] ** self.b3 257 | factor = self.limiter(factor) 258 | accept = factor >= self.accept_safety 259 | if accept: 260 | self.errs[2] = self.errs[1] 261 | self.errs[1] = self.errs[0] 262 | self.h *= factor 263 | return accept 264 | 265 | 266 | class DPMSolver(nn.Module): 267 | """DPM-Solver. See https://arxiv.org/abs/2206.00927.""" 268 | 269 | def __init__(self, model, extra_args=None, eps_callback=None, info_callback=None): 270 | super().__init__() 271 | self.model = model 272 | self.extra_args = {} if extra_args is None else extra_args 273 | self.eps_callback = eps_callback 274 | self.info_callback = info_callback 275 | 276 | def t(self, sigma): 277 | return -sigma.log() 278 | 279 | def sigma(self, t): 280 | return t.neg().exp() 281 | 282 | def eps(self, eps_cache, key, x, t, *args, **kwargs): 283 | if key in eps_cache: 284 | return eps_cache[key], eps_cache 285 | sigma = self.sigma(t) * x.new_ones([x.shape[0]]) 286 | eps = (x - self.model(x, sigma, *args, **self.extra_args, **kwargs)) / self.sigma(t) 287 | if self.eps_callback is not None: 288 | self.eps_callback() 289 | return eps, {key: eps, **eps_cache} 290 | 291 | def dpm_solver_1_step(self, x, t, t_next, eps_cache=None): 292 | eps_cache = {} if eps_cache is None else eps_cache 293 | h = t_next - t 294 | eps, eps_cache = self.eps(eps_cache, 'eps', x, t) 295 | x_1 = x - self.sigma(t_next) * h.expm1() * eps 296 | return x_1, eps_cache 297 | 298 | def dpm_solver_2_step(self, x, t, t_next, r1=1 / 2, eps_cache=None): 299 | eps_cache = {} if eps_cache is None else eps_cache 300 | h = t_next - t 301 | eps, eps_cache = self.eps(eps_cache, 'eps', x, t) 302 | s1 = t + r1 * h 303 | u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps 304 | eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1) 305 | x_2 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / (2 * r1) * h.expm1() * (eps_r1 - eps) 306 | return x_2, eps_cache 307 | 308 | def dpm_solver_3_step(self, x, t, t_next, r1=1 / 3, r2=2 / 3, eps_cache=None): 309 | eps_cache = {} if eps_cache is None else eps_cache 310 | h = t_next - t 311 | eps, eps_cache = self.eps(eps_cache, 'eps', x, t) 312 | s1 = t + r1 * h 313 | s2 = t + r2 * h 314 | u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps 315 | eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1) 316 | u2 = x - self.sigma(s2) * (r2 * h).expm1() * eps - self.sigma(s2) * (r2 / r1) * ((r2 * h).expm1() / (r2 * h) - 1) * (eps_r1 - eps) 317 | eps_r2, eps_cache = self.eps(eps_cache, 'eps_r2', u2, s2) 318 | x_3 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / r2 * (h.expm1() / h - 1) * (eps_r2 - eps) 319 | return x_3, eps_cache 320 | 321 | def dpm_solver_fast(self, x, t_start, t_end, nfe, eta=0., s_noise=1.): 322 | if not t_end > t_start and eta: 323 | raise ValueError('eta must be 0 for reverse sampling') 324 | 325 | m = math.floor(nfe / 3) + 1 326 | ts = torch.linspace(t_start, t_end, m + 1, device=x.device) 327 | 328 | if nfe % 3 == 0: 329 | orders = [3] * (m - 2) + [2, 1] 330 | else: 331 | orders = [3] * (m - 1) + [nfe % 3] 332 | 333 | for i in range(len(orders)): 334 | eps_cache = {} 335 | t, t_next = ts[i], ts[i + 1] 336 | if eta: 337 | sd, su = get_ancestral_step(self.sigma(t), self.sigma(t_next), eta) 338 | t_next_ = torch.minimum(t_end, self.t(sd)) 339 | su = (self.sigma(t_next) ** 2 - self.sigma(t_next_) ** 2) ** 0.5 340 | else: 341 | t_next_, su = t_next, 0. 342 | 343 | eps, eps_cache = self.eps(eps_cache, 'eps', x, t) 344 | denoised = x - self.sigma(t) * eps 345 | if self.info_callback is not None: 346 | self.info_callback({'x': x, 'i': i, 't': ts[i], 't_up': t, 'denoised': denoised}) 347 | 348 | if orders[i] == 1: 349 | x, eps_cache = self.dpm_solver_1_step(x, t, t_next_, eps_cache=eps_cache) 350 | elif orders[i] == 2: 351 | x, eps_cache = self.dpm_solver_2_step(x, t, t_next_, eps_cache=eps_cache) 352 | else: 353 | x, eps_cache = self.dpm_solver_3_step(x, t, t_next_, eps_cache=eps_cache) 354 | 355 | x = x + su * s_noise * torch.randn_like(x) 356 | 357 | return x 358 | 359 | def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1.): 360 | if order not in {2, 3}: 361 | raise ValueError('order should be 2 or 3') 362 | forward = t_end > t_start 363 | if not forward and eta: 364 | raise ValueError('eta must be 0 for reverse sampling') 365 | h_init = abs(h_init) * (1 if forward else -1) 366 | atol = torch.tensor(atol) 367 | rtol = torch.tensor(rtol) 368 | s = t_start 369 | x_prev = x 370 | accept = True 371 | pid = PIDStepSizeController(h_init, pcoeff, icoeff, dcoeff, 1.5 if eta else order, accept_safety) 372 | info = {'steps': 0, 'nfe': 0, 'n_accept': 0, 'n_reject': 0} 373 | 374 | while s < t_end - 1e-5 if forward else s > t_end + 1e-5: 375 | eps_cache = {} 376 | t = torch.minimum(t_end, s + pid.h) if forward else torch.maximum(t_end, s + pid.h) 377 | if eta: 378 | sd, su = get_ancestral_step(self.sigma(s), self.sigma(t), eta) 379 | t_ = torch.minimum(t_end, self.t(sd)) 380 | su = (self.sigma(t) ** 2 - self.sigma(t_) ** 2) ** 0.5 381 | else: 382 | t_, su = t, 0. 383 | 384 | eps, eps_cache = self.eps(eps_cache, 'eps', x, s) 385 | denoised = x - self.sigma(s) * eps 386 | 387 | if order == 2: 388 | x_low, eps_cache = self.dpm_solver_1_step(x, s, t_, eps_cache=eps_cache) 389 | x_high, eps_cache = self.dpm_solver_2_step(x, s, t_, eps_cache=eps_cache) 390 | else: 391 | x_low, eps_cache = self.dpm_solver_2_step(x, s, t_, r1=1 / 3, eps_cache=eps_cache) 392 | x_high, eps_cache = self.dpm_solver_3_step(x, s, t_, eps_cache=eps_cache) 393 | delta = torch.maximum(atol, rtol * torch.maximum(x_low.abs(), x_prev.abs())) 394 | error = torch.linalg.norm((x_low - x_high) / delta) / x.numel() ** 0.5 395 | accept = pid.propose_step(error) 396 | if accept: 397 | x_prev = x_low 398 | x = x_high + su * s_noise * torch.randn_like(x_high) 399 | s = t 400 | info['n_accept'] += 1 401 | else: 402 | info['n_reject'] += 1 403 | info['nfe'] += order 404 | info['steps'] += 1 405 | 406 | if self.info_callback is not None: 407 | self.info_callback({'x': x, 'i': info['steps'] - 1, 't': s, 't_up': s, 'denoised': denoised, 'error': error, 'h': pid.h, **info}) 408 | 409 | return x, info 410 | 411 | 412 | @torch.no_grad() 413 | def sample_dpm_fast(model, x, sigma_min, sigma_max, n, extra_args=None, callback=None, disable=None, eta=0., s_noise=1.): 414 | """DPM-Solver-Fast (fixed step size). See https://arxiv.org/abs/2206.00927.""" 415 | if sigma_min <= 0 or sigma_max <= 0: 416 | raise ValueError('sigma_min and sigma_max must not be 0') 417 | with tqdm(total=n, disable=disable) as pbar: 418 | dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update) 419 | if callback is not None: 420 | dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info}) 421 | return dpm_solver.dpm_solver_fast(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), n, eta, s_noise) 422 | 423 | 424 | @torch.no_grad() 425 | def sample_dpm_adaptive(model, x, sigma_min, sigma_max, extra_args=None, callback=None, disable=None, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., return_info=False): 426 | """DPM-Solver-12 and 23 (adaptive step size). See https://arxiv.org/abs/2206.00927.""" 427 | if sigma_min <= 0 or sigma_max <= 0: 428 | raise ValueError('sigma_min and sigma_max must not be 0') 429 | with tqdm(disable=disable) as pbar: 430 | dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update) 431 | if callback is not None: 432 | dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info}) 433 | x, info = dpm_solver.dpm_solver_adaptive(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise) 434 | if return_info: 435 | return x, info 436 | return x 437 | --------------------------------------------------------------------------------