├── .gitignore ├── LICENSE ├── README.md ├── cm ├── __init__.py ├── dist_util.py ├── fp16_util.py ├── image_datasets.py ├── karras_diffusion.py ├── logger.py ├── losses.py ├── nn.py ├── random_util.py ├── resample.py ├── script_util.py ├── train_util.py └── unet.py ├── datasets ├── README.md └── lsun_bedroom.py ├── docker ├── Dockerfile └── Makefile ├── evaluations ├── __init__.py ├── evaluator.py ├── inception_v3.py ├── requirements.txt └── th_evaluator.py ├── model-card.md ├── scripts ├── cm_train.py ├── edm_train.py ├── image_sample.py ├── launch.sh └── ternary_search.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled python modules. 2 | *.pyc 3 | 4 | # Byte-compiled 5 | _pycache__/ 6 | .cache/ 7 | .idea/ 8 | 9 | # Python egg metadata, regenerated from source files by setuptools. 10 | /*.egg-info 11 | .eggs/ 12 | 13 | # PyPI distribution artifacts. 14 | build/ 15 | dist/ 16 | 17 | # Tests 18 | .pytest_cache/ 19 | 20 | # Other 21 | *.DS_Store 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 OpenAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Consistency Models 2 | 3 | This repository contains the codebase for [Consistency Models](https://arxiv.org/abs/2303.01469), implemented using PyTorch for conducting large-scale experiments on ImageNet-64, LSUN Bedroom-256, and LSUN Cat-256. We have based our repository on [openai/guided-diffusion](https://github.com/openai/guided-diffusion), which was initially released under the MIT license. Our modifications have enabled support for consistency distillation, consistency training, as well as several sampling and editing algorithms discussed in the paper. 4 | 5 | The repository for CIFAR-10 experiments is in JAX and can be found at [openai/consistency_models_cifar10](https://github.com/openai/consistency_models_cifar10). 6 | 7 | # Pre-trained models 8 | 9 | We have released checkpoints for the main models in the paper. Before using these models, please review the corresponding [model card](model-card.md) to understand the intended use and limitations of these models. 10 | 11 | Here are the download links for each model checkpoint: 12 | 13 | * EDM on ImageNet-64: [edm_imagenet64_ema.pt](https://openaipublic.blob.core.windows.net/consistency/edm_imagenet64_ema.pt) 14 | * CD on ImageNet-64 with l2 metric: [cd_imagenet64_l2.pt](https://openaipublic.blob.core.windows.net/consistency/cd_imagenet64_l2.pt) 15 | * CD on ImageNet-64 with LPIPS metric: [cd_imagenet64_lpips.pt](https://openaipublic.blob.core.windows.net/consistency/cd_imagenet64_lpips.pt) 16 | * CT on ImageNet-64: [ct_imagenet64.pt](https://openaipublic.blob.core.windows.net/consistency/ct_imagenet64.pt) 17 | * EDM on LSUN Bedroom-256: [edm_bedroom256_ema.pt](https://openaipublic.blob.core.windows.net/consistency/edm_bedroom256_ema.pt) 18 | * CD on LSUN Bedroom-256 with l2 metric: [cd_bedroom256_l2.pt](https://openaipublic.blob.core.windows.net/consistency/cd_bedroom256_l2.pt) 19 | * CD on LSUN Bedroom-256 with LPIPS metric: [cd_bedroom256_lpips.pt](https://openaipublic.blob.core.windows.net/consistency/cd_bedroom256_lpips.pt) 20 | * CT on LSUN Bedroom-256: [ct_bedroom256.pt](https://openaipublic.blob.core.windows.net/consistency/ct_bedroom256.pt) 21 | * EDM on LSUN Cat-256: [edm_cat256_ema.pt](https://openaipublic.blob.core.windows.net/consistency/edm_cat256_ema.pt) 22 | * CD on LSUN Cat-256 with l2 metric: [cd_cat256_l2.pt](https://openaipublic.blob.core.windows.net/consistency/cd_cat256_l2.pt) 23 | * CD on LSUN Cat-256 with LPIPS metric: [cd_cat256_lpips.pt](https://openaipublic.blob.core.windows.net/consistency/cd_cat256_lpips.pt) 24 | * CT on LSUN Cat-256: [ct_cat256.pt](https://openaipublic.blob.core.windows.net/consistency/ct_cat256.pt) 25 | 26 | # Dependencies 27 | 28 | To install all packages in this codebase along with their dependencies, run 29 | ```sh 30 | pip install -e . 31 | ``` 32 | 33 | To install with Docker, run the following commands: 34 | ```sh 35 | cd docker && make build && make run 36 | ``` 37 | 38 | # Model training and sampling 39 | 40 | We provide examples of EDM training, consistency distillation, consistency training, single-step generation, and multistep generation in [scripts/launch.sh](scripts/launch.sh). 41 | 42 | # Evaluations 43 | 44 | To compare different generative models, we use FID, Precision, Recall, and Inception Score. These metrics can all be calculated using batches of samples stored in `.npz` (numpy) files. One can evaluate samples with [cm/evaluations/evaluator.py](evaluations/evaluator.py) in the same way as described in [openai/guided-diffusion](https://github.com/openai/guided-diffusion), with reference dataset batches provided therein. 45 | 46 | ## Use in 🧨 diffusers 47 | 48 | Consistency models are supported in [🧨 diffusers](https://github.com/huggingface/diffusers) via the [`ConsistencyModelPipeline` class](https://huggingface.co/docs/diffusers/main/en/api/pipelines/consistency_models). Below we provide an example: 49 | 50 | ```python 51 | import torch 52 | 53 | from diffusers import ConsistencyModelPipeline 54 | 55 | device = "cuda" 56 | # Load the cd_imagenet64_l2 checkpoint. 57 | model_id_or_path = "openai/diffusers-cd_imagenet64_l2" 58 | pipe = ConsistencyModelPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16) 59 | pipe.to(device) 60 | 61 | # Onestep Sampling 62 | image = pipe(num_inference_steps=1).images[0] 63 | image.save("consistency_model_onestep_sample.png") 64 | 65 | # Onestep sampling, class-conditional image generation 66 | # ImageNet-64 class label 145 corresponds to king penguins 67 | 68 | class_id = 145 69 | class_id = torch.tensor(class_id, dtype=torch.long) 70 | 71 | image = pipe(num_inference_steps=1, class_labels=class_id).images[0] 72 | image.save("consistency_model_onestep_sample_penguin.png") 73 | 74 | # Multistep sampling, class-conditional image generation 75 | # Timesteps can be explicitly specified; the particular timesteps below are from the original Github repo. 76 | # https://github.com/openai/consistency_models/blob/main/scripts/launch.sh#L77 77 | image = pipe(timesteps=[22, 0], class_labels=class_id).images[0] 78 | image.save("consistency_model_multistep_sample_penguin.png") 79 | ``` 80 | You can further speed up the inference process by using `torch.compile()` on `pipe.unet` (only supported from PyTorch 2.0). For more details, please check out the [official documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/consistency_models). This support was contributed to 🧨 diffusers by [dg845](https://github.com/dg845) and [ayushtues](https://github.com/ayushtues). 81 | 82 | # Citation 83 | 84 | If you find this method and/or code useful, please consider citing 85 | 86 | ```bibtex 87 | @article{song2023consistency, 88 | title={Consistency Models}, 89 | author={Song, Yang and Dhariwal, Prafulla and Chen, Mark and Sutskever, Ilya}, 90 | journal={arXiv preprint arXiv:2303.01469}, 91 | year={2023}, 92 | } 93 | ``` 94 | -------------------------------------------------------------------------------- /cm/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Codebase for "Improved Denoising Diffusion Probabilistic Models". 3 | """ 4 | -------------------------------------------------------------------------------- /cm/dist_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for distributed training. 3 | """ 4 | 5 | import io 6 | import os 7 | import socket 8 | 9 | import blobfile as bf 10 | from mpi4py import MPI 11 | import torch as th 12 | import torch.distributed as dist 13 | 14 | # Change this to reflect your cluster layout. 15 | # The GPU for a given rank is (rank % GPUS_PER_NODE). 16 | GPUS_PER_NODE = 8 17 | 18 | SETUP_RETRY_COUNT = 3 19 | 20 | 21 | def setup_dist(): 22 | """ 23 | Setup a distributed process group. 24 | """ 25 | if dist.is_initialized(): 26 | return 27 | os.environ["CUDA_VISIBLE_DEVICES"] = f"{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}" 28 | 29 | comm = MPI.COMM_WORLD 30 | backend = "gloo" if not th.cuda.is_available() else "nccl" 31 | 32 | if backend == "gloo": 33 | hostname = "localhost" 34 | else: 35 | hostname = socket.gethostbyname(socket.getfqdn()) 36 | os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0) 37 | os.environ["RANK"] = str(comm.rank) 38 | os.environ["WORLD_SIZE"] = str(comm.size) 39 | 40 | port = comm.bcast(_find_free_port(), root=0) 41 | os.environ["MASTER_PORT"] = str(port) 42 | dist.init_process_group(backend=backend, init_method="env://") 43 | 44 | 45 | def dev(): 46 | """ 47 | Get the device to use for torch.distributed. 48 | """ 49 | if th.cuda.is_available(): 50 | return th.device("cuda") 51 | return th.device("cpu") 52 | 53 | 54 | def load_state_dict(path, **kwargs): 55 | """ 56 | Load a PyTorch file without redundant fetches across MPI ranks. 57 | """ 58 | chunk_size = 2**30 # MPI has a relatively small size limit 59 | if MPI.COMM_WORLD.Get_rank() == 0: 60 | with bf.BlobFile(path, "rb") as f: 61 | data = f.read() 62 | num_chunks = len(data) // chunk_size 63 | if len(data) % chunk_size: 64 | num_chunks += 1 65 | MPI.COMM_WORLD.bcast(num_chunks) 66 | for i in range(0, len(data), chunk_size): 67 | MPI.COMM_WORLD.bcast(data[i : i + chunk_size]) 68 | else: 69 | num_chunks = MPI.COMM_WORLD.bcast(None) 70 | data = bytes() 71 | for _ in range(num_chunks): 72 | data += MPI.COMM_WORLD.bcast(None) 73 | 74 | return th.load(io.BytesIO(data), **kwargs) 75 | 76 | 77 | def sync_params(params): 78 | """ 79 | Synchronize a sequence of Tensors across ranks from rank 0. 80 | """ 81 | for p in params: 82 | with th.no_grad(): 83 | dist.broadcast(p, 0) 84 | 85 | 86 | def _find_free_port(): 87 | try: 88 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 89 | s.bind(("", 0)) 90 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 91 | return s.getsockname()[1] 92 | finally: 93 | s.close() 94 | -------------------------------------------------------------------------------- /cm/fp16_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers to train with 16-bit precision. 3 | """ 4 | 5 | import numpy as np 6 | import torch as th 7 | import torch.nn as nn 8 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 9 | 10 | from . import logger 11 | 12 | INITIAL_LOG_LOSS_SCALE = 20.0 13 | 14 | 15 | def convert_module_to_f16(l): 16 | """ 17 | Convert primitive modules to float16. 18 | """ 19 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 20 | l.weight.data = l.weight.data.half() 21 | if l.bias is not None: 22 | l.bias.data = l.bias.data.half() 23 | 24 | 25 | def convert_module_to_f32(l): 26 | """ 27 | Convert primitive modules to float32, undoing convert_module_to_f16(). 28 | """ 29 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 30 | l.weight.data = l.weight.data.float() 31 | if l.bias is not None: 32 | l.bias.data = l.bias.data.float() 33 | 34 | 35 | def make_master_params(param_groups_and_shapes): 36 | """ 37 | Copy model parameters into a (differently-shaped) list of full-precision 38 | parameters. 39 | """ 40 | master_params = [] 41 | for param_group, shape in param_groups_and_shapes: 42 | master_param = nn.Parameter( 43 | _flatten_dense_tensors( 44 | [param.detach().float() for (_, param) in param_group] 45 | ).view(shape) 46 | ) 47 | master_param.requires_grad = True 48 | master_params.append(master_param) 49 | return master_params 50 | 51 | 52 | def model_grads_to_master_grads(param_groups_and_shapes, master_params): 53 | """ 54 | Copy the gradients from the model parameters into the master parameters 55 | from make_master_params(). 56 | """ 57 | for master_param, (param_group, shape) in zip( 58 | master_params, param_groups_and_shapes 59 | ): 60 | master_param.grad = _flatten_dense_tensors( 61 | [param_grad_or_zeros(param) for (_, param) in param_group] 62 | ).view(shape) 63 | 64 | 65 | def master_params_to_model_params(param_groups_and_shapes, master_params): 66 | """ 67 | Copy the master parameter data back into the model parameters. 68 | """ 69 | # Without copying to a list, if a generator is passed, this will 70 | # silently not copy any parameters. 71 | for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes): 72 | for (_, param), unflat_master_param in zip( 73 | param_group, unflatten_master_params(param_group, master_param.view(-1)) 74 | ): 75 | param.detach().copy_(unflat_master_param) 76 | 77 | 78 | def unflatten_master_params(param_group, master_param): 79 | return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group]) 80 | 81 | 82 | def get_param_groups_and_shapes(named_model_params): 83 | named_model_params = list(named_model_params) 84 | scalar_vector_named_params = ( 85 | [(n, p) for (n, p) in named_model_params if p.ndim <= 1], 86 | (-1), 87 | ) 88 | matrix_named_params = ( 89 | [(n, p) for (n, p) in named_model_params if p.ndim > 1], 90 | (1, -1), 91 | ) 92 | return [scalar_vector_named_params, matrix_named_params] 93 | 94 | 95 | def master_params_to_state_dict( 96 | model, param_groups_and_shapes, master_params, use_fp16 97 | ): 98 | if use_fp16: 99 | state_dict = model.state_dict() 100 | for master_param, (param_group, _) in zip( 101 | master_params, param_groups_and_shapes 102 | ): 103 | for (name, _), unflat_master_param in zip( 104 | param_group, unflatten_master_params(param_group, master_param.view(-1)) 105 | ): 106 | assert name in state_dict 107 | state_dict[name] = unflat_master_param 108 | else: 109 | state_dict = model.state_dict() 110 | for i, (name, _value) in enumerate(model.named_parameters()): 111 | assert name in state_dict 112 | state_dict[name] = master_params[i] 113 | return state_dict 114 | 115 | 116 | def state_dict_to_master_params(model, state_dict, use_fp16): 117 | if use_fp16: 118 | named_model_params = [ 119 | (name, state_dict[name]) for name, _ in model.named_parameters() 120 | ] 121 | param_groups_and_shapes = get_param_groups_and_shapes(named_model_params) 122 | master_params = make_master_params(param_groups_and_shapes) 123 | else: 124 | master_params = [state_dict[name] for name, _ in model.named_parameters()] 125 | return master_params 126 | 127 | 128 | def zero_master_grads(master_params): 129 | for param in master_params: 130 | param.grad = None 131 | 132 | 133 | def zero_grad(model_params): 134 | for param in model_params: 135 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group 136 | if param.grad is not None: 137 | param.grad.detach_() 138 | param.grad.zero_() 139 | 140 | 141 | def param_grad_or_zeros(param): 142 | if param.grad is not None: 143 | return param.grad.data.detach() 144 | else: 145 | return th.zeros_like(param) 146 | 147 | 148 | class MixedPrecisionTrainer: 149 | def __init__( 150 | self, 151 | *, 152 | model, 153 | use_fp16=False, 154 | fp16_scale_growth=1e-3, 155 | initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE, 156 | ): 157 | self.model = model 158 | self.use_fp16 = use_fp16 159 | self.fp16_scale_growth = fp16_scale_growth 160 | 161 | self.model_params = list(self.model.parameters()) 162 | self.master_params = self.model_params 163 | self.param_groups_and_shapes = None 164 | self.lg_loss_scale = initial_lg_loss_scale 165 | 166 | if self.use_fp16: 167 | self.param_groups_and_shapes = get_param_groups_and_shapes( 168 | self.model.named_parameters() 169 | ) 170 | self.master_params = make_master_params(self.param_groups_and_shapes) 171 | self.model.convert_to_fp16() 172 | 173 | def zero_grad(self): 174 | zero_grad(self.model_params) 175 | 176 | def backward(self, loss: th.Tensor): 177 | if self.use_fp16: 178 | loss_scale = 2**self.lg_loss_scale 179 | (loss * loss_scale).backward() 180 | else: 181 | loss.backward() 182 | 183 | def optimize(self, opt: th.optim.Optimizer): 184 | if self.use_fp16: 185 | return self._optimize_fp16(opt) 186 | else: 187 | return self._optimize_normal(opt) 188 | 189 | def _optimize_fp16(self, opt: th.optim.Optimizer): 190 | logger.logkv_mean("lg_loss_scale", self.lg_loss_scale) 191 | model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params) 192 | grad_norm, param_norm = self._compute_norms(grad_scale=2**self.lg_loss_scale) 193 | if check_overflow(grad_norm): 194 | self.lg_loss_scale -= 1 195 | logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}") 196 | zero_master_grads(self.master_params) 197 | return False 198 | 199 | logger.logkv_mean("grad_norm", grad_norm) 200 | logger.logkv_mean("param_norm", param_norm) 201 | 202 | for p in self.master_params: 203 | p.grad.mul_(1.0 / (2**self.lg_loss_scale)) 204 | opt.step() 205 | zero_master_grads(self.master_params) 206 | master_params_to_model_params(self.param_groups_and_shapes, self.master_params) 207 | self.lg_loss_scale += self.fp16_scale_growth 208 | return True 209 | 210 | def _optimize_normal(self, opt: th.optim.Optimizer): 211 | grad_norm, param_norm = self._compute_norms() 212 | logger.logkv_mean("grad_norm", grad_norm) 213 | logger.logkv_mean("param_norm", param_norm) 214 | opt.step() 215 | return True 216 | 217 | def _compute_norms(self, grad_scale=1.0): 218 | grad_norm = 0.0 219 | param_norm = 0.0 220 | for p in self.master_params: 221 | with th.no_grad(): 222 | param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2 223 | if p.grad is not None: 224 | grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2 225 | return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm) 226 | 227 | def master_params_to_state_dict(self, master_params): 228 | return master_params_to_state_dict( 229 | self.model, self.param_groups_and_shapes, master_params, self.use_fp16 230 | ) 231 | 232 | def state_dict_to_master_params(self, state_dict): 233 | return state_dict_to_master_params(self.model, state_dict, self.use_fp16) 234 | 235 | 236 | def check_overflow(value): 237 | return (value == float("inf")) or (value == -float("inf")) or (value != value) 238 | -------------------------------------------------------------------------------- /cm/image_datasets.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | 4 | from PIL import Image 5 | import blobfile as bf 6 | from mpi4py import MPI 7 | import numpy as np 8 | from torch.utils.data import DataLoader, Dataset 9 | 10 | 11 | def load_data( 12 | *, 13 | data_dir, 14 | batch_size, 15 | image_size, 16 | class_cond=False, 17 | deterministic=False, 18 | random_crop=False, 19 | random_flip=True, 20 | ): 21 | """ 22 | For a dataset, create a generator over (images, kwargs) pairs. 23 | 24 | Each images is an NCHW float tensor, and the kwargs dict contains zero or 25 | more keys, each of which map to a batched Tensor of their own. 26 | The kwargs dict can be used for class labels, in which case the key is "y" 27 | and the values are integer tensors of class labels. 28 | 29 | :param data_dir: a dataset directory. 30 | :param batch_size: the batch size of each returned pair. 31 | :param image_size: the size to which images are resized. 32 | :param class_cond: if True, include a "y" key in returned dicts for class 33 | label. If classes are not available and this is true, an 34 | exception will be raised. 35 | :param deterministic: if True, yield results in a deterministic order. 36 | :param random_crop: if True, randomly crop the images for augmentation. 37 | :param random_flip: if True, randomly flip the images for augmentation. 38 | """ 39 | if not data_dir: 40 | raise ValueError("unspecified data directory") 41 | all_files = _list_image_files_recursively(data_dir) 42 | classes = None 43 | if class_cond: 44 | # Assume classes are the first part of the filename, 45 | # before an underscore. 46 | class_names = [bf.basename(path).split("_")[0] for path in all_files] 47 | sorted_classes = {x: i for i, x in enumerate(sorted(set(class_names)))} 48 | classes = [sorted_classes[x] for x in class_names] 49 | dataset = ImageDataset( 50 | image_size, 51 | all_files, 52 | classes=classes, 53 | shard=MPI.COMM_WORLD.Get_rank(), 54 | num_shards=MPI.COMM_WORLD.Get_size(), 55 | random_crop=random_crop, 56 | random_flip=random_flip, 57 | ) 58 | if deterministic: 59 | loader = DataLoader( 60 | dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True 61 | ) 62 | else: 63 | loader = DataLoader( 64 | dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True 65 | ) 66 | while True: 67 | yield from loader 68 | 69 | 70 | def _list_image_files_recursively(data_dir): 71 | results = [] 72 | for entry in sorted(bf.listdir(data_dir)): 73 | full_path = bf.join(data_dir, entry) 74 | ext = entry.split(".")[-1] 75 | if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif"]: 76 | results.append(full_path) 77 | elif bf.isdir(full_path): 78 | results.extend(_list_image_files_recursively(full_path)) 79 | return results 80 | 81 | 82 | class ImageDataset(Dataset): 83 | def __init__( 84 | self, 85 | resolution, 86 | image_paths, 87 | classes=None, 88 | shard=0, 89 | num_shards=1, 90 | random_crop=False, 91 | random_flip=True, 92 | ): 93 | super().__init__() 94 | self.resolution = resolution 95 | self.local_images = image_paths[shard:][::num_shards] 96 | self.local_classes = None if classes is None else classes[shard:][::num_shards] 97 | self.random_crop = random_crop 98 | self.random_flip = random_flip 99 | 100 | def __len__(self): 101 | return len(self.local_images) 102 | 103 | def __getitem__(self, idx): 104 | path = self.local_images[idx] 105 | with bf.BlobFile(path, "rb") as f: 106 | pil_image = Image.open(f) 107 | pil_image.load() 108 | pil_image = pil_image.convert("RGB") 109 | 110 | if self.random_crop: 111 | arr = random_crop_arr(pil_image, self.resolution) 112 | else: 113 | arr = center_crop_arr(pil_image, self.resolution) 114 | 115 | if self.random_flip and random.random() < 0.5: 116 | arr = arr[:, ::-1] 117 | 118 | arr = arr.astype(np.float32) / 127.5 - 1 119 | 120 | out_dict = {} 121 | if self.local_classes is not None: 122 | out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64) 123 | return np.transpose(arr, [2, 0, 1]), out_dict 124 | 125 | 126 | def center_crop_arr(pil_image, image_size): 127 | # We are not on a new enough PIL to support the `reducing_gap` 128 | # argument, which uses BOX downsampling at powers of two first. 129 | # Thus, we do it by hand to improve downsample quality. 130 | while min(*pil_image.size) >= 2 * image_size: 131 | pil_image = pil_image.resize( 132 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 133 | ) 134 | 135 | scale = image_size / min(*pil_image.size) 136 | pil_image = pil_image.resize( 137 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 138 | ) 139 | 140 | arr = np.array(pil_image) 141 | crop_y = (arr.shape[0] - image_size) // 2 142 | crop_x = (arr.shape[1] - image_size) // 2 143 | return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size] 144 | 145 | 146 | def random_crop_arr(pil_image, image_size, min_crop_frac=0.8, max_crop_frac=1.0): 147 | min_smaller_dim_size = math.ceil(image_size / max_crop_frac) 148 | max_smaller_dim_size = math.ceil(image_size / min_crop_frac) 149 | smaller_dim_size = random.randrange(min_smaller_dim_size, max_smaller_dim_size + 1) 150 | 151 | # We are not on a new enough PIL to support the `reducing_gap` 152 | # argument, which uses BOX downsampling at powers of two first. 153 | # Thus, we do it by hand to improve downsample quality. 154 | while min(*pil_image.size) >= 2 * smaller_dim_size: 155 | pil_image = pil_image.resize( 156 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 157 | ) 158 | 159 | scale = smaller_dim_size / min(*pil_image.size) 160 | pil_image = pil_image.resize( 161 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 162 | ) 163 | 164 | arr = np.array(pil_image) 165 | crop_y = random.randrange(arr.shape[0] - image_size + 1) 166 | crop_x = random.randrange(arr.shape[1] - image_size + 1) 167 | return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size] 168 | -------------------------------------------------------------------------------- /cm/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Logger copied from OpenAI baselines to avoid extra RL-based dependencies: 3 | https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py 4 | """ 5 | 6 | import os 7 | import sys 8 | import shutil 9 | import os.path as osp 10 | import json 11 | import time 12 | import datetime 13 | import tempfile 14 | import warnings 15 | from collections import defaultdict 16 | from contextlib import contextmanager 17 | 18 | DEBUG = 10 19 | INFO = 20 20 | WARN = 30 21 | ERROR = 40 22 | 23 | DISABLED = 50 24 | 25 | 26 | class KVWriter(object): 27 | def writekvs(self, kvs): 28 | raise NotImplementedError 29 | 30 | 31 | class SeqWriter(object): 32 | def writeseq(self, seq): 33 | raise NotImplementedError 34 | 35 | 36 | class HumanOutputFormat(KVWriter, SeqWriter): 37 | def __init__(self, filename_or_file): 38 | if isinstance(filename_or_file, str): 39 | self.file = open(filename_or_file, "wt") 40 | self.own_file = True 41 | else: 42 | assert hasattr(filename_or_file, "read"), ( 43 | "expected file or str, got %s" % filename_or_file 44 | ) 45 | self.file = filename_or_file 46 | self.own_file = False 47 | 48 | def writekvs(self, kvs): 49 | # Create strings for printing 50 | key2str = {} 51 | for (key, val) in sorted(kvs.items()): 52 | if hasattr(val, "__float__"): 53 | valstr = "%-8.3g" % val 54 | else: 55 | valstr = str(val) 56 | key2str[self._truncate(key)] = self._truncate(valstr) 57 | 58 | # Find max widths 59 | if len(key2str) == 0: 60 | print("WARNING: tried to write empty key-value dict") 61 | return 62 | else: 63 | keywidth = max(map(len, key2str.keys())) 64 | valwidth = max(map(len, key2str.values())) 65 | 66 | # Write out the data 67 | dashes = "-" * (keywidth + valwidth + 7) 68 | lines = [dashes] 69 | for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()): 70 | lines.append( 71 | "| %s%s | %s%s |" 72 | % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val))) 73 | ) 74 | lines.append(dashes) 75 | self.file.write("\n".join(lines) + "\n") 76 | 77 | # Flush the output to the file 78 | self.file.flush() 79 | 80 | def _truncate(self, s): 81 | maxlen = 30 82 | return s[: maxlen - 3] + "..." if len(s) > maxlen else s 83 | 84 | def writeseq(self, seq): 85 | seq = list(seq) 86 | for (i, elem) in enumerate(seq): 87 | self.file.write(elem) 88 | if i < len(seq) - 1: # add space unless this is the last one 89 | self.file.write(" ") 90 | self.file.write("\n") 91 | self.file.flush() 92 | 93 | def close(self): 94 | if self.own_file: 95 | self.file.close() 96 | 97 | 98 | class JSONOutputFormat(KVWriter): 99 | def __init__(self, filename): 100 | self.file = open(filename, "wt") 101 | 102 | def writekvs(self, kvs): 103 | for k, v in sorted(kvs.items()): 104 | if hasattr(v, "dtype"): 105 | kvs[k] = float(v) 106 | self.file.write(json.dumps(kvs) + "\n") 107 | self.file.flush() 108 | 109 | def close(self): 110 | self.file.close() 111 | 112 | 113 | class CSVOutputFormat(KVWriter): 114 | def __init__(self, filename): 115 | self.file = open(filename, "w+t") 116 | self.keys = [] 117 | self.sep = "," 118 | 119 | def writekvs(self, kvs): 120 | # Add our current row to the history 121 | extra_keys = list(kvs.keys() - self.keys) 122 | extra_keys.sort() 123 | if extra_keys: 124 | self.keys.extend(extra_keys) 125 | self.file.seek(0) 126 | lines = self.file.readlines() 127 | self.file.seek(0) 128 | for (i, k) in enumerate(self.keys): 129 | if i > 0: 130 | self.file.write(",") 131 | self.file.write(k) 132 | self.file.write("\n") 133 | for line in lines[1:]: 134 | self.file.write(line[:-1]) 135 | self.file.write(self.sep * len(extra_keys)) 136 | self.file.write("\n") 137 | for (i, k) in enumerate(self.keys): 138 | if i > 0: 139 | self.file.write(",") 140 | v = kvs.get(k) 141 | if v is not None: 142 | self.file.write(str(v)) 143 | self.file.write("\n") 144 | self.file.flush() 145 | 146 | def close(self): 147 | self.file.close() 148 | 149 | 150 | class TensorBoardOutputFormat(KVWriter): 151 | """ 152 | Dumps key/value pairs into TensorBoard's numeric format. 153 | """ 154 | 155 | def __init__(self, dir): 156 | os.makedirs(dir, exist_ok=True) 157 | self.dir = dir 158 | self.step = 1 159 | prefix = "events" 160 | path = osp.join(osp.abspath(dir), prefix) 161 | import tensorflow as tf 162 | from tensorflow.python import pywrap_tensorflow 163 | from tensorflow.core.util import event_pb2 164 | from tensorflow.python.util import compat 165 | 166 | self.tf = tf 167 | self.event_pb2 = event_pb2 168 | self.pywrap_tensorflow = pywrap_tensorflow 169 | self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path)) 170 | 171 | def writekvs(self, kvs): 172 | def summary_val(k, v): 173 | kwargs = {"tag": k, "simple_value": float(v)} 174 | return self.tf.Summary.Value(**kwargs) 175 | 176 | summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()]) 177 | event = self.event_pb2.Event(wall_time=time.time(), summary=summary) 178 | event.step = ( 179 | self.step 180 | ) # is there any reason why you'd want to specify the step? 181 | self.writer.WriteEvent(event) 182 | self.writer.Flush() 183 | self.step += 1 184 | 185 | def close(self): 186 | if self.writer: 187 | self.writer.Close() 188 | self.writer = None 189 | 190 | 191 | def make_output_format(format, ev_dir, log_suffix=""): 192 | os.makedirs(ev_dir, exist_ok=True) 193 | if format == "stdout": 194 | return HumanOutputFormat(sys.stdout) 195 | elif format == "log": 196 | return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix)) 197 | elif format == "json": 198 | return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix)) 199 | elif format == "csv": 200 | return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix)) 201 | elif format == "tensorboard": 202 | return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix)) 203 | else: 204 | raise ValueError("Unknown format specified: %s" % (format,)) 205 | 206 | 207 | # ================================================================ 208 | # API 209 | # ================================================================ 210 | 211 | 212 | def logkv(key, val): 213 | """ 214 | Log a value of some diagnostic 215 | Call this once for each diagnostic quantity, each iteration 216 | If called many times, last value will be used. 217 | """ 218 | get_current().logkv(key, val) 219 | 220 | 221 | def logkv_mean(key, val): 222 | """ 223 | The same as logkv(), but if called many times, values averaged. 224 | """ 225 | get_current().logkv_mean(key, val) 226 | 227 | 228 | def logkvs(d): 229 | """ 230 | Log a dictionary of key-value pairs 231 | """ 232 | for (k, v) in d.items(): 233 | logkv(k, v) 234 | 235 | 236 | def dumpkvs(): 237 | """ 238 | Write all of the diagnostics from the current iteration 239 | """ 240 | return get_current().dumpkvs() 241 | 242 | 243 | def getkvs(): 244 | return get_current().name2val 245 | 246 | 247 | def log(*args, level=INFO): 248 | """ 249 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). 250 | """ 251 | get_current().log(*args, level=level) 252 | 253 | 254 | def debug(*args): 255 | log(*args, level=DEBUG) 256 | 257 | 258 | def info(*args): 259 | log(*args, level=INFO) 260 | 261 | 262 | def warn(*args): 263 | log(*args, level=WARN) 264 | 265 | 266 | def error(*args): 267 | log(*args, level=ERROR) 268 | 269 | 270 | def set_level(level): 271 | """ 272 | Set logging threshold on current logger. 273 | """ 274 | get_current().set_level(level) 275 | 276 | 277 | def set_comm(comm): 278 | get_current().set_comm(comm) 279 | 280 | 281 | def get_dir(): 282 | """ 283 | Get directory that log files are being written to. 284 | will be None if there is no output directory (i.e., if you didn't call start) 285 | """ 286 | return get_current().get_dir() 287 | 288 | 289 | record_tabular = logkv 290 | dump_tabular = dumpkvs 291 | 292 | 293 | @contextmanager 294 | def profile_kv(scopename): 295 | logkey = "wait_" + scopename 296 | tstart = time.time() 297 | try: 298 | yield 299 | finally: 300 | get_current().name2val[logkey] += time.time() - tstart 301 | 302 | 303 | def profile(n): 304 | """ 305 | Usage: 306 | @profile("my_func") 307 | def my_func(): code 308 | """ 309 | 310 | def decorator_with_name(func): 311 | def func_wrapper(*args, **kwargs): 312 | with profile_kv(n): 313 | return func(*args, **kwargs) 314 | 315 | return func_wrapper 316 | 317 | return decorator_with_name 318 | 319 | 320 | # ================================================================ 321 | # Backend 322 | # ================================================================ 323 | 324 | 325 | def get_current(): 326 | if Logger.CURRENT is None: 327 | _configure_default_logger() 328 | 329 | return Logger.CURRENT 330 | 331 | 332 | class Logger(object): 333 | DEFAULT = None # A logger with no output files. (See right below class definition) 334 | # So that you can still log to the terminal without setting up any output files 335 | CURRENT = None # Current logger being used by the free functions above 336 | 337 | def __init__(self, dir, output_formats, comm=None): 338 | self.name2val = defaultdict(float) # values this iteration 339 | self.name2cnt = defaultdict(int) 340 | self.level = INFO 341 | self.dir = dir 342 | self.output_formats = output_formats 343 | self.comm = comm 344 | 345 | # Logging API, forwarded 346 | # ---------------------------------------- 347 | def logkv(self, key, val): 348 | self.name2val[key] = val 349 | 350 | def logkv_mean(self, key, val): 351 | oldval, cnt = self.name2val[key], self.name2cnt[key] 352 | self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1) 353 | self.name2cnt[key] = cnt + 1 354 | 355 | def dumpkvs(self): 356 | if self.comm is None: 357 | d = self.name2val 358 | else: 359 | d = mpi_weighted_mean( 360 | self.comm, 361 | { 362 | name: (val, self.name2cnt.get(name, 1)) 363 | for (name, val) in self.name2val.items() 364 | }, 365 | ) 366 | if self.comm.rank != 0: 367 | d["dummy"] = 1 # so we don't get a warning about empty dict 368 | out = d.copy() # Return the dict for unit testing purposes 369 | for fmt in self.output_formats: 370 | if isinstance(fmt, KVWriter): 371 | fmt.writekvs(d) 372 | self.name2val.clear() 373 | self.name2cnt.clear() 374 | return out 375 | 376 | def log(self, *args, level=INFO): 377 | if self.level <= level: 378 | self._do_log(args) 379 | 380 | # Configuration 381 | # ---------------------------------------- 382 | def set_level(self, level): 383 | self.level = level 384 | 385 | def set_comm(self, comm): 386 | self.comm = comm 387 | 388 | def get_dir(self): 389 | return self.dir 390 | 391 | def close(self): 392 | for fmt in self.output_formats: 393 | fmt.close() 394 | 395 | # Misc 396 | # ---------------------------------------- 397 | def _do_log(self, args): 398 | for fmt in self.output_formats: 399 | if isinstance(fmt, SeqWriter): 400 | fmt.writeseq(map(str, args)) 401 | 402 | 403 | def get_rank_without_mpi_import(): 404 | # check environment variables here instead of importing mpi4py 405 | # to avoid calling MPI_Init() when this module is imported 406 | for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]: 407 | if varname in os.environ: 408 | return int(os.environ[varname]) 409 | return 0 410 | 411 | 412 | def mpi_weighted_mean(comm, local_name2valcount): 413 | """ 414 | Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110 415 | Perform a weighted average over dicts that are each on a different node 416 | Input: local_name2valcount: dict mapping key -> (value, count) 417 | Returns: key -> mean 418 | """ 419 | all_name2valcount = comm.gather(local_name2valcount) 420 | if comm.rank == 0: 421 | name2sum = defaultdict(float) 422 | name2count = defaultdict(float) 423 | for n2vc in all_name2valcount: 424 | for (name, (val, count)) in n2vc.items(): 425 | try: 426 | val = float(val) 427 | except ValueError: 428 | if comm.rank == 0: 429 | warnings.warn( 430 | "WARNING: tried to compute mean on non-float {}={}".format( 431 | name, val 432 | ) 433 | ) 434 | else: 435 | name2sum[name] += val * count 436 | name2count[name] += count 437 | return {name: name2sum[name] / name2count[name] for name in name2sum} 438 | else: 439 | return {} 440 | 441 | 442 | def configure(dir=None, format_strs=None, comm=None, log_suffix=""): 443 | """ 444 | If comm is provided, average all numerical stats across that comm 445 | """ 446 | if dir is None: 447 | dir = os.getenv("OPENAI_LOGDIR") 448 | if dir is None: 449 | dir = osp.join( 450 | tempfile.gettempdir(), 451 | datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"), 452 | ) 453 | assert isinstance(dir, str) 454 | dir = os.path.expanduser(dir) 455 | os.makedirs(os.path.expanduser(dir), exist_ok=True) 456 | 457 | rank = get_rank_without_mpi_import() 458 | if rank > 0: 459 | log_suffix = log_suffix + "-rank%03i" % rank 460 | 461 | if format_strs is None: 462 | if rank == 0: 463 | format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",") 464 | else: 465 | format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",") 466 | format_strs = filter(None, format_strs) 467 | output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs] 468 | 469 | Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm) 470 | if output_formats: 471 | log("Logging to %s" % dir) 472 | 473 | 474 | def _configure_default_logger(): 475 | configure() 476 | Logger.DEFAULT = Logger.CURRENT 477 | 478 | 479 | def reset(): 480 | if Logger.CURRENT is not Logger.DEFAULT: 481 | Logger.CURRENT.close() 482 | Logger.CURRENT = Logger.DEFAULT 483 | log("Reset logger") 484 | 485 | 486 | @contextmanager 487 | def scoped_configure(dir=None, format_strs=None, comm=None): 488 | prevlogger = Logger.CURRENT 489 | configure(dir=dir, format_strs=format_strs, comm=comm) 490 | try: 491 | yield 492 | finally: 493 | Logger.CURRENT.close() 494 | Logger.CURRENT = prevlogger 495 | 496 | -------------------------------------------------------------------------------- /cm/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for various likelihood-based losses. These are ported from the original 3 | Ho et al. diffusion models codebase: 4 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py 5 | """ 6 | 7 | import numpy as np 8 | 9 | import torch as th 10 | 11 | 12 | def normal_kl(mean1, logvar1, mean2, logvar2): 13 | """ 14 | Compute the KL divergence between two gaussians. 15 | 16 | Shapes are automatically broadcasted, so batches can be compared to 17 | scalars, among other use cases. 18 | """ 19 | tensor = None 20 | for obj in (mean1, logvar1, mean2, logvar2): 21 | if isinstance(obj, th.Tensor): 22 | tensor = obj 23 | break 24 | assert tensor is not None, "at least one argument must be a Tensor" 25 | 26 | # Force variances to be Tensors. Broadcasting helps convert scalars to 27 | # Tensors, but it does not work for th.exp(). 28 | logvar1, logvar2 = [ 29 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 30 | for x in (logvar1, logvar2) 31 | ] 32 | 33 | return 0.5 * ( 34 | -1.0 35 | + logvar2 36 | - logvar1 37 | + th.exp(logvar1 - logvar2) 38 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 39 | ) 40 | 41 | 42 | def approx_standard_normal_cdf(x): 43 | """ 44 | A fast approximation of the cumulative distribution function of the 45 | standard normal. 46 | """ 47 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 48 | 49 | 50 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 51 | """ 52 | Compute the log-likelihood of a Gaussian distribution discretizing to a 53 | given image. 54 | 55 | :param x: the target images. It is assumed that this was uint8 values, 56 | rescaled to the range [-1, 1]. 57 | :param means: the Gaussian mean Tensor. 58 | :param log_scales: the Gaussian log stddev Tensor. 59 | :return: a tensor like x of log probabilities (in nats). 60 | """ 61 | assert x.shape == means.shape == log_scales.shape 62 | centered_x = x - means 63 | inv_stdv = th.exp(-log_scales) 64 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 65 | cdf_plus = approx_standard_normal_cdf(plus_in) 66 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 67 | cdf_min = approx_standard_normal_cdf(min_in) 68 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 69 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 70 | cdf_delta = cdf_plus - cdf_min 71 | log_probs = th.where( 72 | x < -0.999, 73 | log_cdf_plus, 74 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 75 | ) 76 | assert log_probs.shape == x.shape 77 | return log_probs 78 | -------------------------------------------------------------------------------- /cm/nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utilities for neural networks. 3 | """ 4 | 5 | import math 6 | 7 | import torch as th 8 | import torch.nn as nn 9 | import numpy as np 10 | import torch.nn.functional as F 11 | 12 | 13 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 14 | class SiLU(nn.Module): 15 | def forward(self, x): 16 | return x * th.sigmoid(x) 17 | 18 | 19 | class GroupNorm32(nn.GroupNorm): 20 | def forward(self, x): 21 | return super().forward(x.float()).type(x.dtype) 22 | 23 | 24 | def conv_nd(dims, *args, **kwargs): 25 | """ 26 | Create a 1D, 2D, or 3D convolution module. 27 | """ 28 | if dims == 1: 29 | return nn.Conv1d(*args, **kwargs) 30 | elif dims == 2: 31 | return nn.Conv2d(*args, **kwargs) 32 | elif dims == 3: 33 | return nn.Conv3d(*args, **kwargs) 34 | raise ValueError(f"unsupported dimensions: {dims}") 35 | 36 | 37 | def linear(*args, **kwargs): 38 | """ 39 | Create a linear module. 40 | """ 41 | return nn.Linear(*args, **kwargs) 42 | 43 | 44 | def avg_pool_nd(dims, *args, **kwargs): 45 | """ 46 | Create a 1D, 2D, or 3D average pooling module. 47 | """ 48 | if dims == 1: 49 | return nn.AvgPool1d(*args, **kwargs) 50 | elif dims == 2: 51 | return nn.AvgPool2d(*args, **kwargs) 52 | elif dims == 3: 53 | return nn.AvgPool3d(*args, **kwargs) 54 | raise ValueError(f"unsupported dimensions: {dims}") 55 | 56 | 57 | def update_ema(target_params, source_params, rate=0.99): 58 | """ 59 | Update target parameters to be closer to those of source parameters using 60 | an exponential moving average. 61 | 62 | :param target_params: the target parameter sequence. 63 | :param source_params: the source parameter sequence. 64 | :param rate: the EMA rate (closer to 1 means slower). 65 | """ 66 | for targ, src in zip(target_params, source_params): 67 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 68 | 69 | 70 | def zero_module(module): 71 | """ 72 | Zero out the parameters of a module and return it. 73 | """ 74 | for p in module.parameters(): 75 | p.detach().zero_() 76 | return module 77 | 78 | 79 | def scale_module(module, scale): 80 | """ 81 | Scale the parameters of a module and return it. 82 | """ 83 | for p in module.parameters(): 84 | p.detach().mul_(scale) 85 | return module 86 | 87 | 88 | def mean_flat(tensor): 89 | """ 90 | Take the mean over all non-batch dimensions. 91 | """ 92 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 93 | 94 | 95 | def append_dims(x, target_dims): 96 | """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" 97 | dims_to_append = target_dims - x.ndim 98 | if dims_to_append < 0: 99 | raise ValueError( 100 | f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" 101 | ) 102 | return x[(...,) + (None,) * dims_to_append] 103 | 104 | 105 | def append_zero(x): 106 | return th.cat([x, x.new_zeros([1])]) 107 | 108 | 109 | def normalization(channels): 110 | """ 111 | Make a standard normalization layer. 112 | 113 | :param channels: number of input channels. 114 | :return: an nn.Module for normalization. 115 | """ 116 | return GroupNorm32(32, channels) 117 | 118 | 119 | def timestep_embedding(timesteps, dim, max_period=10000): 120 | """ 121 | Create sinusoidal timestep embeddings. 122 | 123 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 124 | These may be fractional. 125 | :param dim: the dimension of the output. 126 | :param max_period: controls the minimum frequency of the embeddings. 127 | :return: an [N x dim] Tensor of positional embeddings. 128 | """ 129 | half = dim // 2 130 | freqs = th.exp( 131 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half 132 | ).to(device=timesteps.device) 133 | args = timesteps[:, None].float() * freqs[None] 134 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 135 | if dim % 2: 136 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 137 | return embedding 138 | 139 | 140 | def checkpoint(func, inputs, params, flag): 141 | """ 142 | Evaluate a function without caching intermediate activations, allowing for 143 | reduced memory at the expense of extra compute in the backward pass. 144 | 145 | :param func: the function to evaluate. 146 | :param inputs: the argument sequence to pass to `func`. 147 | :param params: a sequence of parameters `func` depends on but does not 148 | explicitly take as arguments. 149 | :param flag: if False, disable gradient checkpointing. 150 | """ 151 | if flag: 152 | args = tuple(inputs) + tuple(params) 153 | return CheckpointFunction.apply(func, len(inputs), *args) 154 | else: 155 | return func(*inputs) 156 | 157 | 158 | class CheckpointFunction(th.autograd.Function): 159 | @staticmethod 160 | def forward(ctx, run_function, length, *args): 161 | ctx.run_function = run_function 162 | ctx.input_tensors = list(args[:length]) 163 | ctx.input_params = list(args[length:]) 164 | with th.no_grad(): 165 | output_tensors = ctx.run_function(*ctx.input_tensors) 166 | return output_tensors 167 | 168 | @staticmethod 169 | def backward(ctx, *output_grads): 170 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 171 | with th.enable_grad(): 172 | # Fixes a bug where the first op in run_function modifies the 173 | # Tensor storage in place, which is not allowed for detach()'d 174 | # Tensors. 175 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 176 | output_tensors = ctx.run_function(*shallow_copies) 177 | input_grads = th.autograd.grad( 178 | output_tensors, 179 | ctx.input_tensors + ctx.input_params, 180 | output_grads, 181 | allow_unused=True, 182 | ) 183 | del ctx.input_tensors 184 | del ctx.input_params 185 | del output_tensors 186 | return (None, None) + input_grads 187 | -------------------------------------------------------------------------------- /cm/random_util.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import torch.distributed as dist 3 | from . import dist_util 4 | 5 | 6 | def get_generator(generator, num_samples=0, seed=0): 7 | if generator == "dummy": 8 | return DummyGenerator() 9 | elif generator == "determ": 10 | return DeterministicGenerator(num_samples, seed) 11 | elif generator == "determ-indiv": 12 | return DeterministicIndividualGenerator(num_samples, seed) 13 | else: 14 | raise NotImplementedError 15 | 16 | 17 | class DummyGenerator: 18 | def randn(self, *args, **kwargs): 19 | return th.randn(*args, **kwargs) 20 | 21 | def randint(self, *args, **kwargs): 22 | return th.randint(*args, **kwargs) 23 | 24 | def randn_like(self, *args, **kwargs): 25 | return th.randn_like(*args, **kwargs) 26 | 27 | 28 | class DeterministicGenerator: 29 | """ 30 | RNG to deterministically sample num_samples samples that does not depend on batch_size or mpi_machines 31 | Uses a single rng and samples num_samples sized randomness and subsamples the current indices 32 | """ 33 | 34 | def __init__(self, num_samples, seed=0): 35 | if dist.is_initialized(): 36 | self.rank = dist.get_rank() 37 | self.world_size = dist.get_world_size() 38 | else: 39 | print("Warning: Distributed not initialised, using single rank") 40 | self.rank = 0 41 | self.world_size = 1 42 | self.num_samples = num_samples 43 | self.done_samples = 0 44 | self.seed = seed 45 | self.rng_cpu = th.Generator() 46 | if th.cuda.is_available(): 47 | self.rng_cuda = th.Generator(dist_util.dev()) 48 | self.set_seed(seed) 49 | 50 | def get_global_size_and_indices(self, size): 51 | global_size = (self.num_samples, *size[1:]) 52 | indices = th.arange( 53 | self.done_samples + self.rank, 54 | self.done_samples + self.world_size * int(size[0]), 55 | self.world_size, 56 | ) 57 | indices = th.clamp(indices, 0, self.num_samples - 1) 58 | assert ( 59 | len(indices) == size[0] 60 | ), f"rank={self.rank}, ws={self.world_size}, l={len(indices)}, bs={size[0]}" 61 | return global_size, indices 62 | 63 | def get_generator(self, device): 64 | return self.rng_cpu if th.device(device).type == "cpu" else self.rng_cuda 65 | 66 | def randn(self, *size, dtype=th.float, device="cpu"): 67 | global_size, indices = self.get_global_size_and_indices(size) 68 | generator = self.get_generator(device) 69 | return th.randn(*global_size, generator=generator, dtype=dtype, device=device)[ 70 | indices 71 | ] 72 | 73 | def randint(self, low, high, size, dtype=th.long, device="cpu"): 74 | global_size, indices = self.get_global_size_and_indices(size) 75 | generator = self.get_generator(device) 76 | return th.randint( 77 | low, high, generator=generator, size=global_size, dtype=dtype, device=device 78 | )[indices] 79 | 80 | def randn_like(self, tensor): 81 | size, dtype, device = tensor.size(), tensor.dtype, tensor.device 82 | return self.randn(*size, dtype=dtype, device=device) 83 | 84 | def set_done_samples(self, done_samples): 85 | self.done_samples = done_samples 86 | self.set_seed(self.seed) 87 | 88 | def get_seed(self): 89 | return self.seed 90 | 91 | def set_seed(self, seed): 92 | self.rng_cpu.manual_seed(seed) 93 | if th.cuda.is_available(): 94 | self.rng_cuda.manual_seed(seed) 95 | 96 | 97 | class DeterministicIndividualGenerator: 98 | """ 99 | RNG to deterministically sample num_samples samples that does not depend on batch_size or mpi_machines 100 | Uses a separate rng for each sample to reduce memoery usage 101 | """ 102 | 103 | def __init__(self, num_samples, seed=0): 104 | if dist.is_initialized(): 105 | self.rank = dist.get_rank() 106 | self.world_size = dist.get_world_size() 107 | else: 108 | print("Warning: Distributed not initialised, using single rank") 109 | self.rank = 0 110 | self.world_size = 1 111 | self.num_samples = num_samples 112 | self.done_samples = 0 113 | self.seed = seed 114 | self.rng_cpu = [th.Generator() for _ in range(num_samples)] 115 | if th.cuda.is_available(): 116 | self.rng_cuda = [th.Generator(dist_util.dev()) for _ in range(num_samples)] 117 | self.set_seed(seed) 118 | 119 | def get_size_and_indices(self, size): 120 | indices = th.arange( 121 | self.done_samples + self.rank, 122 | self.done_samples + self.world_size * int(size[0]), 123 | self.world_size, 124 | ) 125 | indices = th.clamp(indices, 0, self.num_samples - 1) 126 | assert ( 127 | len(indices) == size[0] 128 | ), f"rank={self.rank}, ws={self.world_size}, l={len(indices)}, bs={size[0]}" 129 | return (1, *size[1:]), indices 130 | 131 | def get_generator(self, device): 132 | return self.rng_cpu if th.device(device).type == "cpu" else self.rng_cuda 133 | 134 | def randn(self, *size, dtype=th.float, device="cpu"): 135 | size, indices = self.get_size_and_indices(size) 136 | generator = self.get_generator(device) 137 | return th.cat( 138 | [ 139 | th.randn(*size, generator=generator[i], dtype=dtype, device=device) 140 | for i in indices 141 | ], 142 | dim=0, 143 | ) 144 | 145 | def randint(self, low, high, size, dtype=th.long, device="cpu"): 146 | size, indices = self.get_size_and_indices(size) 147 | generator = self.get_generator(device) 148 | return th.cat( 149 | [ 150 | th.randint( 151 | low, 152 | high, 153 | generator=generator[i], 154 | size=size, 155 | dtype=dtype, 156 | device=device, 157 | ) 158 | for i in indices 159 | ], 160 | dim=0, 161 | ) 162 | 163 | def randn_like(self, tensor): 164 | size, dtype, device = tensor.size(), tensor.dtype, tensor.device 165 | return self.randn(*size, dtype=dtype, device=device) 166 | 167 | def set_done_samples(self, done_samples): 168 | self.done_samples = done_samples 169 | 170 | def get_seed(self): 171 | return self.seed 172 | 173 | def set_seed(self, seed): 174 | [ 175 | rng_cpu.manual_seed(i + self.num_samples * seed) 176 | for i, rng_cpu in enumerate(self.rng_cpu) 177 | ] 178 | if th.cuda.is_available(): 179 | [ 180 | rng_cuda.manual_seed(i + self.num_samples * seed) 181 | for i, rng_cuda in enumerate(self.rng_cuda) 182 | ] 183 | -------------------------------------------------------------------------------- /cm/resample.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import numpy as np 4 | import torch as th 5 | from scipy.stats import norm 6 | import torch.distributed as dist 7 | 8 | 9 | def create_named_schedule_sampler(name, diffusion): 10 | """ 11 | Create a ScheduleSampler from a library of pre-defined samplers. 12 | 13 | :param name: the name of the sampler. 14 | :param diffusion: the diffusion object to sample for. 15 | """ 16 | if name == "uniform": 17 | return UniformSampler(diffusion) 18 | elif name == "loss-second-moment": 19 | return LossSecondMomentResampler(diffusion) 20 | elif name == "lognormal": 21 | return LogNormalSampler() 22 | else: 23 | raise NotImplementedError(f"unknown schedule sampler: {name}") 24 | 25 | 26 | class ScheduleSampler(ABC): 27 | """ 28 | A distribution over timesteps in the diffusion process, intended to reduce 29 | variance of the objective. 30 | 31 | By default, samplers perform unbiased importance sampling, in which the 32 | objective's mean is unchanged. 33 | However, subclasses may override sample() to change how the resampled 34 | terms are reweighted, allowing for actual changes in the objective. 35 | """ 36 | 37 | @abstractmethod 38 | def weights(self): 39 | """ 40 | Get a numpy array of weights, one per diffusion step. 41 | 42 | The weights needn't be normalized, but must be positive. 43 | """ 44 | 45 | def sample(self, batch_size, device): 46 | """ 47 | Importance-sample timesteps for a batch. 48 | 49 | :param batch_size: the number of timesteps. 50 | :param device: the torch device to save to. 51 | :return: a tuple (timesteps, weights): 52 | - timesteps: a tensor of timestep indices. 53 | - weights: a tensor of weights to scale the resulting losses. 54 | """ 55 | w = self.weights() 56 | p = w / np.sum(w) 57 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 58 | indices = th.from_numpy(indices_np).long().to(device) 59 | weights_np = 1 / (len(p) * p[indices_np]) 60 | weights = th.from_numpy(weights_np).float().to(device) 61 | return indices, weights 62 | 63 | 64 | class UniformSampler(ScheduleSampler): 65 | def __init__(self, diffusion): 66 | self.diffusion = diffusion 67 | self._weights = np.ones([diffusion.num_timesteps]) 68 | 69 | def weights(self): 70 | return self._weights 71 | 72 | 73 | class LossAwareSampler(ScheduleSampler): 74 | def update_with_local_losses(self, local_ts, local_losses): 75 | """ 76 | Update the reweighting using losses from a model. 77 | 78 | Call this method from each rank with a batch of timesteps and the 79 | corresponding losses for each of those timesteps. 80 | This method will perform synchronization to make sure all of the ranks 81 | maintain the exact same reweighting. 82 | 83 | :param local_ts: an integer Tensor of timesteps. 84 | :param local_losses: a 1D Tensor of losses. 85 | """ 86 | batch_sizes = [ 87 | th.tensor([0], dtype=th.int32, device=local_ts.device) 88 | for _ in range(dist.get_world_size()) 89 | ] 90 | dist.all_gather( 91 | batch_sizes, 92 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 93 | ) 94 | 95 | # Pad all_gather batches to be the maximum batch size. 96 | batch_sizes = [x.item() for x in batch_sizes] 97 | max_bs = max(batch_sizes) 98 | 99 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] 100 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] 101 | dist.all_gather(timestep_batches, local_ts) 102 | dist.all_gather(loss_batches, local_losses) 103 | timesteps = [ 104 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 105 | ] 106 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 107 | self.update_with_all_losses(timesteps, losses) 108 | 109 | @abstractmethod 110 | def update_with_all_losses(self, ts, losses): 111 | """ 112 | Update the reweighting using losses from a model. 113 | 114 | Sub-classes should override this method to update the reweighting 115 | using losses from the model. 116 | 117 | This method directly updates the reweighting without synchronizing 118 | between workers. It is called by update_with_local_losses from all 119 | ranks with identical arguments. Thus, it should have deterministic 120 | behavior to maintain state across workers. 121 | 122 | :param ts: a list of int timesteps. 123 | :param losses: a list of float losses, one per timestep. 124 | """ 125 | 126 | 127 | class LossSecondMomentResampler(LossAwareSampler): 128 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 129 | self.diffusion = diffusion 130 | self.history_per_term = history_per_term 131 | self.uniform_prob = uniform_prob 132 | self._loss_history = np.zeros( 133 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 134 | ) 135 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 136 | 137 | def weights(self): 138 | if not self._warmed_up(): 139 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 140 | weights = np.sqrt(np.mean(self._loss_history**2, axis=-1)) 141 | weights /= np.sum(weights) 142 | weights *= 1 - self.uniform_prob 143 | weights += self.uniform_prob / len(weights) 144 | return weights 145 | 146 | def update_with_all_losses(self, ts, losses): 147 | for t, loss in zip(ts, losses): 148 | if self._loss_counts[t] == self.history_per_term: 149 | # Shift out the oldest loss term. 150 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 151 | self._loss_history[t, -1] = loss 152 | else: 153 | self._loss_history[t, self._loss_counts[t]] = loss 154 | self._loss_counts[t] += 1 155 | 156 | def _warmed_up(self): 157 | return (self._loss_counts == self.history_per_term).all() 158 | 159 | 160 | class LogNormalSampler: 161 | def __init__(self, p_mean=-1.2, p_std=1.2, even=False): 162 | self.p_mean = p_mean 163 | self.p_std = p_std 164 | self.even = even 165 | if self.even: 166 | self.inv_cdf = lambda x: norm.ppf(x, loc=p_mean, scale=p_std) 167 | self.rank, self.size = dist.get_rank(), dist.get_world_size() 168 | 169 | def sample(self, bs, device): 170 | if self.even: 171 | # buckets = [1/G] 172 | start_i, end_i = self.rank * bs, (self.rank + 1) * bs 173 | global_batch_size = self.size * bs 174 | locs = (th.arange(start_i, end_i) + th.rand(bs)) / global_batch_size 175 | log_sigmas = th.tensor(self.inv_cdf(locs), dtype=th.float32, device=device) 176 | else: 177 | log_sigmas = self.p_mean + self.p_std * th.randn(bs, device=device) 178 | sigmas = th.exp(log_sigmas) 179 | weights = th.ones_like(sigmas) 180 | return sigmas, weights 181 | -------------------------------------------------------------------------------- /cm/script_util.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from .karras_diffusion import KarrasDenoiser 4 | from .unet import UNetModel 5 | import numpy as np 6 | 7 | NUM_CLASSES = 1000 8 | 9 | 10 | def cm_train_defaults(): 11 | return dict( 12 | teacher_model_path="", 13 | teacher_dropout=0.1, 14 | training_mode="consistency_distillation", 15 | target_ema_mode="fixed", 16 | scale_mode="fixed", 17 | total_training_steps=600000, 18 | start_ema=0.0, 19 | start_scales=40, 20 | end_scales=40, 21 | distill_steps_per_iter=50000, 22 | loss_norm="lpips", 23 | ) 24 | 25 | 26 | def model_and_diffusion_defaults(): 27 | """ 28 | Defaults for image training. 29 | """ 30 | res = dict( 31 | sigma_min=0.002, 32 | sigma_max=80.0, 33 | image_size=64, 34 | num_channels=128, 35 | num_res_blocks=2, 36 | num_heads=4, 37 | num_heads_upsample=-1, 38 | num_head_channels=-1, 39 | attention_resolutions="32,16,8", 40 | channel_mult="", 41 | dropout=0.0, 42 | class_cond=False, 43 | use_checkpoint=False, 44 | use_scale_shift_norm=True, 45 | resblock_updown=False, 46 | use_fp16=False, 47 | use_new_attention_order=False, 48 | learn_sigma=False, 49 | weight_schedule="karras", 50 | ) 51 | return res 52 | 53 | 54 | def create_model_and_diffusion( 55 | image_size, 56 | class_cond, 57 | learn_sigma, 58 | num_channels, 59 | num_res_blocks, 60 | channel_mult, 61 | num_heads, 62 | num_head_channels, 63 | num_heads_upsample, 64 | attention_resolutions, 65 | dropout, 66 | use_checkpoint, 67 | use_scale_shift_norm, 68 | resblock_updown, 69 | use_fp16, 70 | use_new_attention_order, 71 | weight_schedule, 72 | sigma_min=0.002, 73 | sigma_max=80.0, 74 | distillation=False, 75 | ): 76 | model = create_model( 77 | image_size, 78 | num_channels, 79 | num_res_blocks, 80 | channel_mult=channel_mult, 81 | learn_sigma=learn_sigma, 82 | class_cond=class_cond, 83 | use_checkpoint=use_checkpoint, 84 | attention_resolutions=attention_resolutions, 85 | num_heads=num_heads, 86 | num_head_channels=num_head_channels, 87 | num_heads_upsample=num_heads_upsample, 88 | use_scale_shift_norm=use_scale_shift_norm, 89 | dropout=dropout, 90 | resblock_updown=resblock_updown, 91 | use_fp16=use_fp16, 92 | use_new_attention_order=use_new_attention_order, 93 | ) 94 | diffusion = KarrasDenoiser( 95 | sigma_data=0.5, 96 | sigma_max=sigma_max, 97 | sigma_min=sigma_min, 98 | distillation=distillation, 99 | weight_schedule=weight_schedule, 100 | ) 101 | return model, diffusion 102 | 103 | 104 | def create_model( 105 | image_size, 106 | num_channels, 107 | num_res_blocks, 108 | channel_mult="", 109 | learn_sigma=False, 110 | class_cond=False, 111 | use_checkpoint=False, 112 | attention_resolutions="16", 113 | num_heads=1, 114 | num_head_channels=-1, 115 | num_heads_upsample=-1, 116 | use_scale_shift_norm=False, 117 | dropout=0, 118 | resblock_updown=False, 119 | use_fp16=False, 120 | use_new_attention_order=False, 121 | ): 122 | if channel_mult == "": 123 | if image_size == 512: 124 | channel_mult = (0.5, 1, 1, 2, 2, 4, 4) 125 | elif image_size == 256: 126 | channel_mult = (1, 1, 2, 2, 4, 4) 127 | elif image_size == 128: 128 | channel_mult = (1, 1, 2, 3, 4) 129 | elif image_size == 64: 130 | channel_mult = (1, 2, 3, 4) 131 | else: 132 | raise ValueError(f"unsupported image size: {image_size}") 133 | else: 134 | channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(",")) 135 | 136 | attention_ds = [] 137 | for res in attention_resolutions.split(","): 138 | attention_ds.append(image_size // int(res)) 139 | 140 | return UNetModel( 141 | image_size=image_size, 142 | in_channels=3, 143 | model_channels=num_channels, 144 | out_channels=(3 if not learn_sigma else 6), 145 | num_res_blocks=num_res_blocks, 146 | attention_resolutions=tuple(attention_ds), 147 | dropout=dropout, 148 | channel_mult=channel_mult, 149 | num_classes=(NUM_CLASSES if class_cond else None), 150 | use_checkpoint=use_checkpoint, 151 | use_fp16=use_fp16, 152 | num_heads=num_heads, 153 | num_head_channels=num_head_channels, 154 | num_heads_upsample=num_heads_upsample, 155 | use_scale_shift_norm=use_scale_shift_norm, 156 | resblock_updown=resblock_updown, 157 | use_new_attention_order=use_new_attention_order, 158 | ) 159 | 160 | 161 | def create_ema_and_scales_fn( 162 | target_ema_mode, 163 | start_ema, 164 | scale_mode, 165 | start_scales, 166 | end_scales, 167 | total_steps, 168 | distill_steps_per_iter, 169 | ): 170 | def ema_and_scales_fn(step): 171 | if target_ema_mode == "fixed" and scale_mode == "fixed": 172 | target_ema = start_ema 173 | scales = start_scales 174 | elif target_ema_mode == "fixed" and scale_mode == "progressive": 175 | target_ema = start_ema 176 | scales = np.ceil( 177 | np.sqrt( 178 | (step / total_steps) * ((end_scales + 1) ** 2 - start_scales**2) 179 | + start_scales**2 180 | ) 181 | - 1 182 | ).astype(np.int32) 183 | scales = np.maximum(scales, 1) 184 | scales = scales + 1 185 | 186 | elif target_ema_mode == "adaptive" and scale_mode == "progressive": 187 | scales = np.ceil( 188 | np.sqrt( 189 | (step / total_steps) * ((end_scales + 1) ** 2 - start_scales**2) 190 | + start_scales**2 191 | ) 192 | - 1 193 | ).astype(np.int32) 194 | scales = np.maximum(scales, 1) 195 | c = -np.log(start_ema) * start_scales 196 | target_ema = np.exp(-c / scales) 197 | scales = scales + 1 198 | elif target_ema_mode == "fixed" and scale_mode == "progdist": 199 | distill_stage = step // distill_steps_per_iter 200 | scales = start_scales // (2**distill_stage) 201 | scales = np.maximum(scales, 2) 202 | 203 | sub_stage = np.maximum( 204 | step - distill_steps_per_iter * (np.log2(start_scales) - 1), 205 | 0, 206 | ) 207 | sub_stage = sub_stage // (distill_steps_per_iter * 2) 208 | sub_scales = 2 // (2**sub_stage) 209 | sub_scales = np.maximum(sub_scales, 1) 210 | 211 | scales = np.where(scales == 2, sub_scales, scales) 212 | 213 | target_ema = 1.0 214 | else: 215 | raise NotImplementedError 216 | 217 | return float(target_ema), int(scales) 218 | 219 | return ema_and_scales_fn 220 | 221 | 222 | def add_dict_to_argparser(parser, default_dict): 223 | for k, v in default_dict.items(): 224 | v_type = type(v) 225 | if v is None: 226 | v_type = str 227 | elif isinstance(v, bool): 228 | v_type = str2bool 229 | parser.add_argument(f"--{k}", default=v, type=v_type) 230 | 231 | 232 | def args_to_dict(args, keys): 233 | return {k: getattr(args, k) for k in keys} 234 | 235 | 236 | def str2bool(v): 237 | """ 238 | https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse 239 | """ 240 | if isinstance(v, bool): 241 | return v 242 | if v.lower() in ("yes", "true", "t", "y", "1"): 243 | return True 244 | elif v.lower() in ("no", "false", "f", "n", "0"): 245 | return False 246 | else: 247 | raise argparse.ArgumentTypeError("boolean value expected") 248 | -------------------------------------------------------------------------------- /cm/train_util.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import functools 3 | import os 4 | 5 | import blobfile as bf 6 | import torch as th 7 | import torch.distributed as dist 8 | from torch.nn.parallel.distributed import DistributedDataParallel as DDP 9 | from torch.optim import RAdam 10 | 11 | from . import dist_util, logger 12 | from .fp16_util import MixedPrecisionTrainer 13 | from .nn import update_ema 14 | from .resample import LossAwareSampler, UniformSampler 15 | 16 | from .fp16_util import ( 17 | get_param_groups_and_shapes, 18 | make_master_params, 19 | master_params_to_model_params, 20 | ) 21 | import numpy as np 22 | 23 | # For ImageNet experiments, this was a good default value. 24 | # We found that the lg_loss_scale quickly climbed to 25 | # 20-21 within the first ~1K steps of training. 26 | INITIAL_LOG_LOSS_SCALE = 20.0 27 | 28 | 29 | class TrainLoop: 30 | def __init__( 31 | self, 32 | *, 33 | model, 34 | diffusion, 35 | data, 36 | batch_size, 37 | microbatch, 38 | lr, 39 | ema_rate, 40 | log_interval, 41 | save_interval, 42 | resume_checkpoint, 43 | use_fp16=False, 44 | fp16_scale_growth=1e-3, 45 | schedule_sampler=None, 46 | weight_decay=0.0, 47 | lr_anneal_steps=0, 48 | ): 49 | self.model = model 50 | self.diffusion = diffusion 51 | self.data = data 52 | self.batch_size = batch_size 53 | self.microbatch = microbatch if microbatch > 0 else batch_size 54 | self.lr = lr 55 | self.ema_rate = ( 56 | [ema_rate] 57 | if isinstance(ema_rate, float) 58 | else [float(x) for x in ema_rate.split(",")] 59 | ) 60 | self.log_interval = log_interval 61 | self.save_interval = save_interval 62 | self.resume_checkpoint = resume_checkpoint 63 | self.use_fp16 = use_fp16 64 | self.fp16_scale_growth = fp16_scale_growth 65 | self.schedule_sampler = schedule_sampler or UniformSampler(diffusion) 66 | self.weight_decay = weight_decay 67 | self.lr_anneal_steps = lr_anneal_steps 68 | 69 | self.step = 0 70 | self.resume_step = 0 71 | self.global_batch = self.batch_size * dist.get_world_size() 72 | 73 | self.sync_cuda = th.cuda.is_available() 74 | 75 | self._load_and_sync_parameters() 76 | self.mp_trainer = MixedPrecisionTrainer( 77 | model=self.model, 78 | use_fp16=self.use_fp16, 79 | fp16_scale_growth=fp16_scale_growth, 80 | ) 81 | 82 | self.opt = RAdam( 83 | self.mp_trainer.master_params, lr=self.lr, weight_decay=self.weight_decay 84 | ) 85 | if self.resume_step: 86 | self._load_optimizer_state() 87 | # Model was resumed, either due to a restart or a checkpoint 88 | # being specified at the command line. 89 | self.ema_params = [ 90 | self._load_ema_parameters(rate) for rate in self.ema_rate 91 | ] 92 | else: 93 | self.ema_params = [ 94 | copy.deepcopy(self.mp_trainer.master_params) 95 | for _ in range(len(self.ema_rate)) 96 | ] 97 | 98 | if th.cuda.is_available(): 99 | self.use_ddp = True 100 | self.ddp_model = DDP( 101 | self.model, 102 | device_ids=[dist_util.dev()], 103 | output_device=dist_util.dev(), 104 | broadcast_buffers=False, 105 | bucket_cap_mb=128, 106 | find_unused_parameters=False, 107 | ) 108 | else: 109 | if dist.get_world_size() > 1: 110 | logger.warn( 111 | "Distributed training requires CUDA. " 112 | "Gradients will not be synchronized properly!" 113 | ) 114 | self.use_ddp = False 115 | self.ddp_model = self.model 116 | 117 | self.step = self.resume_step 118 | 119 | def _load_and_sync_parameters(self): 120 | resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 121 | 122 | if resume_checkpoint: 123 | self.resume_step = parse_resume_step_from_filename(resume_checkpoint) 124 | if dist.get_rank() == 0: 125 | logger.log(f"loading model from checkpoint: {resume_checkpoint}...") 126 | self.model.load_state_dict( 127 | dist_util.load_state_dict( 128 | resume_checkpoint, map_location=dist_util.dev() 129 | ), 130 | ) 131 | 132 | dist_util.sync_params(self.model.parameters()) 133 | dist_util.sync_params(self.model.buffers()) 134 | 135 | def _load_ema_parameters(self, rate): 136 | ema_params = copy.deepcopy(self.mp_trainer.master_params) 137 | 138 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 139 | ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate) 140 | if ema_checkpoint: 141 | if dist.get_rank() == 0: 142 | logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...") 143 | state_dict = dist_util.load_state_dict( 144 | ema_checkpoint, map_location=dist_util.dev() 145 | ) 146 | ema_params = self.mp_trainer.state_dict_to_master_params(state_dict) 147 | 148 | dist_util.sync_params(ema_params) 149 | return ema_params 150 | 151 | def _load_optimizer_state(self): 152 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 153 | opt_checkpoint = bf.join( 154 | bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt" 155 | ) 156 | if bf.exists(opt_checkpoint): 157 | logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}") 158 | state_dict = dist_util.load_state_dict( 159 | opt_checkpoint, map_location=dist_util.dev() 160 | ) 161 | self.opt.load_state_dict(state_dict) 162 | 163 | def run_loop(self): 164 | while not self.lr_anneal_steps or self.step < self.lr_anneal_steps: 165 | batch, cond = next(self.data) 166 | self.run_step(batch, cond) 167 | if self.step % self.log_interval == 0: 168 | logger.dumpkvs() 169 | if self.step % self.save_interval == 0: 170 | self.save() 171 | # Run for a finite amount of time in integration tests. 172 | if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0: 173 | return 174 | # Save the last checkpoint if it wasn't already saved. 175 | if (self.step - 1) % self.save_interval != 0: 176 | self.save() 177 | 178 | def run_step(self, batch, cond): 179 | self.forward_backward(batch, cond) 180 | took_step = self.mp_trainer.optimize(self.opt) 181 | if took_step: 182 | self.step += 1 183 | self._update_ema() 184 | self._anneal_lr() 185 | self.log_step() 186 | 187 | def forward_backward(self, batch, cond): 188 | self.mp_trainer.zero_grad() 189 | for i in range(0, batch.shape[0], self.microbatch): 190 | micro = batch[i : i + self.microbatch].to(dist_util.dev()) 191 | micro_cond = { 192 | k: v[i : i + self.microbatch].to(dist_util.dev()) 193 | for k, v in cond.items() 194 | } 195 | last_batch = (i + self.microbatch) >= batch.shape[0] 196 | t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev()) 197 | 198 | compute_losses = functools.partial( 199 | self.diffusion.training_losses, 200 | self.ddp_model, 201 | micro, 202 | t, 203 | model_kwargs=micro_cond, 204 | ) 205 | 206 | if last_batch or not self.use_ddp: 207 | losses = compute_losses() 208 | else: 209 | with self.ddp_model.no_sync(): 210 | losses = compute_losses() 211 | 212 | if isinstance(self.schedule_sampler, LossAwareSampler): 213 | self.schedule_sampler.update_with_local_losses( 214 | t, losses["loss"].detach() 215 | ) 216 | 217 | loss = (losses["loss"] * weights).mean() 218 | log_loss_dict( 219 | self.diffusion, t, {k: v * weights for k, v in losses.items()} 220 | ) 221 | self.mp_trainer.backward(loss) 222 | 223 | def _update_ema(self): 224 | for rate, params in zip(self.ema_rate, self.ema_params): 225 | update_ema(params, self.mp_trainer.master_params, rate=rate) 226 | 227 | def _anneal_lr(self): 228 | if not self.lr_anneal_steps: 229 | return 230 | frac_done = (self.step + self.resume_step) / self.lr_anneal_steps 231 | lr = self.lr * (1 - frac_done) 232 | for param_group in self.opt.param_groups: 233 | param_group["lr"] = lr 234 | 235 | def log_step(self): 236 | logger.logkv("step", self.step + self.resume_step) 237 | logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch) 238 | 239 | def save(self): 240 | def save_checkpoint(rate, params): 241 | state_dict = self.mp_trainer.master_params_to_state_dict(params) 242 | if dist.get_rank() == 0: 243 | logger.log(f"saving model {rate}...") 244 | if not rate: 245 | filename = f"model{(self.step+self.resume_step):06d}.pt" 246 | else: 247 | filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt" 248 | with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f: 249 | th.save(state_dict, f) 250 | 251 | for rate, params in zip(self.ema_rate, self.ema_params): 252 | save_checkpoint(rate, params) 253 | 254 | if dist.get_rank() == 0: 255 | with bf.BlobFile( 256 | bf.join(get_blob_logdir(), f"opt{(self.step+self.resume_step):06d}.pt"), 257 | "wb", 258 | ) as f: 259 | th.save(self.opt.state_dict(), f) 260 | 261 | # Save model parameters last to prevent race conditions where a restart 262 | # loads model at step N, but opt/ema state isn't saved for step N. 263 | save_checkpoint(0, self.mp_trainer.master_params) 264 | dist.barrier() 265 | 266 | 267 | class CMTrainLoop(TrainLoop): 268 | def __init__( 269 | self, 270 | *, 271 | target_model, 272 | teacher_model, 273 | teacher_diffusion, 274 | training_mode, 275 | ema_scale_fn, 276 | total_training_steps, 277 | **kwargs, 278 | ): 279 | super().__init__(**kwargs) 280 | self.training_mode = training_mode 281 | self.ema_scale_fn = ema_scale_fn 282 | self.target_model = target_model 283 | self.teacher_model = teacher_model 284 | self.teacher_diffusion = teacher_diffusion 285 | self.total_training_steps = total_training_steps 286 | 287 | if target_model: 288 | self._load_and_sync_target_parameters() 289 | self.target_model.requires_grad_(False) 290 | self.target_model.train() 291 | 292 | self.target_model_param_groups_and_shapes = get_param_groups_and_shapes( 293 | self.target_model.named_parameters() 294 | ) 295 | self.target_model_master_params = make_master_params( 296 | self.target_model_param_groups_and_shapes 297 | ) 298 | 299 | if teacher_model: 300 | self._load_and_sync_teacher_parameters() 301 | self.teacher_model.requires_grad_(False) 302 | self.teacher_model.eval() 303 | 304 | self.global_step = self.step 305 | if training_mode == "progdist": 306 | self.target_model.eval() 307 | _, scale = ema_scale_fn(self.global_step) 308 | if scale == 1 or scale == 2: 309 | _, start_scale = ema_scale_fn(0) 310 | n_normal_steps = int(np.log2(start_scale // 2)) * self.lr_anneal_steps 311 | step = self.global_step - n_normal_steps 312 | if step != 0: 313 | self.lr_anneal_steps *= 2 314 | self.step = step % self.lr_anneal_steps 315 | else: 316 | self.step = 0 317 | else: 318 | self.step = self.global_step % self.lr_anneal_steps 319 | 320 | def _load_and_sync_target_parameters(self): 321 | resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 322 | if resume_checkpoint: 323 | path, name = os.path.split(resume_checkpoint) 324 | target_name = name.replace("model", "target_model") 325 | resume_target_checkpoint = os.path.join(path, target_name) 326 | if bf.exists(resume_target_checkpoint) and dist.get_rank() == 0: 327 | logger.log( 328 | "loading model from checkpoint: {resume_target_checkpoint}..." 329 | ) 330 | self.target_model.load_state_dict( 331 | dist_util.load_state_dict( 332 | resume_target_checkpoint, map_location=dist_util.dev() 333 | ), 334 | ) 335 | 336 | dist_util.sync_params(self.target_model.parameters()) 337 | dist_util.sync_params(self.target_model.buffers()) 338 | 339 | def _load_and_sync_teacher_parameters(self): 340 | resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 341 | if resume_checkpoint: 342 | path, name = os.path.split(resume_checkpoint) 343 | teacher_name = name.replace("model", "teacher_model") 344 | resume_teacher_checkpoint = os.path.join(path, teacher_name) 345 | 346 | if bf.exists(resume_teacher_checkpoint) and dist.get_rank() == 0: 347 | logger.log( 348 | "loading model from checkpoint: {resume_teacher_checkpoint}..." 349 | ) 350 | self.teacher_model.load_state_dict( 351 | dist_util.load_state_dict( 352 | resume_teacher_checkpoint, map_location=dist_util.dev() 353 | ), 354 | ) 355 | 356 | dist_util.sync_params(self.teacher_model.parameters()) 357 | dist_util.sync_params(self.teacher_model.buffers()) 358 | 359 | def run_loop(self): 360 | saved = False 361 | while ( 362 | not self.lr_anneal_steps 363 | or self.step < self.lr_anneal_steps 364 | or self.global_step < self.total_training_steps 365 | ): 366 | batch, cond = next(self.data) 367 | self.run_step(batch, cond) 368 | saved = False 369 | if ( 370 | self.global_step 371 | and self.save_interval != -1 372 | and self.global_step % self.save_interval == 0 373 | ): 374 | self.save() 375 | saved = True 376 | th.cuda.empty_cache() 377 | # Run for a finite amount of time in integration tests. 378 | if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0: 379 | return 380 | 381 | if self.global_step % self.log_interval == 0: 382 | logger.dumpkvs() 383 | 384 | # Save the last checkpoint if it wasn't already saved. 385 | if not saved: 386 | self.save() 387 | 388 | def run_step(self, batch, cond): 389 | self.forward_backward(batch, cond) 390 | took_step = self.mp_trainer.optimize(self.opt) 391 | if took_step: 392 | self._update_ema() 393 | if self.target_model: 394 | self._update_target_ema() 395 | if self.training_mode == "progdist": 396 | self.reset_training_for_progdist() 397 | self.step += 1 398 | self.global_step += 1 399 | 400 | self._anneal_lr() 401 | self.log_step() 402 | 403 | def _update_target_ema(self): 404 | target_ema, scales = self.ema_scale_fn(self.global_step) 405 | with th.no_grad(): 406 | update_ema( 407 | self.target_model_master_params, 408 | self.mp_trainer.master_params, 409 | rate=target_ema, 410 | ) 411 | master_params_to_model_params( 412 | self.target_model_param_groups_and_shapes, 413 | self.target_model_master_params, 414 | ) 415 | 416 | def reset_training_for_progdist(self): 417 | assert self.training_mode == "progdist", "Training mode must be progdist" 418 | if self.global_step > 0: 419 | scales = self.ema_scale_fn(self.global_step)[1] 420 | scales2 = self.ema_scale_fn(self.global_step - 1)[1] 421 | if scales != scales2: 422 | with th.no_grad(): 423 | update_ema( 424 | self.teacher_model.parameters(), 425 | self.model.parameters(), 426 | 0.0, 427 | ) 428 | # reset optimizer 429 | self.opt = RAdam( 430 | self.mp_trainer.master_params, 431 | lr=self.lr, 432 | weight_decay=self.weight_decay, 433 | ) 434 | 435 | self.ema_params = [ 436 | copy.deepcopy(self.mp_trainer.master_params) 437 | for _ in range(len(self.ema_rate)) 438 | ] 439 | if scales == 2: 440 | self.lr_anneal_steps *= 2 441 | self.teacher_model.eval() 442 | self.step = 0 443 | 444 | def forward_backward(self, batch, cond): 445 | self.mp_trainer.zero_grad() 446 | for i in range(0, batch.shape[0], self.microbatch): 447 | micro = batch[i : i + self.microbatch].to(dist_util.dev()) 448 | micro_cond = { 449 | k: v[i : i + self.microbatch].to(dist_util.dev()) 450 | for k, v in cond.items() 451 | } 452 | last_batch = (i + self.microbatch) >= batch.shape[0] 453 | t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev()) 454 | 455 | ema, num_scales = self.ema_scale_fn(self.global_step) 456 | if self.training_mode == "progdist": 457 | if num_scales == self.ema_scale_fn(0)[1]: 458 | compute_losses = functools.partial( 459 | self.diffusion.progdist_losses, 460 | self.ddp_model, 461 | micro, 462 | num_scales, 463 | target_model=self.teacher_model, 464 | target_diffusion=self.teacher_diffusion, 465 | model_kwargs=micro_cond, 466 | ) 467 | else: 468 | compute_losses = functools.partial( 469 | self.diffusion.progdist_losses, 470 | self.ddp_model, 471 | micro, 472 | num_scales, 473 | target_model=self.target_model, 474 | target_diffusion=self.diffusion, 475 | model_kwargs=micro_cond, 476 | ) 477 | elif self.training_mode == "consistency_distillation": 478 | compute_losses = functools.partial( 479 | self.diffusion.consistency_losses, 480 | self.ddp_model, 481 | micro, 482 | num_scales, 483 | target_model=self.target_model, 484 | teacher_model=self.teacher_model, 485 | teacher_diffusion=self.teacher_diffusion, 486 | model_kwargs=micro_cond, 487 | ) 488 | elif self.training_mode == "consistency_training": 489 | compute_losses = functools.partial( 490 | self.diffusion.consistency_losses, 491 | self.ddp_model, 492 | micro, 493 | num_scales, 494 | target_model=self.target_model, 495 | model_kwargs=micro_cond, 496 | ) 497 | else: 498 | raise ValueError(f"Unknown training mode {self.training_mode}") 499 | 500 | if last_batch or not self.use_ddp: 501 | losses = compute_losses() 502 | else: 503 | with self.ddp_model.no_sync(): 504 | losses = compute_losses() 505 | 506 | if isinstance(self.schedule_sampler, LossAwareSampler): 507 | self.schedule_sampler.update_with_local_losses( 508 | t, losses["loss"].detach() 509 | ) 510 | 511 | loss = (losses["loss"] * weights).mean() 512 | 513 | log_loss_dict( 514 | self.diffusion, t, {k: v * weights for k, v in losses.items()} 515 | ) 516 | self.mp_trainer.backward(loss) 517 | 518 | def save(self): 519 | import blobfile as bf 520 | 521 | step = self.global_step 522 | 523 | def save_checkpoint(rate, params): 524 | state_dict = self.mp_trainer.master_params_to_state_dict(params) 525 | if dist.get_rank() == 0: 526 | logger.log(f"saving model {rate}...") 527 | if not rate: 528 | filename = f"model{step:06d}.pt" 529 | else: 530 | filename = f"ema_{rate}_{step:06d}.pt" 531 | with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f: 532 | th.save(state_dict, f) 533 | 534 | for rate, params in zip(self.ema_rate, self.ema_params): 535 | save_checkpoint(rate, params) 536 | 537 | logger.log("saving optimizer state...") 538 | if dist.get_rank() == 0: 539 | with bf.BlobFile( 540 | bf.join(get_blob_logdir(), f"opt{step:06d}.pt"), 541 | "wb", 542 | ) as f: 543 | th.save(self.opt.state_dict(), f) 544 | 545 | if dist.get_rank() == 0: 546 | if self.target_model: 547 | logger.log("saving target model state") 548 | filename = f"target_model{step:06d}.pt" 549 | with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f: 550 | th.save(self.target_model.state_dict(), f) 551 | if self.teacher_model and self.training_mode == "progdist": 552 | logger.log("saving teacher model state") 553 | filename = f"teacher_model{step:06d}.pt" 554 | with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f: 555 | th.save(self.teacher_model.state_dict(), f) 556 | 557 | # Save model parameters last to prevent race conditions where a restart 558 | # loads model at step N, but opt/ema state isn't saved for step N. 559 | save_checkpoint(0, self.mp_trainer.master_params) 560 | dist.barrier() 561 | 562 | def log_step(self): 563 | step = self.global_step 564 | logger.logkv("step", step) 565 | logger.logkv("samples", (step + 1) * self.global_batch) 566 | 567 | 568 | def parse_resume_step_from_filename(filename): 569 | """ 570 | Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the 571 | checkpoint's number of steps. 572 | """ 573 | split = filename.split("model") 574 | if len(split) < 2: 575 | return 0 576 | split1 = split[-1].split(".")[0] 577 | try: 578 | return int(split1) 579 | except ValueError: 580 | return 0 581 | 582 | 583 | def get_blob_logdir(): 584 | # You can change this to be a separate path to save checkpoints to 585 | # a blobstore or some external drive. 586 | return logger.get_dir() 587 | 588 | 589 | def find_resume_checkpoint(): 590 | # On your infrastructure, you may want to override this to automatically 591 | # discover the latest checkpoint on your blob storage, etc. 592 | return None 593 | 594 | 595 | def find_ema_checkpoint(main_checkpoint, step, rate): 596 | if main_checkpoint is None: 597 | return None 598 | filename = f"ema_{rate}_{(step):06d}.pt" 599 | path = bf.join(bf.dirname(main_checkpoint), filename) 600 | if bf.exists(path): 601 | return path 602 | return None 603 | 604 | 605 | def log_loss_dict(diffusion, ts, losses): 606 | for key, values in losses.items(): 607 | logger.logkv_mean(key, values.mean().item()) 608 | # Log the quantiles (four quartiles, in particular). 609 | for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()): 610 | quartile = int(4 * sub_t / diffusion.num_timesteps) 611 | logger.logkv_mean(f"{key}_q{quartile}", sub_loss) 612 | -------------------------------------------------------------------------------- /cm/unet.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | import math 4 | 5 | import numpy as np 6 | import torch as th 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from .fp16_util import convert_module_to_f16, convert_module_to_f32 11 | from .nn import ( 12 | checkpoint, 13 | conv_nd, 14 | linear, 15 | avg_pool_nd, 16 | zero_module, 17 | normalization, 18 | timestep_embedding, 19 | ) 20 | 21 | 22 | class AttentionPool2d(nn.Module): 23 | """ 24 | Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py 25 | """ 26 | 27 | def __init__( 28 | self, 29 | spacial_dim: int, 30 | embed_dim: int, 31 | num_heads_channels: int, 32 | output_dim: int = None, 33 | ): 34 | super().__init__() 35 | self.positional_embedding = nn.Parameter( 36 | th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5 37 | ) 38 | self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) 39 | self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) 40 | self.num_heads = embed_dim // num_heads_channels 41 | self.attention = QKVAttention(self.num_heads) 42 | 43 | def forward(self, x): 44 | b, c, *_spatial = x.shape 45 | x = x.reshape(b, c, -1) # NC(HW) 46 | x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) 47 | x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) 48 | x = self.qkv_proj(x) 49 | x = self.attention(x) 50 | x = self.c_proj(x) 51 | return x[:, :, 0] 52 | 53 | 54 | class TimestepBlock(nn.Module): 55 | """ 56 | Any module where forward() takes timestep embeddings as a second argument. 57 | """ 58 | 59 | @abstractmethod 60 | def forward(self, x, emb): 61 | """ 62 | Apply the module to `x` given `emb` timestep embeddings. 63 | """ 64 | 65 | 66 | class TimestepEmbedSequential(nn.Sequential, TimestepBlock): 67 | """ 68 | A sequential module that passes timestep embeddings to the children that 69 | support it as an extra input. 70 | """ 71 | 72 | def forward(self, x, emb): 73 | for layer in self: 74 | if isinstance(layer, TimestepBlock): 75 | x = layer(x, emb) 76 | else: 77 | x = layer(x) 78 | return x 79 | 80 | 81 | class Upsample(nn.Module): 82 | """ 83 | An upsampling layer with an optional convolution. 84 | 85 | :param channels: channels in the inputs and outputs. 86 | :param use_conv: a bool determining if a convolution is applied. 87 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 88 | upsampling occurs in the inner-two dimensions. 89 | """ 90 | 91 | def __init__(self, channels, use_conv, dims=2, out_channels=None): 92 | super().__init__() 93 | self.channels = channels 94 | self.out_channels = out_channels or channels 95 | self.use_conv = use_conv 96 | self.dims = dims 97 | if use_conv: 98 | self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1) 99 | 100 | def forward(self, x): 101 | assert x.shape[1] == self.channels 102 | if self.dims == 3: 103 | x = F.interpolate( 104 | x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" 105 | ) 106 | else: 107 | x = F.interpolate(x, scale_factor=2, mode="nearest") 108 | if self.use_conv: 109 | x = self.conv(x) 110 | return x 111 | 112 | 113 | class Downsample(nn.Module): 114 | """ 115 | A downsampling layer with an optional convolution. 116 | 117 | :param channels: channels in the inputs and outputs. 118 | :param use_conv: a bool determining if a convolution is applied. 119 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 120 | downsampling occurs in the inner-two dimensions. 121 | """ 122 | 123 | def __init__(self, channels, use_conv, dims=2, out_channels=None): 124 | super().__init__() 125 | self.channels = channels 126 | self.out_channels = out_channels or channels 127 | self.use_conv = use_conv 128 | self.dims = dims 129 | stride = 2 if dims != 3 else (1, 2, 2) 130 | if use_conv: 131 | self.op = conv_nd( 132 | dims, self.channels, self.out_channels, 3, stride=stride, padding=1 133 | ) 134 | else: 135 | assert self.channels == self.out_channels 136 | self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) 137 | 138 | def forward(self, x): 139 | assert x.shape[1] == self.channels 140 | return self.op(x) 141 | 142 | 143 | class ResBlock(TimestepBlock): 144 | """ 145 | A residual block that can optionally change the number of channels. 146 | 147 | :param channels: the number of input channels. 148 | :param emb_channels: the number of timestep embedding channels. 149 | :param dropout: the rate of dropout. 150 | :param out_channels: if specified, the number of out channels. 151 | :param use_conv: if True and out_channels is specified, use a spatial 152 | convolution instead of a smaller 1x1 convolution to change the 153 | channels in the skip connection. 154 | :param dims: determines if the signal is 1D, 2D, or 3D. 155 | :param use_checkpoint: if True, use gradient checkpointing on this module. 156 | :param up: if True, use this block for upsampling. 157 | :param down: if True, use this block for downsampling. 158 | """ 159 | 160 | def __init__( 161 | self, 162 | channels, 163 | emb_channels, 164 | dropout, 165 | out_channels=None, 166 | use_conv=False, 167 | use_scale_shift_norm=False, 168 | dims=2, 169 | use_checkpoint=False, 170 | up=False, 171 | down=False, 172 | ): 173 | super().__init__() 174 | self.channels = channels 175 | self.emb_channels = emb_channels 176 | self.dropout = dropout 177 | self.out_channels = out_channels or channels 178 | self.use_conv = use_conv 179 | self.use_checkpoint = use_checkpoint 180 | self.use_scale_shift_norm = use_scale_shift_norm 181 | 182 | self.in_layers = nn.Sequential( 183 | normalization(channels), 184 | nn.SiLU(), 185 | conv_nd(dims, channels, self.out_channels, 3, padding=1), 186 | ) 187 | 188 | self.updown = up or down 189 | 190 | if up: 191 | self.h_upd = Upsample(channels, False, dims) 192 | self.x_upd = Upsample(channels, False, dims) 193 | elif down: 194 | self.h_upd = Downsample(channels, False, dims) 195 | self.x_upd = Downsample(channels, False, dims) 196 | else: 197 | self.h_upd = self.x_upd = nn.Identity() 198 | 199 | self.emb_layers = nn.Sequential( 200 | nn.SiLU(), 201 | linear( 202 | emb_channels, 203 | 2 * self.out_channels if use_scale_shift_norm else self.out_channels, 204 | ), 205 | ) 206 | self.out_layers = nn.Sequential( 207 | normalization(self.out_channels), 208 | nn.SiLU(), 209 | nn.Dropout(p=dropout), 210 | zero_module( 211 | conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) 212 | ), 213 | ) 214 | 215 | if self.out_channels == channels: 216 | self.skip_connection = nn.Identity() 217 | elif use_conv: 218 | self.skip_connection = conv_nd( 219 | dims, channels, self.out_channels, 3, padding=1 220 | ) 221 | else: 222 | self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) 223 | 224 | def forward(self, x, emb): 225 | """ 226 | Apply the block to a Tensor, conditioned on a timestep embedding. 227 | 228 | :param x: an [N x C x ...] Tensor of features. 229 | :param emb: an [N x emb_channels] Tensor of timestep embeddings. 230 | :return: an [N x C x ...] Tensor of outputs. 231 | """ 232 | return checkpoint( 233 | self._forward, (x, emb), self.parameters(), self.use_checkpoint 234 | ) 235 | 236 | def _forward(self, x, emb): 237 | if self.updown: 238 | in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] 239 | h = in_rest(x) 240 | h = self.h_upd(h) 241 | x = self.x_upd(x) 242 | h = in_conv(h) 243 | else: 244 | h = self.in_layers(x) 245 | emb_out = self.emb_layers(emb).type(h.dtype) 246 | while len(emb_out.shape) < len(h.shape): 247 | emb_out = emb_out[..., None] 248 | if self.use_scale_shift_norm: 249 | out_norm, out_rest = self.out_layers[0], self.out_layers[1:] 250 | scale, shift = th.chunk(emb_out, 2, dim=1) 251 | h = out_norm(h) * (1 + scale) + shift 252 | h = out_rest(h) 253 | else: 254 | h = h + emb_out 255 | h = self.out_layers(h) 256 | return self.skip_connection(x) + h 257 | 258 | 259 | class AttentionBlock(nn.Module): 260 | """ 261 | An attention block that allows spatial positions to attend to each other. 262 | 263 | Originally ported from here, but adapted to the N-d case. 264 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. 265 | """ 266 | 267 | def __init__( 268 | self, 269 | channels, 270 | num_heads=1, 271 | num_head_channels=-1, 272 | use_checkpoint=False, 273 | attention_type="flash", 274 | encoder_channels=None, 275 | dims=2, 276 | channels_last=False, 277 | use_new_attention_order=False, 278 | ): 279 | super().__init__() 280 | self.channels = channels 281 | if num_head_channels == -1: 282 | self.num_heads = num_heads 283 | else: 284 | assert ( 285 | channels % num_head_channels == 0 286 | ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" 287 | self.num_heads = channels // num_head_channels 288 | self.use_checkpoint = use_checkpoint 289 | self.norm = normalization(channels) 290 | self.qkv = conv_nd(dims, channels, channels * 3, 1) 291 | self.attention_type = attention_type 292 | if attention_type == "flash": 293 | self.attention = QKVFlashAttention(channels, self.num_heads) 294 | else: 295 | # split heads before split qkv 296 | self.attention = QKVAttentionLegacy(self.num_heads) 297 | 298 | self.use_attention_checkpoint = not ( 299 | self.use_checkpoint or self.attention_type == "flash" 300 | ) 301 | if encoder_channels is not None: 302 | assert attention_type != "flash" 303 | self.encoder_kv = conv_nd(1, encoder_channels, channels * 2, 1) 304 | self.proj_out = zero_module(conv_nd(dims, channels, channels, 1)) 305 | 306 | def forward(self, x, encoder_out=None): 307 | if encoder_out is None: 308 | return checkpoint( 309 | self._forward, (x,), self.parameters(), self.use_checkpoint 310 | ) 311 | else: 312 | return checkpoint( 313 | self._forward, (x, encoder_out), self.parameters(), self.use_checkpoint 314 | ) 315 | 316 | def _forward(self, x, encoder_out=None): 317 | b, _, *spatial = x.shape 318 | qkv = self.qkv(self.norm(x)).view(b, -1, np.prod(spatial)) 319 | if encoder_out is not None: 320 | encoder_out = self.encoder_kv(encoder_out) 321 | h = checkpoint( 322 | self.attention, (qkv, encoder_out), (), self.use_attention_checkpoint 323 | ) 324 | else: 325 | h = checkpoint(self.attention, (qkv,), (), self.use_attention_checkpoint) 326 | h = h.view(b, -1, *spatial) 327 | h = self.proj_out(h) 328 | return x + h 329 | 330 | 331 | class QKVFlashAttention(nn.Module): 332 | def __init__( 333 | self, 334 | embed_dim, 335 | num_heads, 336 | batch_first=True, 337 | attention_dropout=0.0, 338 | causal=False, 339 | device=None, 340 | dtype=None, 341 | **kwargs, 342 | ) -> None: 343 | from einops import rearrange 344 | from flash_attn.flash_attention import FlashAttention 345 | 346 | assert batch_first 347 | factory_kwargs = {"device": device, "dtype": dtype} 348 | super().__init__() 349 | self.embed_dim = embed_dim 350 | self.num_heads = num_heads 351 | self.causal = causal 352 | 353 | assert ( 354 | self.embed_dim % num_heads == 0 355 | ), "self.kdim must be divisible by num_heads" 356 | self.head_dim = self.embed_dim // num_heads 357 | assert self.head_dim in [16, 32, 64], "Only support head_dim == 16, 32, or 64" 358 | 359 | self.inner_attn = FlashAttention( 360 | attention_dropout=attention_dropout, **factory_kwargs 361 | ) 362 | self.rearrange = rearrange 363 | 364 | def forward(self, qkv, attn_mask=None, key_padding_mask=None, need_weights=False): 365 | qkv = self.rearrange( 366 | qkv, "b (three h d) s -> b s three h d", three=3, h=self.num_heads 367 | ) 368 | qkv, _ = self.inner_attn( 369 | qkv, 370 | key_padding_mask=key_padding_mask, 371 | need_weights=need_weights, 372 | causal=self.causal, 373 | ) 374 | return self.rearrange(qkv, "b s h d -> b (h d) s") 375 | 376 | 377 | def count_flops_attn(model, _x, y): 378 | """ 379 | A counter for the `thop` package to count the operations in an 380 | attention operation. 381 | Meant to be used like: 382 | macs, params = thop.profile( 383 | model, 384 | inputs=(inputs, timestamps), 385 | custom_ops={QKVAttention: QKVAttention.count_flops}, 386 | ) 387 | """ 388 | b, c, *spatial = y[0].shape 389 | num_spatial = int(np.prod(spatial)) 390 | # We perform two matmuls with the same number of ops. 391 | # The first computes the weight matrix, the second computes 392 | # the combination of the value vectors. 393 | matmul_ops = 2 * b * (num_spatial**2) * c 394 | model.total_ops += th.DoubleTensor([matmul_ops]) 395 | 396 | 397 | class QKVAttentionLegacy(nn.Module): 398 | """ 399 | A module which performs QKV attention. Matches legacy QKVAttention + input/output heads shaping 400 | """ 401 | 402 | def __init__(self, n_heads): 403 | super().__init__() 404 | self.n_heads = n_heads 405 | from einops import rearrange 406 | self.rearrange = rearrange 407 | 408 | 409 | def forward(self, qkv): 410 | """ 411 | Apply QKV attention. 412 | 413 | :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. 414 | :return: an [N x (H * C) x T] tensor after attention. 415 | """ 416 | bs, width, length = qkv.shape 417 | assert width % (3 * self.n_heads) == 0 418 | ch = width // (3 * self.n_heads) 419 | qkv = qkv.half() 420 | 421 | qkv = self.rearrange( 422 | qkv, "b (three h d) s -> b s three h d", three=3, h=self.n_heads 423 | ) 424 | q, k, v = qkv.transpose(1, 3).transpose(3, 4).split(1, dim=2) 425 | q = q.reshape(bs*self.n_heads, ch, length) 426 | k = k.reshape(bs*self.n_heads, ch, length) 427 | v = v.reshape(bs*self.n_heads, ch, length) 428 | 429 | scale = 1 / math.sqrt(math.sqrt(ch)) 430 | weight = th.einsum( 431 | "bct,bcs->bts", q * scale, k * scale 432 | ) # More stable with f16 than dividing afterwards 433 | weight = th.softmax(weight, dim=-1).type(weight.dtype) 434 | a = th.einsum("bts,bcs->bct", weight, v) 435 | a = a.float() 436 | return a.reshape(bs, -1, length) 437 | 438 | @staticmethod 439 | def count_flops(model, _x, y): 440 | return count_flops_attn(model, _x, y) 441 | 442 | 443 | # class QKVAttention(nn.Module): 444 | # """ 445 | # A module which performs QKV attention and splits in a different order. 446 | # """ 447 | 448 | # def __init__(self, n_heads): 449 | # super().__init__() 450 | # self.n_heads = n_heads 451 | 452 | # def forward(self, qkv): 453 | # """ 454 | # Apply QKV attention. 455 | 456 | # :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. 457 | # :return: an [N x (H * C) x T] tensor after attention. 458 | # """ 459 | # bs, width, length = qkv.shape 460 | # assert width % (3 * self.n_heads) == 0 461 | # ch = width // (3 * self.n_heads) 462 | # q, k, v = qkv.chunk(3, dim=1) 463 | # scale = 1 / math.sqrt(math.sqrt(ch)) 464 | # weight = th.einsum( 465 | # "bct,bcs->bts", 466 | # (q * scale).view(bs * self.n_heads, ch, length), 467 | # (k * scale).view(bs * self.n_heads, ch, length), 468 | # ) # More stable with f16 than dividing afterwards 469 | # weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) 470 | # a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) 471 | # return a.reshape(bs, -1, length) 472 | 473 | # @staticmethod 474 | # def count_flops(model, _x, y): 475 | # return count_flops_attn(model, _x, y) 476 | 477 | 478 | class QKVAttention(nn.Module): 479 | """ 480 | A module which performs QKV attention. Fallback from Blocksparse if use_fp16=False 481 | """ 482 | 483 | def __init__(self, n_heads): 484 | super().__init__() 485 | self.n_heads = n_heads 486 | 487 | def forward(self, qkv, encoder_kv=None): 488 | """ 489 | Apply QKV attention. 490 | 491 | :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. 492 | :return: an [N x (H * C) x T] tensor after attention. 493 | """ 494 | bs, width, length = qkv.shape 495 | assert width % (3 * self.n_heads) == 0 496 | ch = width // (3 * self.n_heads) 497 | q, k, v = qkv.chunk(3, dim=1) 498 | if encoder_kv is not None: 499 | assert encoder_kv.shape[1] == 2 * ch * self.n_heads 500 | ek, ev = encoder_kv.chunk(2, dim=1) 501 | k = th.cat([ek, k], dim=-1) 502 | v = th.cat([ev, v], dim=-1) 503 | scale = 1 / math.sqrt(math.sqrt(ch)) 504 | weight = th.einsum( 505 | "bct,bcs->bts", 506 | (q * scale).view(bs * self.n_heads, ch, length), 507 | (k * scale).view(bs * self.n_heads, ch, -1), 508 | ) # More stable with f16 than dividing afterwards 509 | weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) 510 | a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, -1)) 511 | return a.reshape(bs, -1, length) 512 | 513 | @staticmethod 514 | def count_flops(model, _x, y): 515 | return count_flops_attn(model, _x, y) 516 | 517 | 518 | class UNetModel(nn.Module): 519 | """ 520 | The full UNet model with attention and timestep embedding. 521 | 522 | :param in_channels: channels in the input Tensor. 523 | :param model_channels: base channel count for the model. 524 | :param out_channels: channels in the output Tensor. 525 | :param num_res_blocks: number of residual blocks per downsample. 526 | :param attention_resolutions: a collection of downsample rates at which 527 | attention will take place. May be a set, list, or tuple. 528 | For example, if this contains 4, then at 4x downsampling, attention 529 | will be used. 530 | :param dropout: the dropout probability. 531 | :param channel_mult: channel multiplier for each level of the UNet. 532 | :param conv_resample: if True, use learned convolutions for upsampling and 533 | downsampling. 534 | :param dims: determines if the signal is 1D, 2D, or 3D. 535 | :param num_classes: if specified (as an int), then this model will be 536 | class-conditional with `num_classes` classes. 537 | :param use_checkpoint: use gradient checkpointing to reduce memory usage. 538 | :param num_heads: the number of attention heads in each attention layer. 539 | :param num_heads_channels: if specified, ignore num_heads and instead use 540 | a fixed channel width per attention head. 541 | :param num_heads_upsample: works with num_heads to set a different number 542 | of heads for upsampling. Deprecated. 543 | :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. 544 | :param resblock_updown: use residual blocks for up/downsampling. 545 | :param use_new_attention_order: use a different attention pattern for potentially 546 | increased efficiency. 547 | """ 548 | 549 | def __init__( 550 | self, 551 | image_size, 552 | in_channels, 553 | model_channels, 554 | out_channels, 555 | num_res_blocks, 556 | attention_resolutions, 557 | dropout=0, 558 | channel_mult=(1, 2, 4, 8), 559 | conv_resample=True, 560 | dims=2, 561 | num_classes=None, 562 | use_checkpoint=False, 563 | use_fp16=False, 564 | num_heads=1, 565 | num_head_channels=-1, 566 | num_heads_upsample=-1, 567 | use_scale_shift_norm=False, 568 | resblock_updown=False, 569 | use_new_attention_order=False, 570 | ): 571 | super().__init__() 572 | 573 | if num_heads_upsample == -1: 574 | num_heads_upsample = num_heads 575 | 576 | self.image_size = image_size 577 | self.in_channels = in_channels 578 | self.model_channels = model_channels 579 | self.out_channels = out_channels 580 | self.num_res_blocks = num_res_blocks 581 | self.attention_resolutions = attention_resolutions 582 | self.dropout = dropout 583 | self.channel_mult = channel_mult 584 | self.conv_resample = conv_resample 585 | self.num_classes = num_classes 586 | self.use_checkpoint = use_checkpoint 587 | self.dtype = th.float16 if use_fp16 else th.float32 588 | self.num_heads = num_heads 589 | self.num_head_channels = num_head_channels 590 | self.num_heads_upsample = num_heads_upsample 591 | 592 | time_embed_dim = model_channels * 4 593 | self.time_embed = nn.Sequential( 594 | linear(model_channels, time_embed_dim), 595 | nn.SiLU(), 596 | linear(time_embed_dim, time_embed_dim), 597 | ) 598 | 599 | if self.num_classes is not None: 600 | self.label_emb = nn.Embedding(num_classes, time_embed_dim) 601 | 602 | ch = input_ch = int(channel_mult[0] * model_channels) 603 | self.input_blocks = nn.ModuleList( 604 | [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))] 605 | ) 606 | self._feature_size = ch 607 | input_block_chans = [ch] 608 | ds = 1 609 | for level, mult in enumerate(channel_mult): 610 | for _ in range(num_res_blocks): 611 | layers = [ 612 | ResBlock( 613 | ch, 614 | time_embed_dim, 615 | dropout, 616 | out_channels=int(mult * model_channels), 617 | dims=dims, 618 | use_checkpoint=use_checkpoint, 619 | use_scale_shift_norm=use_scale_shift_norm, 620 | ) 621 | ] 622 | ch = int(mult * model_channels) 623 | if ds in attention_resolutions: 624 | layers.append( 625 | AttentionBlock( 626 | ch, 627 | use_checkpoint=use_checkpoint, 628 | num_heads=num_heads, 629 | num_head_channels=num_head_channels, 630 | use_new_attention_order=use_new_attention_order, 631 | ) 632 | ) 633 | self.input_blocks.append(TimestepEmbedSequential(*layers)) 634 | self._feature_size += ch 635 | input_block_chans.append(ch) 636 | if level != len(channel_mult) - 1: 637 | out_ch = ch 638 | self.input_blocks.append( 639 | TimestepEmbedSequential( 640 | ResBlock( 641 | ch, 642 | time_embed_dim, 643 | dropout, 644 | out_channels=out_ch, 645 | dims=dims, 646 | use_checkpoint=use_checkpoint, 647 | use_scale_shift_norm=use_scale_shift_norm, 648 | down=True, 649 | ) 650 | if resblock_updown 651 | else Downsample( 652 | ch, conv_resample, dims=dims, out_channels=out_ch 653 | ) 654 | ) 655 | ) 656 | ch = out_ch 657 | input_block_chans.append(ch) 658 | ds *= 2 659 | self._feature_size += ch 660 | 661 | self.middle_block = TimestepEmbedSequential( 662 | ResBlock( 663 | ch, 664 | time_embed_dim, 665 | dropout, 666 | dims=dims, 667 | use_checkpoint=use_checkpoint, 668 | use_scale_shift_norm=use_scale_shift_norm, 669 | ), 670 | AttentionBlock( 671 | ch, 672 | use_checkpoint=use_checkpoint, 673 | num_heads=num_heads, 674 | num_head_channels=num_head_channels, 675 | use_new_attention_order=use_new_attention_order, 676 | ), 677 | ResBlock( 678 | ch, 679 | time_embed_dim, 680 | dropout, 681 | dims=dims, 682 | use_checkpoint=use_checkpoint, 683 | use_scale_shift_norm=use_scale_shift_norm, 684 | ), 685 | ) 686 | self._feature_size += ch 687 | 688 | self.output_blocks = nn.ModuleList([]) 689 | for level, mult in list(enumerate(channel_mult))[::-1]: 690 | for i in range(num_res_blocks + 1): 691 | ich = input_block_chans.pop() 692 | layers = [ 693 | ResBlock( 694 | ch + ich, 695 | time_embed_dim, 696 | dropout, 697 | out_channels=int(model_channels * mult), 698 | dims=dims, 699 | use_checkpoint=use_checkpoint, 700 | use_scale_shift_norm=use_scale_shift_norm, 701 | ) 702 | ] 703 | ch = int(model_channels * mult) 704 | if ds in attention_resolutions: 705 | layers.append( 706 | AttentionBlock( 707 | ch, 708 | use_checkpoint=use_checkpoint, 709 | num_heads=num_heads_upsample, 710 | num_head_channels=num_head_channels, 711 | use_new_attention_order=use_new_attention_order, 712 | ) 713 | ) 714 | if level and i == num_res_blocks: 715 | out_ch = ch 716 | layers.append( 717 | ResBlock( 718 | ch, 719 | time_embed_dim, 720 | dropout, 721 | out_channels=out_ch, 722 | dims=dims, 723 | use_checkpoint=use_checkpoint, 724 | use_scale_shift_norm=use_scale_shift_norm, 725 | up=True, 726 | ) 727 | if resblock_updown 728 | else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) 729 | ) 730 | ds //= 2 731 | self.output_blocks.append(TimestepEmbedSequential(*layers)) 732 | self._feature_size += ch 733 | 734 | self.out = nn.Sequential( 735 | normalization(ch), 736 | nn.SiLU(), 737 | zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)), 738 | ) 739 | 740 | def convert_to_fp16(self): 741 | """ 742 | Convert the torso of the model to float16. 743 | """ 744 | self.input_blocks.apply(convert_module_to_f16) 745 | self.middle_block.apply(convert_module_to_f16) 746 | self.output_blocks.apply(convert_module_to_f16) 747 | 748 | def convert_to_fp32(self): 749 | """ 750 | Convert the torso of the model to float32. 751 | """ 752 | self.input_blocks.apply(convert_module_to_f32) 753 | self.middle_block.apply(convert_module_to_f32) 754 | self.output_blocks.apply(convert_module_to_f32) 755 | 756 | def forward(self, x, timesteps, y=None): 757 | """ 758 | Apply the model to an input batch. 759 | 760 | :param x: an [N x C x ...] Tensor of inputs. 761 | :param timesteps: a 1-D batch of timesteps. 762 | :param y: an [N] Tensor of labels, if class-conditional. 763 | :return: an [N x C x ...] Tensor of outputs. 764 | """ 765 | assert (y is not None) == ( 766 | self.num_classes is not None 767 | ), "must specify y if and only if the model is class-conditional" 768 | 769 | hs = [] 770 | emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) 771 | 772 | if self.num_classes is not None: 773 | assert y.shape == (x.shape[0],) 774 | emb = emb + self.label_emb(y) 775 | 776 | h = x.type(self.dtype) 777 | for module in self.input_blocks: 778 | h = module(h, emb) 779 | hs.append(h) 780 | h = self.middle_block(h, emb) 781 | for module in self.output_blocks: 782 | h = th.cat([h, hs.pop()], dim=1) 783 | h = module(h, emb) 784 | h = h.type(x.dtype) 785 | return self.out(h) 786 | -------------------------------------------------------------------------------- /datasets/README.md: -------------------------------------------------------------------------------- 1 | # Downloading datasets 2 | 3 | This directory includes instructions and scripts for downloading ImageNet and LSUN bedrooms for use in this codebase. 4 | 5 | ## Class-conditional ImageNet 6 | 7 | For our class-conditional models, we use the official ILSVRC2012 dataset with manual center cropping and downsampling. To obtain this dataset, navigate to [this page on image-net.org](https://image-net.org/challenges/LSVRC/2012/2012-downloads.php) and sign in (or create an account if you do not already have one). Then click on the link reading "Training images (Task 1 & 2)". This is a 138GB tar file containing 1000 sub-tar files, one per class. 8 | 9 | Once the file is downloaded, extract it and look inside. You should see 1000 `.tar` files. You need to extract each of these, which may be impractical to do by hand on your operating system. To automate the process on a Unix-based system, you can `cd` into the directory and run this short shell script: 10 | 11 | ``` 12 | for file in *.tar; do tar xf "$file"; rm "$file"; done 13 | ``` 14 | 15 | This will extract and remove each tar file in turn. 16 | 17 | Once all of the images have been extracted, the resulting directory should be usable as a data directory (the `--data_dir` argument for the training script). The filenames should all start with WNID (class ids) followed by underscores, like `n01440764_2708.JPEG`. Conveniently (but not by accident) this is how the automated data-loader expects to discover class labels. 18 | 19 | ## LSUN bedroom 20 | 21 | To download and pre-process LSUN bedroom, clone [fyu/lsun](https://github.com/fyu/lsun) on GitHub and run their download script `python3 download.py -c bedroom`. The result will be an "lmdb" database named like `bedroom_train_lmdb`. You can pass this to our [lsun_bedroom.py](lsun_bedroom.py) script like so: 22 | 23 | ``` 24 | python lsun_bedroom.py bedroom_train_lmdb lsun_train_output_dir 25 | ``` 26 | 27 | This creates a directory called `lsun_train_output_dir`. This directory can be passed to the training scripts via the `--data_dir` argument. 28 | -------------------------------------------------------------------------------- /datasets/lsun_bedroom.py: -------------------------------------------------------------------------------- 1 | """ 2 | Convert an LSUN lmdb database into a directory of images. 3 | """ 4 | 5 | import argparse 6 | import io 7 | import os 8 | 9 | from PIL import Image 10 | import lmdb 11 | import numpy as np 12 | 13 | 14 | def read_images(lmdb_path, image_size): 15 | env = lmdb.open(lmdb_path, map_size=1099511627776, max_readers=100, readonly=True) 16 | with env.begin(write=False) as transaction: 17 | cursor = transaction.cursor() 18 | for _, webp_data in cursor: 19 | img = Image.open(io.BytesIO(webp_data)) 20 | width, height = img.size 21 | scale = image_size / min(width, height) 22 | img = img.resize( 23 | (int(round(scale * width)), int(round(scale * height))), 24 | resample=Image.BOX, 25 | ) 26 | arr = np.array(img) 27 | h, w, _ = arr.shape 28 | h_off = (h - image_size) // 2 29 | w_off = (w - image_size) // 2 30 | arr = arr[h_off : h_off + image_size, w_off : w_off + image_size] 31 | yield arr 32 | 33 | 34 | def dump_images(out_dir, images, prefix): 35 | if not os.path.exists(out_dir): 36 | os.mkdir(out_dir) 37 | for i, img in enumerate(images): 38 | Image.fromarray(img).save(os.path.join(out_dir, f"{prefix}_{i:07d}.png")) 39 | 40 | 41 | def main(): 42 | parser = argparse.ArgumentParser() 43 | parser.add_argument("--image-size", help="new image size", type=int, default=256) 44 | parser.add_argument("--prefix", help="class name", type=str, default="bedroom") 45 | parser.add_argument("lmdb_path", help="path to an LSUN lmdb database") 46 | parser.add_argument("out_dir", help="path to output directory") 47 | args = parser.parse_args() 48 | 49 | images = read_images(args.lmdb_path, args.image_size) 50 | dump_images(args.out_dir, images, args.prefix) 51 | 52 | 53 | if __name__ == "__main__": 54 | main() 55 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu22.04 2 | 3 | ENV DEBIAN_FRONTEND=noninteractive PIP_PREFER_BINARY=1 4 | 5 | RUN apt-get update && apt-get install -y --no-install-recommends \ 6 | libgl1-mesa-dev libopenmpi-dev git wget \ 7 | python3 python3-dev python3-pip python3-setuptools python3-wheel \ 8 | && apt-get clean && rm -rf /var/lib/apt/lists/* 9 | 10 | RUN echo "export PATH=/usr/local/cuda/bin:$PATH" >> /etc/bash.bashrc \ 11 | && echo "export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH" >> /etc/bash.bashrc 12 | 13 | RUN pip3 install --no-cache-dir --upgrade pip setuptools wheel packaging mpi4py \ 14 | && pip3 install --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/cu118 \ 15 | && pip3 install flash-attn==0.2.8 16 | 17 | WORKDIR /home/ 18 | RUN pip3 install -e git+https://github.com/openai/consistency_models.git@main#egg=consistency_models \ 19 | && ln -s /usr/bin/python3 /usr/bin/python 20 | -------------------------------------------------------------------------------- /docker/Makefile: -------------------------------------------------------------------------------- 1 | NAME=consistency_models 2 | TAG=0.1 3 | PROJECT_DIRECTORY = $(shell pwd)/.. 4 | 5 | build: 6 | docker build -t ${NAME}:${TAG} -f Dockerfile . 7 | 8 | run: 9 | docker container run --gpus all\ 10 | --restart=always\ 11 | -it -d \ 12 | -v $(PROJECT_DIRECTORY):/home/${NAME}\ 13 | --name ${NAME} ${NAME}:${TAG} /bin/bash 14 | -------------------------------------------------------------------------------- /evaluations/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/consistency_models/e32b69ee436d518377db86fb2127a3972d0d8716/evaluations/__init__.py -------------------------------------------------------------------------------- /evaluations/evaluator.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import io 3 | import os 4 | import random 5 | import warnings 6 | import zipfile 7 | from abc import ABC, abstractmethod 8 | from contextlib import contextmanager 9 | from functools import partial 10 | from multiprocessing import cpu_count 11 | from multiprocessing.pool import ThreadPool 12 | from typing import Iterable, Optional, Tuple 13 | 14 | import numpy as np 15 | import requests 16 | import tensorflow.compat.v1 as tf 17 | from scipy import linalg 18 | from tqdm.auto import tqdm 19 | 20 | INCEPTION_V3_URL = "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/classify_image_graph_def.pb" 21 | INCEPTION_V3_PATH = "classify_image_graph_def.pb" 22 | 23 | FID_POOL_NAME = "pool_3:0" 24 | FID_SPATIAL_NAME = "mixed_6/conv:0" 25 | 26 | 27 | def main(): 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument("ref_batch", help="path to reference batch npz file") 30 | parser.add_argument("sample_batch", help="path to sample batch npz file") 31 | args = parser.parse_args() 32 | 33 | config = tf.ConfigProto( 34 | allow_soft_placement=True # allows DecodeJpeg to run on CPU in Inception graph 35 | ) 36 | config.gpu_options.allow_growth = True 37 | evaluator = Evaluator(tf.Session(config=config)) 38 | 39 | print("warming up TensorFlow...") 40 | # This will cause TF to print a bunch of verbose stuff now rather 41 | # than after the next print(), to help prevent confusion. 42 | evaluator.warmup() 43 | 44 | print("computing reference batch activations...") 45 | ref_acts = evaluator.read_activations(args.ref_batch) 46 | print("computing/reading reference batch statistics...") 47 | ref_stats, ref_stats_spatial = evaluator.read_statistics(args.ref_batch, ref_acts) 48 | 49 | print("computing sample batch activations...") 50 | sample_acts = evaluator.read_activations(args.sample_batch) 51 | print("computing/reading sample batch statistics...") 52 | sample_stats, sample_stats_spatial = evaluator.read_statistics( 53 | args.sample_batch, sample_acts 54 | ) 55 | 56 | print("Computing evaluations...") 57 | print("Inception Score:", evaluator.compute_inception_score(sample_acts[0])) 58 | print("FID:", sample_stats.frechet_distance(ref_stats)) 59 | print("sFID:", sample_stats_spatial.frechet_distance(ref_stats_spatial)) 60 | prec, recall = evaluator.compute_prec_recall(ref_acts[0], sample_acts[0]) 61 | print("Precision:", prec) 62 | print("Recall:", recall) 63 | 64 | 65 | class InvalidFIDException(Exception): 66 | pass 67 | 68 | 69 | class FIDStatistics: 70 | def __init__(self, mu: np.ndarray, sigma: np.ndarray): 71 | self.mu = mu 72 | self.sigma = sigma 73 | 74 | def frechet_distance(self, other, eps=1e-6): 75 | """ 76 | Compute the Frechet distance between two sets of statistics. 77 | """ 78 | # https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L132 79 | mu1, sigma1 = self.mu, self.sigma 80 | mu2, sigma2 = other.mu, other.sigma 81 | 82 | mu1 = np.atleast_1d(mu1) 83 | mu2 = np.atleast_1d(mu2) 84 | 85 | sigma1 = np.atleast_2d(sigma1) 86 | sigma2 = np.atleast_2d(sigma2) 87 | 88 | assert ( 89 | mu1.shape == mu2.shape 90 | ), f"Training and test mean vectors have different lengths: {mu1.shape}, {mu2.shape}" 91 | assert ( 92 | sigma1.shape == sigma2.shape 93 | ), f"Training and test covariances have different dimensions: {sigma1.shape}, {sigma2.shape}" 94 | 95 | diff = mu1 - mu2 96 | 97 | # product might be almost singular 98 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 99 | if not np.isfinite(covmean).all(): 100 | msg = ( 101 | "fid calculation produces singular product; adding %s to diagonal of cov estimates" 102 | % eps 103 | ) 104 | warnings.warn(msg) 105 | offset = np.eye(sigma1.shape[0]) * eps 106 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 107 | 108 | # numerical error might give slight imaginary component 109 | if np.iscomplexobj(covmean): 110 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 111 | m = np.max(np.abs(covmean.imag)) 112 | raise ValueError("Imaginary component {}".format(m)) 113 | covmean = covmean.real 114 | 115 | tr_covmean = np.trace(covmean) 116 | 117 | return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean 118 | 119 | 120 | class Evaluator: 121 | def __init__( 122 | self, 123 | session, 124 | batch_size=64, 125 | softmax_batch_size=512, 126 | ): 127 | self.sess = session 128 | self.batch_size = batch_size 129 | self.softmax_batch_size = softmax_batch_size 130 | self.manifold_estimator = ManifoldEstimator(session) 131 | with self.sess.graph.as_default(): 132 | self.image_input = tf.placeholder(tf.float32, shape=[None, None, None, 3]) 133 | self.softmax_input = tf.placeholder(tf.float32, shape=[None, 2048]) 134 | self.pool_features, self.spatial_features = _create_feature_graph( 135 | self.image_input 136 | ) 137 | self.softmax = _create_softmax_graph(self.softmax_input) 138 | 139 | def warmup(self): 140 | self.compute_activations(np.zeros([1, 8, 64, 64, 3])) 141 | 142 | def read_activations(self, npz_path: str) -> Tuple[np.ndarray, np.ndarray]: 143 | with open_npz_array(npz_path, "arr_0") as reader: 144 | return self.compute_activations(reader.read_batches(self.batch_size)) 145 | 146 | def compute_activations( 147 | self, batches: Iterable[np.ndarray] 148 | ) -> Tuple[np.ndarray, np.ndarray]: 149 | """ 150 | Compute image features for downstream evals. 151 | 152 | :param batches: a iterator over NHWC numpy arrays in [0, 255]. 153 | :return: a tuple of numpy arrays of shape [N x X], where X is a feature 154 | dimension. The tuple is (pool_3, spatial). 155 | """ 156 | preds = [] 157 | spatial_preds = [] 158 | for batch in tqdm(batches): 159 | batch = batch.astype(np.float32) 160 | pred, spatial_pred = self.sess.run( 161 | [self.pool_features, self.spatial_features], {self.image_input: batch} 162 | ) 163 | preds.append(pred.reshape([pred.shape[0], -1])) 164 | spatial_preds.append(spatial_pred.reshape([spatial_pred.shape[0], -1])) 165 | return ( 166 | np.concatenate(preds, axis=0), 167 | np.concatenate(spatial_preds, axis=0), 168 | ) 169 | 170 | def read_statistics( 171 | self, npz_path: str, activations: Tuple[np.ndarray, np.ndarray] 172 | ) -> Tuple[FIDStatistics, FIDStatistics]: 173 | obj = np.load(npz_path) 174 | if "mu" in list(obj.keys()): 175 | return FIDStatistics(obj["mu"], obj["sigma"]), FIDStatistics( 176 | obj["mu_s"], obj["sigma_s"] 177 | ) 178 | return tuple(self.compute_statistics(x) for x in activations) 179 | 180 | def compute_statistics(self, activations: np.ndarray) -> FIDStatistics: 181 | mu = np.mean(activations, axis=0) 182 | sigma = np.cov(activations, rowvar=False) 183 | return FIDStatistics(mu, sigma) 184 | 185 | def compute_inception_score( 186 | self, activations: np.ndarray, split_size: int = 5000 187 | ) -> float: 188 | softmax_out = [] 189 | for i in range(0, len(activations), self.softmax_batch_size): 190 | acts = activations[i : i + self.softmax_batch_size] 191 | softmax_out.append( 192 | self.sess.run(self.softmax, feed_dict={self.softmax_input: acts}) 193 | ) 194 | preds = np.concatenate(softmax_out, axis=0) 195 | # https://github.com/openai/improved-gan/blob/4f5d1ec5c16a7eceb206f42bfc652693601e1d5c/inception_score/model.py#L46 196 | scores = [] 197 | for i in range(0, len(preds), split_size): 198 | part = preds[i : i + split_size] 199 | kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0))) 200 | kl = np.mean(np.sum(kl, 1)) 201 | scores.append(np.exp(kl)) 202 | return float(np.mean(scores)) 203 | 204 | def compute_prec_recall( 205 | self, activations_ref: np.ndarray, activations_sample: np.ndarray 206 | ) -> Tuple[float, float]: 207 | radii_1 = self.manifold_estimator.manifold_radii(activations_ref) 208 | radii_2 = self.manifold_estimator.manifold_radii(activations_sample) 209 | pr = self.manifold_estimator.evaluate_pr( 210 | activations_ref, radii_1, activations_sample, radii_2 211 | ) 212 | return (float(pr[0][0]), float(pr[1][0])) 213 | 214 | 215 | class ManifoldEstimator: 216 | """ 217 | A helper for comparing manifolds of feature vectors. 218 | 219 | Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L57 220 | """ 221 | 222 | def __init__( 223 | self, 224 | session, 225 | row_batch_size=10000, 226 | col_batch_size=10000, 227 | nhood_sizes=(3,), 228 | clamp_to_percentile=None, 229 | eps=1e-5, 230 | ): 231 | """ 232 | Estimate the manifold of given feature vectors. 233 | 234 | :param session: the TensorFlow session. 235 | :param row_batch_size: row batch size to compute pairwise distances 236 | (parameter to trade-off between memory usage and performance). 237 | :param col_batch_size: column batch size to compute pairwise distances. 238 | :param nhood_sizes: number of neighbors used to estimate the manifold. 239 | :param clamp_to_percentile: prune hyperspheres that have radius larger than 240 | the given percentile. 241 | :param eps: small number for numerical stability. 242 | """ 243 | self.distance_block = DistanceBlock(session) 244 | self.row_batch_size = row_batch_size 245 | self.col_batch_size = col_batch_size 246 | self.nhood_sizes = nhood_sizes 247 | self.num_nhoods = len(nhood_sizes) 248 | self.clamp_to_percentile = clamp_to_percentile 249 | self.eps = eps 250 | 251 | def warmup(self): 252 | feats, radii = ( 253 | np.zeros([1, 2048], dtype=np.float32), 254 | np.zeros([1, 1], dtype=np.float32), 255 | ) 256 | self.evaluate_pr(feats, radii, feats, radii) 257 | 258 | def manifold_radii(self, features: np.ndarray) -> np.ndarray: 259 | num_images = len(features) 260 | 261 | # Estimate manifold of features by calculating distances to k-NN of each sample. 262 | radii = np.zeros([num_images, self.num_nhoods], dtype=np.float32) 263 | distance_batch = np.zeros([self.row_batch_size, num_images], dtype=np.float32) 264 | seq = np.arange(max(self.nhood_sizes) + 1, dtype=np.int32) 265 | 266 | for begin1 in range(0, num_images, self.row_batch_size): 267 | end1 = min(begin1 + self.row_batch_size, num_images) 268 | row_batch = features[begin1:end1] 269 | 270 | for begin2 in range(0, num_images, self.col_batch_size): 271 | end2 = min(begin2 + self.col_batch_size, num_images) 272 | col_batch = features[begin2:end2] 273 | 274 | # Compute distances between batches. 275 | distance_batch[ 276 | 0 : end1 - begin1, begin2:end2 277 | ] = self.distance_block.pairwise_distances(row_batch, col_batch) 278 | 279 | # Find the k-nearest neighbor from the current batch. 280 | radii[begin1:end1, :] = np.concatenate( 281 | [ 282 | x[:, self.nhood_sizes] 283 | for x in _numpy_partition( 284 | distance_batch[0 : end1 - begin1, :], seq, axis=1 285 | ) 286 | ], 287 | axis=0, 288 | ) 289 | 290 | if self.clamp_to_percentile is not None: 291 | max_distances = np.percentile(radii, self.clamp_to_percentile, axis=0) 292 | radii[radii > max_distances] = 0 293 | return radii 294 | 295 | def evaluate( 296 | self, features: np.ndarray, radii: np.ndarray, eval_features: np.ndarray 297 | ): 298 | """ 299 | Evaluate if new feature vectors are at the manifold. 300 | """ 301 | num_eval_images = eval_features.shape[0] 302 | num_ref_images = radii.shape[0] 303 | distance_batch = np.zeros( 304 | [self.row_batch_size, num_ref_images], dtype=np.float32 305 | ) 306 | batch_predictions = np.zeros([num_eval_images, self.num_nhoods], dtype=np.int32) 307 | max_realism_score = np.zeros([num_eval_images], dtype=np.float32) 308 | nearest_indices = np.zeros([num_eval_images], dtype=np.int32) 309 | 310 | for begin1 in range(0, num_eval_images, self.row_batch_size): 311 | end1 = min(begin1 + self.row_batch_size, num_eval_images) 312 | feature_batch = eval_features[begin1:end1] 313 | 314 | for begin2 in range(0, num_ref_images, self.col_batch_size): 315 | end2 = min(begin2 + self.col_batch_size, num_ref_images) 316 | ref_batch = features[begin2:end2] 317 | 318 | distance_batch[ 319 | 0 : end1 - begin1, begin2:end2 320 | ] = self.distance_block.pairwise_distances(feature_batch, ref_batch) 321 | 322 | # From the minibatch of new feature vectors, determine if they are in the estimated manifold. 323 | # If a feature vector is inside a hypersphere of some reference sample, then 324 | # the new sample lies at the estimated manifold. 325 | # The radii of the hyperspheres are determined from distances of neighborhood size k. 326 | samples_in_manifold = distance_batch[0 : end1 - begin1, :, None] <= radii 327 | batch_predictions[begin1:end1] = np.any(samples_in_manifold, axis=1).astype( 328 | np.int32 329 | ) 330 | 331 | max_realism_score[begin1:end1] = np.max( 332 | radii[:, 0] / (distance_batch[0 : end1 - begin1, :] + self.eps), axis=1 333 | ) 334 | nearest_indices[begin1:end1] = np.argmin( 335 | distance_batch[0 : end1 - begin1, :], axis=1 336 | ) 337 | 338 | return { 339 | "fraction": float(np.mean(batch_predictions)), 340 | "batch_predictions": batch_predictions, 341 | "max_realisim_score": max_realism_score, 342 | "nearest_indices": nearest_indices, 343 | } 344 | 345 | def evaluate_pr( 346 | self, 347 | features_1: np.ndarray, 348 | radii_1: np.ndarray, 349 | features_2: np.ndarray, 350 | radii_2: np.ndarray, 351 | ) -> Tuple[np.ndarray, np.ndarray]: 352 | """ 353 | Evaluate precision and recall efficiently. 354 | 355 | :param features_1: [N1 x D] feature vectors for reference batch. 356 | :param radii_1: [N1 x K1] radii for reference vectors. 357 | :param features_2: [N2 x D] feature vectors for the other batch. 358 | :param radii_2: [N x K2] radii for other vectors. 359 | :return: a tuple of arrays for (precision, recall): 360 | - precision: an np.ndarray of length K1 361 | - recall: an np.ndarray of length K2 362 | """ 363 | features_1_status = np.zeros([len(features_1), radii_2.shape[1]], dtype=np.bool) 364 | features_2_status = np.zeros([len(features_2), radii_1.shape[1]], dtype=np.bool) 365 | for begin_1 in range(0, len(features_1), self.row_batch_size): 366 | end_1 = begin_1 + self.row_batch_size 367 | batch_1 = features_1[begin_1:end_1] 368 | for begin_2 in range(0, len(features_2), self.col_batch_size): 369 | end_2 = begin_2 + self.col_batch_size 370 | batch_2 = features_2[begin_2:end_2] 371 | batch_1_in, batch_2_in = self.distance_block.less_thans( 372 | batch_1, radii_1[begin_1:end_1], batch_2, radii_2[begin_2:end_2] 373 | ) 374 | features_1_status[begin_1:end_1] |= batch_1_in 375 | features_2_status[begin_2:end_2] |= batch_2_in 376 | return ( 377 | np.mean(features_2_status.astype(np.float64), axis=0), 378 | np.mean(features_1_status.astype(np.float64), axis=0), 379 | ) 380 | 381 | 382 | class DistanceBlock: 383 | """ 384 | Calculate pairwise distances between vectors. 385 | 386 | Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L34 387 | """ 388 | 389 | def __init__(self, session): 390 | self.session = session 391 | 392 | # Initialize TF graph to calculate pairwise distances. 393 | with session.graph.as_default(): 394 | self._features_batch1 = tf.placeholder(tf.float32, shape=[None, None]) 395 | self._features_batch2 = tf.placeholder(tf.float32, shape=[None, None]) 396 | distance_block_16 = _batch_pairwise_distances( 397 | tf.cast(self._features_batch1, tf.float16), 398 | tf.cast(self._features_batch2, tf.float16), 399 | ) 400 | self.distance_block = tf.cond( 401 | tf.reduce_all(tf.math.is_finite(distance_block_16)), 402 | lambda: tf.cast(distance_block_16, tf.float32), 403 | lambda: _batch_pairwise_distances( 404 | self._features_batch1, self._features_batch2 405 | ), 406 | ) 407 | 408 | # Extra logic for less thans. 409 | self._radii1 = tf.placeholder(tf.float32, shape=[None, None]) 410 | self._radii2 = tf.placeholder(tf.float32, shape=[None, None]) 411 | dist32 = tf.cast(self.distance_block, tf.float32)[..., None] 412 | self._batch_1_in = tf.math.reduce_any(dist32 <= self._radii2, axis=1) 413 | self._batch_2_in = tf.math.reduce_any( 414 | dist32 <= self._radii1[:, None], axis=0 415 | ) 416 | 417 | def pairwise_distances(self, U, V): 418 | """ 419 | Evaluate pairwise distances between two batches of feature vectors. 420 | """ 421 | return self.session.run( 422 | self.distance_block, 423 | feed_dict={self._features_batch1: U, self._features_batch2: V}, 424 | ) 425 | 426 | def less_thans(self, batch_1, radii_1, batch_2, radii_2): 427 | return self.session.run( 428 | [self._batch_1_in, self._batch_2_in], 429 | feed_dict={ 430 | self._features_batch1: batch_1, 431 | self._features_batch2: batch_2, 432 | self._radii1: radii_1, 433 | self._radii2: radii_2, 434 | }, 435 | ) 436 | 437 | 438 | def _batch_pairwise_distances(U, V): 439 | """ 440 | Compute pairwise distances between two batches of feature vectors. 441 | """ 442 | with tf.variable_scope("pairwise_dist_block"): 443 | # Squared norms of each row in U and V. 444 | norm_u = tf.reduce_sum(tf.square(U), 1) 445 | norm_v = tf.reduce_sum(tf.square(V), 1) 446 | 447 | # norm_u as a column and norm_v as a row vectors. 448 | norm_u = tf.reshape(norm_u, [-1, 1]) 449 | norm_v = tf.reshape(norm_v, [1, -1]) 450 | 451 | # Pairwise squared Euclidean distances. 452 | D = tf.maximum(norm_u - 2 * tf.matmul(U, V, False, True) + norm_v, 0.0) 453 | 454 | return D 455 | 456 | 457 | class NpzArrayReader(ABC): 458 | @abstractmethod 459 | def read_batch(self, batch_size: int) -> Optional[np.ndarray]: 460 | pass 461 | 462 | @abstractmethod 463 | def remaining(self) -> int: 464 | pass 465 | 466 | def read_batches(self, batch_size: int) -> Iterable[np.ndarray]: 467 | def gen_fn(): 468 | while True: 469 | batch = self.read_batch(batch_size) 470 | if batch is None: 471 | break 472 | yield batch 473 | 474 | rem = self.remaining() 475 | num_batches = rem // batch_size + int(rem % batch_size != 0) 476 | return BatchIterator(gen_fn, num_batches) 477 | 478 | 479 | class BatchIterator: 480 | def __init__(self, gen_fn, length): 481 | self.gen_fn = gen_fn 482 | self.length = length 483 | 484 | def __len__(self): 485 | return self.length 486 | 487 | def __iter__(self): 488 | return self.gen_fn() 489 | 490 | 491 | class StreamingNpzArrayReader(NpzArrayReader): 492 | def __init__(self, arr_f, shape, dtype): 493 | self.arr_f = arr_f 494 | self.shape = shape 495 | self.dtype = dtype 496 | self.idx = 0 497 | 498 | def read_batch(self, batch_size: int) -> Optional[np.ndarray]: 499 | if self.idx >= self.shape[0]: 500 | return None 501 | 502 | bs = min(batch_size, self.shape[0] - self.idx) 503 | self.idx += bs 504 | 505 | if self.dtype.itemsize == 0: 506 | return np.ndarray([bs, *self.shape[1:]], dtype=self.dtype) 507 | 508 | read_count = bs * np.prod(self.shape[1:]) 509 | read_size = int(read_count * self.dtype.itemsize) 510 | data = _read_bytes(self.arr_f, read_size, "array data") 511 | return np.frombuffer(data, dtype=self.dtype).reshape([bs, *self.shape[1:]]) 512 | 513 | def remaining(self) -> int: 514 | return max(0, self.shape[0] - self.idx) 515 | 516 | 517 | class MemoryNpzArrayReader(NpzArrayReader): 518 | def __init__(self, arr): 519 | self.arr = arr 520 | self.idx = 0 521 | 522 | @classmethod 523 | def load(cls, path: str, arr_name: str): 524 | with open(path, "rb") as f: 525 | arr = np.load(f)[arr_name] 526 | return cls(arr) 527 | 528 | def read_batch(self, batch_size: int) -> Optional[np.ndarray]: 529 | if self.idx >= self.arr.shape[0]: 530 | return None 531 | 532 | res = self.arr[self.idx : self.idx + batch_size] 533 | self.idx += batch_size 534 | return res 535 | 536 | def remaining(self) -> int: 537 | return max(0, self.arr.shape[0] - self.idx) 538 | 539 | 540 | @contextmanager 541 | def open_npz_array(path: str, arr_name: str) -> NpzArrayReader: 542 | with _open_npy_file(path, arr_name) as arr_f: 543 | version = np.lib.format.read_magic(arr_f) 544 | if version == (1, 0): 545 | header = np.lib.format.read_array_header_1_0(arr_f) 546 | elif version == (2, 0): 547 | header = np.lib.format.read_array_header_2_0(arr_f) 548 | else: 549 | yield MemoryNpzArrayReader.load(path, arr_name) 550 | return 551 | shape, fortran, dtype = header 552 | if fortran or dtype.hasobject: 553 | yield MemoryNpzArrayReader.load(path, arr_name) 554 | else: 555 | yield StreamingNpzArrayReader(arr_f, shape, dtype) 556 | 557 | 558 | def _read_bytes(fp, size, error_template="ran out of data"): 559 | """ 560 | Copied from: https://github.com/numpy/numpy/blob/fb215c76967739268de71aa4bda55dd1b062bc2e/numpy/lib/format.py#L788-L886 561 | 562 | Read from file-like object until size bytes are read. 563 | Raises ValueError if not EOF is encountered before size bytes are read. 564 | Non-blocking objects only supported if they derive from io objects. 565 | Required as e.g. ZipExtFile in python 2.6 can return less data than 566 | requested. 567 | """ 568 | data = bytes() 569 | while True: 570 | # io files (default in python3) return None or raise on 571 | # would-block, python2 file will truncate, probably nothing can be 572 | # done about that. note that regular files can't be non-blocking 573 | try: 574 | r = fp.read(size - len(data)) 575 | data += r 576 | if len(r) == 0 or len(data) == size: 577 | break 578 | except io.BlockingIOError: 579 | pass 580 | if len(data) != size: 581 | msg = "EOF: reading %s, expected %d bytes got %d" 582 | raise ValueError(msg % (error_template, size, len(data))) 583 | else: 584 | return data 585 | 586 | 587 | @contextmanager 588 | def _open_npy_file(path: str, arr_name: str): 589 | with open(path, "rb") as f: 590 | with zipfile.ZipFile(f, "r") as zip_f: 591 | if f"{arr_name}.npy" not in zip_f.namelist(): 592 | raise ValueError(f"missing {arr_name} in npz file") 593 | with zip_f.open(f"{arr_name}.npy", "r") as arr_f: 594 | yield arr_f 595 | 596 | 597 | def _download_inception_model(): 598 | if os.path.exists(INCEPTION_V3_PATH): 599 | return 600 | print("downloading InceptionV3 model...") 601 | with requests.get(INCEPTION_V3_URL, stream=True) as r: 602 | r.raise_for_status() 603 | tmp_path = INCEPTION_V3_PATH + ".tmp" 604 | with open(tmp_path, "wb") as f: 605 | for chunk in tqdm(r.iter_content(chunk_size=8192)): 606 | f.write(chunk) 607 | os.rename(tmp_path, INCEPTION_V3_PATH) 608 | 609 | 610 | def _create_feature_graph(input_batch): 611 | _download_inception_model() 612 | prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}" 613 | with open(INCEPTION_V3_PATH, "rb") as f: 614 | graph_def = tf.GraphDef() 615 | graph_def.ParseFromString(f.read()) 616 | pool3, spatial = tf.import_graph_def( 617 | graph_def, 618 | input_map={f"ExpandDims:0": input_batch}, 619 | return_elements=[FID_POOL_NAME, FID_SPATIAL_NAME], 620 | name=prefix, 621 | ) 622 | _update_shapes(pool3) 623 | spatial = spatial[..., :7] 624 | return pool3, spatial 625 | 626 | 627 | def _create_softmax_graph(input_batch): 628 | _download_inception_model() 629 | prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}" 630 | with open(INCEPTION_V3_PATH, "rb") as f: 631 | graph_def = tf.GraphDef() 632 | graph_def.ParseFromString(f.read()) 633 | (matmul,) = tf.import_graph_def( 634 | graph_def, return_elements=[f"softmax/logits/MatMul"], name=prefix 635 | ) 636 | w = matmul.inputs[1] 637 | logits = tf.matmul(input_batch, w) 638 | return tf.nn.softmax(logits) 639 | 640 | 641 | def _update_shapes(pool3): 642 | # https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L50-L63 643 | ops = pool3.graph.get_operations() 644 | for op in ops: 645 | for o in op.outputs: 646 | shape = o.get_shape() 647 | if shape._dims is not None: # pylint: disable=protected-access 648 | # shape = [s.value for s in shape] TF 1.x 649 | shape = [s for s in shape] # TF 2.x 650 | new_shape = [] 651 | for j, s in enumerate(shape): 652 | if s == 1 and j == 0: 653 | new_shape.append(None) 654 | else: 655 | new_shape.append(s) 656 | o.__dict__["_shape_val"] = tf.TensorShape(new_shape) 657 | return pool3 658 | 659 | 660 | def _numpy_partition(arr, kth, **kwargs): 661 | num_workers = min(cpu_count(), len(arr)) 662 | chunk_size = len(arr) // num_workers 663 | extra = len(arr) % num_workers 664 | 665 | start_idx = 0 666 | batches = [] 667 | for i in range(num_workers): 668 | size = chunk_size + (1 if i < extra else 0) 669 | batches.append(arr[start_idx : start_idx + size]) 670 | start_idx += size 671 | 672 | with ThreadPool(num_workers) as pool: 673 | return list(pool.map(partial(np.partition, kth=kth, **kwargs), batches)) 674 | 675 | 676 | if __name__ == "__main__": 677 | main() 678 | -------------------------------------------------------------------------------- /evaluations/inception_v3.py: -------------------------------------------------------------------------------- 1 | # Ported from the model here: 2 | # https://github.com/NVlabs/stylegan3/blob/407db86e6fe432540a22515310188288687858fa/metrics/frechet_inception_distance.py#L22 3 | # 4 | # I have verified that the spatial features and output features are correct 5 | # within a mean absolute error of ~3e-5. 6 | 7 | import collections 8 | 9 | import torch 10 | 11 | 12 | class Conv2dLayer(torch.nn.Module): 13 | def __init__(self, in_channels, out_channels, kh, kw, stride=1, padding=0): 14 | super().__init__() 15 | self.stride = stride 16 | self.padding = padding 17 | self.weight = torch.nn.Parameter(torch.zeros(out_channels, in_channels, kh, kw)) 18 | self.beta = torch.nn.Parameter(torch.zeros(out_channels)) 19 | self.mean = torch.nn.Parameter(torch.zeros(out_channels)) 20 | self.var = torch.nn.Parameter(torch.zeros(out_channels)) 21 | 22 | def forward(self, x): 23 | x = torch.nn.functional.conv2d( 24 | x, self.weight.to(x.dtype), stride=self.stride, padding=self.padding 25 | ) 26 | x = torch.nn.functional.batch_norm( 27 | x, running_mean=self.mean, running_var=self.var, bias=self.beta, eps=1e-3 28 | ) 29 | x = torch.nn.functional.relu(x) 30 | return x 31 | 32 | 33 | # ---------------------------------------------------------------------------- 34 | 35 | 36 | class InceptionA(torch.nn.Module): 37 | def __init__(self, in_channels, tmp_channels): 38 | super().__init__() 39 | self.conv = Conv2dLayer(in_channels, 64, kh=1, kw=1) 40 | self.tower = torch.nn.Sequential( 41 | collections.OrderedDict( 42 | [ 43 | ("conv", Conv2dLayer(in_channels, 48, kh=1, kw=1)), 44 | ("conv_1", Conv2dLayer(48, 64, kh=5, kw=5, padding=2)), 45 | ] 46 | ) 47 | ) 48 | self.tower_1 = torch.nn.Sequential( 49 | collections.OrderedDict( 50 | [ 51 | ("conv", Conv2dLayer(in_channels, 64, kh=1, kw=1)), 52 | ("conv_1", Conv2dLayer(64, 96, kh=3, kw=3, padding=1)), 53 | ("conv_2", Conv2dLayer(96, 96, kh=3, kw=3, padding=1)), 54 | ] 55 | ) 56 | ) 57 | self.tower_2 = torch.nn.Sequential( 58 | collections.OrderedDict( 59 | [ 60 | ( 61 | "pool", 62 | torch.nn.AvgPool2d( 63 | kernel_size=3, stride=1, padding=1, count_include_pad=False 64 | ), 65 | ), 66 | ("conv", Conv2dLayer(in_channels, tmp_channels, kh=1, kw=1)), 67 | ] 68 | ) 69 | ) 70 | 71 | def forward(self, x): 72 | return torch.cat( 73 | [ 74 | self.conv(x).contiguous(), 75 | self.tower(x).contiguous(), 76 | self.tower_1(x).contiguous(), 77 | self.tower_2(x).contiguous(), 78 | ], 79 | dim=1, 80 | ) 81 | 82 | 83 | # ---------------------------------------------------------------------------- 84 | 85 | 86 | class InceptionB(torch.nn.Module): 87 | def __init__(self, in_channels): 88 | super().__init__() 89 | self.conv = Conv2dLayer(in_channels, 384, kh=3, kw=3, stride=2) 90 | self.tower = torch.nn.Sequential( 91 | collections.OrderedDict( 92 | [ 93 | ("conv", Conv2dLayer(in_channels, 64, kh=1, kw=1)), 94 | ("conv_1", Conv2dLayer(64, 96, kh=3, kw=3, padding=1)), 95 | ("conv_2", Conv2dLayer(96, 96, kh=3, kw=3, stride=2)), 96 | ] 97 | ) 98 | ) 99 | self.pool = torch.nn.MaxPool2d(kernel_size=3, stride=2) 100 | 101 | def forward(self, x): 102 | return torch.cat( 103 | [ 104 | self.conv(x).contiguous(), 105 | self.tower(x).contiguous(), 106 | self.pool(x).contiguous(), 107 | ], 108 | dim=1, 109 | ) 110 | 111 | 112 | # ---------------------------------------------------------------------------- 113 | 114 | 115 | class InceptionC(torch.nn.Module): 116 | def __init__(self, in_channels, tmp_channels): 117 | super().__init__() 118 | self.conv = Conv2dLayer(in_channels, 192, kh=1, kw=1) 119 | self.tower = torch.nn.Sequential( 120 | collections.OrderedDict( 121 | [ 122 | ("conv", Conv2dLayer(in_channels, tmp_channels, kh=1, kw=1)), 123 | ( 124 | "conv_1", 125 | Conv2dLayer( 126 | tmp_channels, tmp_channels, kh=1, kw=7, padding=[0, 3] 127 | ), 128 | ), 129 | ( 130 | "conv_2", 131 | Conv2dLayer(tmp_channels, 192, kh=7, kw=1, padding=[3, 0]), 132 | ), 133 | ] 134 | ) 135 | ) 136 | self.tower_1 = torch.nn.Sequential( 137 | collections.OrderedDict( 138 | [ 139 | ("conv", Conv2dLayer(in_channels, tmp_channels, kh=1, kw=1)), 140 | ( 141 | "conv_1", 142 | Conv2dLayer( 143 | tmp_channels, tmp_channels, kh=7, kw=1, padding=[3, 0] 144 | ), 145 | ), 146 | ( 147 | "conv_2", 148 | Conv2dLayer( 149 | tmp_channels, tmp_channels, kh=1, kw=7, padding=[0, 3] 150 | ), 151 | ), 152 | ( 153 | "conv_3", 154 | Conv2dLayer( 155 | tmp_channels, tmp_channels, kh=7, kw=1, padding=[3, 0] 156 | ), 157 | ), 158 | ( 159 | "conv_4", 160 | Conv2dLayer(tmp_channels, 192, kh=1, kw=7, padding=[0, 3]), 161 | ), 162 | ] 163 | ) 164 | ) 165 | self.tower_2 = torch.nn.Sequential( 166 | collections.OrderedDict( 167 | [ 168 | ( 169 | "pool", 170 | torch.nn.AvgPool2d( 171 | kernel_size=3, stride=1, padding=1, count_include_pad=False 172 | ), 173 | ), 174 | ("conv", Conv2dLayer(in_channels, 192, kh=1, kw=1)), 175 | ] 176 | ) 177 | ) 178 | 179 | def forward(self, x): 180 | return torch.cat( 181 | [ 182 | self.conv(x).contiguous(), 183 | self.tower(x).contiguous(), 184 | self.tower_1(x).contiguous(), 185 | self.tower_2(x).contiguous(), 186 | ], 187 | dim=1, 188 | ) 189 | 190 | 191 | # ---------------------------------------------------------------------------- 192 | 193 | 194 | class InceptionD(torch.nn.Module): 195 | def __init__(self, in_channels): 196 | super().__init__() 197 | self.tower = torch.nn.Sequential( 198 | collections.OrderedDict( 199 | [ 200 | ("conv", Conv2dLayer(in_channels, 192, kh=1, kw=1)), 201 | ("conv_1", Conv2dLayer(192, 320, kh=3, kw=3, stride=2)), 202 | ] 203 | ) 204 | ) 205 | self.tower_1 = torch.nn.Sequential( 206 | collections.OrderedDict( 207 | [ 208 | ("conv", Conv2dLayer(in_channels, 192, kh=1, kw=1)), 209 | ("conv_1", Conv2dLayer(192, 192, kh=1, kw=7, padding=[0, 3])), 210 | ("conv_2", Conv2dLayer(192, 192, kh=7, kw=1, padding=[3, 0])), 211 | ("conv_3", Conv2dLayer(192, 192, kh=3, kw=3, stride=2)), 212 | ] 213 | ) 214 | ) 215 | self.pool = torch.nn.MaxPool2d(kernel_size=3, stride=2) 216 | 217 | def forward(self, x): 218 | return torch.cat( 219 | [ 220 | self.tower(x).contiguous(), 221 | self.tower_1(x).contiguous(), 222 | self.pool(x).contiguous(), 223 | ], 224 | dim=1, 225 | ) 226 | 227 | 228 | # ---------------------------------------------------------------------------- 229 | 230 | 231 | class InceptionE(torch.nn.Module): 232 | def __init__(self, in_channels, use_avg_pool): 233 | super().__init__() 234 | self.conv = Conv2dLayer(in_channels, 320, kh=1, kw=1) 235 | self.tower_conv = Conv2dLayer(in_channels, 384, kh=1, kw=1) 236 | self.tower_mixed_conv = Conv2dLayer(384, 384, kh=1, kw=3, padding=[0, 1]) 237 | self.tower_mixed_conv_1 = Conv2dLayer(384, 384, kh=3, kw=1, padding=[1, 0]) 238 | self.tower_1_conv = Conv2dLayer(in_channels, 448, kh=1, kw=1) 239 | self.tower_1_conv_1 = Conv2dLayer(448, 384, kh=3, kw=3, padding=1) 240 | self.tower_1_mixed_conv = Conv2dLayer(384, 384, kh=1, kw=3, padding=[0, 1]) 241 | self.tower_1_mixed_conv_1 = Conv2dLayer(384, 384, kh=3, kw=1, padding=[1, 0]) 242 | if use_avg_pool: 243 | self.tower_2_pool = torch.nn.AvgPool2d( 244 | kernel_size=3, stride=1, padding=1, count_include_pad=False 245 | ) 246 | else: 247 | self.tower_2_pool = torch.nn.MaxPool2d(kernel_size=3, stride=1, padding=1) 248 | self.tower_2_conv = Conv2dLayer(in_channels, 192, kh=1, kw=1) 249 | 250 | def forward(self, x): 251 | a = self.tower_conv(x) 252 | b = self.tower_1_conv_1(self.tower_1_conv(x)) 253 | return torch.cat( 254 | [ 255 | self.conv(x).contiguous(), 256 | self.tower_mixed_conv(a).contiguous(), 257 | self.tower_mixed_conv_1(a).contiguous(), 258 | self.tower_1_mixed_conv(b).contiguous(), 259 | self.tower_1_mixed_conv_1(b).contiguous(), 260 | self.tower_2_conv(self.tower_2_pool(x)).contiguous(), 261 | ], 262 | dim=1, 263 | ) 264 | 265 | 266 | # ---------------------------------------------------------------------------- 267 | 268 | 269 | class InceptionV3(torch.nn.Module): 270 | def __init__(self): 271 | super().__init__() 272 | self.layers = torch.nn.Sequential( 273 | collections.OrderedDict( 274 | [ 275 | ("conv", Conv2dLayer(3, 32, kh=3, kw=3, stride=2)), 276 | ("conv_1", Conv2dLayer(32, 32, kh=3, kw=3)), 277 | ("conv_2", Conv2dLayer(32, 64, kh=3, kw=3, padding=1)), 278 | ("pool0", torch.nn.MaxPool2d(kernel_size=3, stride=2)), 279 | ("conv_3", Conv2dLayer(64, 80, kh=1, kw=1)), 280 | ("conv_4", Conv2dLayer(80, 192, kh=3, kw=3)), 281 | ("pool1", torch.nn.MaxPool2d(kernel_size=3, stride=2)), 282 | ("mixed", InceptionA(192, tmp_channels=32)), 283 | ("mixed_1", InceptionA(256, tmp_channels=64)), 284 | ("mixed_2", InceptionA(288, tmp_channels=64)), 285 | ("mixed_3", InceptionB(288)), 286 | ("mixed_4", InceptionC(768, tmp_channels=128)), 287 | ("mixed_5", InceptionC(768, tmp_channels=160)), 288 | ("mixed_6", InceptionC(768, tmp_channels=160)), 289 | ("mixed_7", InceptionC(768, tmp_channels=192)), 290 | ("mixed_8", InceptionD(768)), 291 | ("mixed_9", InceptionE(1280, use_avg_pool=True)), 292 | ("mixed_10", InceptionE(2048, use_avg_pool=False)), 293 | ("pool2", torch.nn.AvgPool2d(kernel_size=8)), 294 | ] 295 | ) 296 | ) 297 | self.output = torch.nn.Linear(2048, 1008) 298 | 299 | def forward( 300 | self, 301 | img, 302 | return_features: bool = True, 303 | use_fp16: bool = False, 304 | no_output_bias: bool = False, 305 | ): 306 | batch_size, channels, height, width = img.shape # [NCHW] 307 | assert channels == 3 308 | 309 | # Cast to float. 310 | x = img.to(torch.float16 if use_fp16 else torch.float32) 311 | 312 | # Emulate tf.image.resize_bilinear(x, [299, 299]), including the funky alignment. 313 | new_width, new_height = 299, 299 314 | theta = torch.eye(2, 3, device=x.device) 315 | theta[0, 2] += theta[0, 0] / width - theta[0, 0] / new_width 316 | theta[1, 2] += theta[1, 1] / height - theta[1, 1] / new_height 317 | theta = theta.to(x.dtype).unsqueeze(0).repeat([batch_size, 1, 1]) 318 | grid = torch.nn.functional.affine_grid( 319 | theta, [batch_size, channels, new_height, new_width], align_corners=False 320 | ) 321 | x = torch.nn.functional.grid_sample( 322 | x, grid, mode="bilinear", padding_mode="border", align_corners=False 323 | ) 324 | 325 | # Scale dynamic range from [0,255] to [-1,1[. 326 | x -= 128 327 | x /= 128 328 | 329 | # Main layers. 330 | intermediate = self.layers[:-6](x) 331 | spatial_features = ( 332 | self.layers[-6] 333 | .conv(intermediate)[:, :7] 334 | .permute(0, 2, 3, 1) 335 | .reshape(-1, 2023) 336 | ) 337 | features = self.layers[-6:](intermediate).reshape(-1, 2048).to(torch.float32) 338 | if return_features: 339 | return features, spatial_features 340 | 341 | # Output layer. 342 | return self.acts_to_probs(features, no_output_bias=no_output_bias) 343 | 344 | def acts_to_probs(self, features, no_output_bias: bool = False): 345 | if no_output_bias: 346 | logits = torch.nn.functional.linear(features, self.output.weight) 347 | else: 348 | logits = self.output(features) 349 | probs = torch.nn.functional.softmax(logits, dim=1) 350 | return probs 351 | 352 | def create_softmax_model(self): 353 | return SoftmaxModel(self.output.weight) 354 | 355 | 356 | class SoftmaxModel(torch.nn.Module): 357 | def __init__(self, weight: torch.Tensor): 358 | super().__init__() 359 | self.weight = torch.nn.Parameter(weight.detach().clone()) 360 | 361 | def forward(self, x): 362 | logits = torch.nn.functional.linear(x, self.weight) 363 | probs = torch.nn.functional.softmax(logits, dim=1) 364 | return probs 365 | -------------------------------------------------------------------------------- /evaluations/requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow-gpu>=2.0 2 | scipy 3 | requests 4 | tqdm -------------------------------------------------------------------------------- /evaluations/th_evaluator.py: -------------------------------------------------------------------------------- 1 | from .inception_v3 import InceptionV3 2 | import blobfile as bf 3 | import torch 4 | import torch.distributed as dist 5 | import torch.nn as nn 6 | from cm import dist_util 7 | import numpy as np 8 | import warnings 9 | from scipy import linalg 10 | from PIL import Image 11 | from tqdm import tqdm 12 | 13 | 14 | def clip_preproc(preproc_fn, x): 15 | return preproc_fn(Image.fromarray(x.astype(np.uint8))) 16 | 17 | 18 | def all_gather(x, dim=0): 19 | xs = [torch.zeros_like(x) for _ in range(dist.get_world_size())] 20 | dist.all_gather(xs, x) 21 | return torch.cat(xs, dim=dim) 22 | 23 | 24 | class FIDStatistics: 25 | def __init__(self, mu: np.ndarray, sigma: np.ndarray, resolution: int): 26 | self.mu = mu 27 | self.sigma = sigma 28 | self.resolution = resolution 29 | 30 | def frechet_distance(self, other, eps=1e-6): 31 | """ 32 | Compute the Frechet distance between two sets of statistics. 33 | """ 34 | mu1, sigma1 = self.mu, self.sigma 35 | mu2, sigma2 = other.mu, other.sigma 36 | 37 | mu1 = np.atleast_1d(mu1) 38 | mu2 = np.atleast_1d(mu2) 39 | 40 | sigma1 = np.atleast_2d(sigma1) 41 | sigma2 = np.atleast_2d(sigma2) 42 | 43 | assert ( 44 | mu1.shape == mu2.shape 45 | ), f"Training and test mean vectors have different lengths: {mu1.shape}, {mu2.shape}" 46 | assert ( 47 | sigma1.shape == sigma2.shape 48 | ), f"Training and test covariances have different dimensions: {sigma1.shape}, {sigma2.shape}" 49 | 50 | diff = mu1 - mu2 51 | 52 | # product might be almost singular 53 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 54 | if not np.isfinite(covmean).all(): 55 | msg = ( 56 | "fid calculation produces singular product; adding %s to diagonal of cov estimates" 57 | % eps 58 | ) 59 | warnings.warn(msg) 60 | offset = np.eye(sigma1.shape[0]) * eps 61 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 62 | 63 | # numerical error might give slight imaginary component 64 | if np.iscomplexobj(covmean): 65 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 66 | m = np.max(np.abs(covmean.imag)) 67 | raise ValueError("Imaginary component {}".format(m)) 68 | covmean = covmean.real 69 | 70 | tr_covmean = np.trace(covmean) 71 | 72 | return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean 73 | 74 | 75 | class FIDAndIS: 76 | def __init__( 77 | self, 78 | softmax_batch_size=512, 79 | clip_score_batch_size=512, 80 | path="https://openaipublic.blob.core.windows.net/consistency/inception/inception-2015-12-05.pt", 81 | ): 82 | import clip 83 | 84 | super().__init__() 85 | 86 | self.softmax_batch_size = softmax_batch_size 87 | self.clip_score_batch_size = clip_score_batch_size 88 | self.inception = InceptionV3() 89 | with bf.BlobFile(path, "rb") as f: 90 | self.inception.load_state_dict(torch.load(f)) 91 | self.inception.eval() 92 | self.inception.to(dist_util.dev()) 93 | 94 | self.inception_softmax = self.inception.create_softmax_model() 95 | 96 | if dist.get_rank() % 8 == 0: 97 | clip_model, self.clip_preproc_fn = clip.load( 98 | "ViT-B/32", device=dist_util.dev() 99 | ) 100 | dist.barrier() 101 | if dist.get_rank() % 8 != 0: 102 | clip_model, self.clip_preproc_fn = clip.load( 103 | "ViT-B/32", device=dist_util.dev() 104 | ) 105 | dist.barrier() 106 | 107 | # Compute the probe features separately from the final projection. 108 | class ProjLayer(nn.Module): 109 | def __init__(self, param): 110 | super().__init__() 111 | self.param = param 112 | 113 | def forward(self, x): 114 | return x @ self.param 115 | 116 | self.clip_visual = clip_model.visual 117 | self.clip_proj = ProjLayer(self.clip_visual.proj) 118 | self.clip_visual.proj = None 119 | 120 | class TextModel(nn.Module): 121 | def __init__(self, clip_model): 122 | super().__init__() 123 | self.clip_model = clip_model 124 | 125 | def forward(self, x): 126 | return self.clip_model.encode_text(x) 127 | 128 | self.clip_tokenizer = lambda captions: clip.tokenize(captions, truncate=True) 129 | self.clip_text = TextModel(clip_model) 130 | self.clip_logit_scale = clip_model.logit_scale.exp().item() 131 | self.ref_features = {} 132 | self.is_root = not dist.is_initialized() or dist.get_rank() == 0 133 | 134 | def get_statistics(self, activations: np.ndarray, resolution: int): 135 | """ 136 | Compute activation statistics for a batch of images. 137 | 138 | :param activations: an [N x D] batch of activations. 139 | :return: an FIDStatistics object. 140 | """ 141 | mu = np.mean(activations, axis=0) 142 | sigma = np.cov(activations, rowvar=False) 143 | return FIDStatistics(mu, sigma, resolution) 144 | 145 | def get_preds(self, batch, captions=None): 146 | with torch.no_grad(): 147 | batch = 127.5 * (batch + 1) 148 | np_batch = batch.to(torch.uint8).cpu().numpy().transpose((0, 2, 3, 1)) 149 | 150 | pred, spatial_pred = self.inception(batch) 151 | pred, spatial_pred = pred.reshape( 152 | [pred.shape[0], -1] 153 | ), spatial_pred.reshape([spatial_pred.shape[0], -1]) 154 | 155 | clip_in = torch.stack( 156 | [clip_preproc(self.clip_preproc_fn, img) for img in np_batch] 157 | ) 158 | clip_pred = self.clip_visual(clip_in.half().to(dist_util.dev())) 159 | if captions is not None: 160 | text_in = self.clip_tokenizer(captions) 161 | text_pred = self.clip_text(text_in.to(dist_util.dev())) 162 | else: 163 | # Hack to easily deal with no captions 164 | text_pred = self.clip_proj(clip_pred.half()) 165 | text_pred = text_pred / text_pred.norm(dim=-1, keepdim=True) 166 | 167 | return pred, spatial_pred, clip_pred, text_pred, np_batch 168 | 169 | def get_inception_score( 170 | self, activations: np.ndarray, split_size: int = 5000 171 | ) -> float: 172 | """ 173 | Compute the inception score using a batch of activations. 174 | :param activations: an [N x D] batch of activations. 175 | :param split_size: the number of samples per split. This is used to 176 | make results consistent with other work, even when 177 | using a different number of samples. 178 | :return: an inception score estimate. 179 | """ 180 | softmax_out = [] 181 | for i in range(0, len(activations), self.softmax_batch_size): 182 | acts = activations[i : i + self.softmax_batch_size] 183 | with torch.no_grad(): 184 | softmax_out.append( 185 | self.inception_softmax(torch.from_numpy(acts).to(dist_util.dev())) 186 | .cpu() 187 | .numpy() 188 | ) 189 | preds = np.concatenate(softmax_out, axis=0) 190 | # https://github.com/openai/improved-gan/blob/4f5d1ec5c16a7eceb206f42bfc652693601e1d5c/inception_score/model.py#L46 191 | scores = [] 192 | for i in range(0, len(preds), split_size): 193 | part = preds[i : i + split_size] 194 | kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0))) 195 | kl = np.mean(np.sum(kl, 1)) 196 | scores.append(np.exp(kl)) 197 | return float(np.mean(scores)) 198 | 199 | def get_clip_score( 200 | self, activations: np.ndarray, text_features: np.ndarray 201 | ) -> float: 202 | # Sizes should never mismatch, but if they do we want to compute 203 | # _some_ value instead of crash looping. 204 | size = min(len(activations), len(text_features)) 205 | activations = activations[:size] 206 | text_features = text_features[:size] 207 | 208 | scores_out = [] 209 | for i in range(0, len(activations), self.clip_score_batch_size): 210 | acts = activations[i : i + self.clip_score_batch_size] 211 | sub_features = text_features[i : i + self.clip_score_batch_size] 212 | with torch.no_grad(): 213 | image_features = self.clip_proj( 214 | torch.from_numpy(acts).half().to(dist_util.dev()) 215 | ) 216 | image_features = image_features / image_features.norm( 217 | dim=-1, keepdim=True 218 | ) 219 | image_features = image_features.detach().cpu().float().numpy() 220 | scores_out.extend(np.sum(sub_features * image_features, axis=-1).tolist()) 221 | return np.mean(scores_out) * self.clip_logit_scale 222 | 223 | def get_activations(self, data, num_samples, global_batch_size, pr_samples=50000): 224 | if self.is_root: 225 | preds = [] 226 | spatial_preds = [] 227 | clip_preds = [] 228 | pr_images = [] 229 | 230 | for _ in tqdm(range(0, int(np.ceil(num_samples / global_batch_size)))): 231 | batch, cond, _ = next(data) 232 | batch, cond = batch.to(dist_util.dev()), { 233 | k: v.to(dist_util.dev()) for k, v in cond.items() 234 | } 235 | pred, spatial_pred, clip_pred, _, np_batch = self.get_preds(batch) 236 | pred, spatial_pred, clip_pred = ( 237 | all_gather(pred).cpu().numpy(), 238 | all_gather(spatial_pred).cpu().numpy(), 239 | all_gather(clip_pred).cpu().numpy(), 240 | ) 241 | if self.is_root: 242 | preds.append(pred) 243 | spatial_preds.append(spatial_pred) 244 | clip_preds.append(clip_pred) 245 | if len(pr_images) * np_batch.shape[0] < pr_samples: 246 | pr_images.append(np_batch) 247 | 248 | if self.is_root: 249 | preds, spatial_preds, clip_preds, pr_images = ( 250 | np.concatenate(preds, axis=0), 251 | np.concatenate(spatial_preds, axis=0), 252 | np.concatenate(clip_preds, axis=0), 253 | np.concatenate(pr_images, axis=0), 254 | ) 255 | # assert len(pr_images) >= pr_samples 256 | return ( 257 | preds[:num_samples], 258 | spatial_preds[:num_samples], 259 | clip_preds[:num_samples], 260 | pr_images[:pr_samples], 261 | ) 262 | else: 263 | return [], [], [], [] 264 | 265 | def get_virtual_batch(self, data, num_samples, global_batch_size, resolution): 266 | preds, spatial_preds, clip_preds, batch = self.get_activations( 267 | data, num_samples, global_batch_size, pr_samples=10000 268 | ) 269 | if self.is_root: 270 | fid_stats = self.get_statistics(preds, resolution) 271 | spatial_stats = self.get_statistics(spatial_preds, resolution) 272 | clip_stats = self.get_statistics(clip_preds, resolution) 273 | return batch, dict( 274 | mu=fid_stats.mu, 275 | sigma=fid_stats.sigma, 276 | mu_s=spatial_stats.mu, 277 | sigma_s=spatial_stats.sigma, 278 | mu_clip=clip_stats.mu, 279 | sigma_clip=clip_stats.sigma, 280 | ) 281 | else: 282 | return None, dict() 283 | 284 | def set_ref_batch(self, ref_batch): 285 | with bf.BlobFile(ref_batch, "rb") as f: 286 | data = np.load(f) 287 | fid_stats = FIDStatistics(mu=data["mu"], sigma=data["sigma"], resolution=-1) 288 | spatial_stats = FIDStatistics( 289 | mu=data["mu_s"], sigma=data["sigma_s"], resolution=-1 290 | ) 291 | clip_stats = FIDStatistics( 292 | mu=data["mu_clip"], sigma=data["sigma_clip"], resolution=-1 293 | ) 294 | 295 | self.ref_features[ref_batch] = (fid_stats, spatial_stats, clip_stats) 296 | 297 | def get_ref_batch(self, ref_batch): 298 | return self.ref_features[ref_batch] 299 | -------------------------------------------------------------------------------- /model-card.md: -------------------------------------------------------------------------------- 1 | # Overview 2 | 3 | These are diffusion models and consistency models described in the paper [Consistency Models](https://arxiv.org/abs/2303.01469). We include the following models in this release: 4 | 5 | * Consistency models trained by CD (with both l2 and LPIPS metrics) on ImageNet 64x64, LSUN Bedroom 256x256, and LSUN Cat 256x256. 6 | * Consistency models trained by CT on ImageNet 64x64, LSUN Bedroom 256x256, and LSUN Cat 256x256. 7 | 8 | # Datasets 9 | 10 | The models that we are making available have been trained on the [ILSVRC 2012 subset of ImageNet](http://www.image-net.org/challenges/LSVRC/2012/) or on individual categories from [LSUN](https://arxiv.org/abs/1506.03365). Here we outline the characteristics of these datasets that influence the behavior of the models: 11 | 12 | **ILSVRC 2012 subset of ImageNet**: This dataset was curated in 2012 and has around a million pictures, each of which belongs to one of 1,000 categories. A significant number of the categories in this dataset are animals, plants, and other naturally occurring objects. Although many photographs include humans, these humans are typically not represented by the class label (for example, the category "Tench, tinca tinca" includes many photographs of individuals holding fish). 13 | 14 | **LSUN**: This dataset was collected in 2015 by a combination of human labeling via Amazon Mechanical Turk and automated data labeling. Both classes that we consider have more than a million images. The dataset creators discovered that when assessed by trained experts, the label accuracy was approximately 90% throughout the entire LSUN dataset. The pictures are gathered from the internet, and those in the cat class often follow a "meme" format. Occasionally, people, including faces, appear in these photographs. 15 | 16 | 17 | # Performance 18 | 19 | These models are intended to generate samples consistent with their training distributions. 20 | This has been measured in terms of FID, Inception Score, Precision, and Recall. 21 | These metrics all rely on the representations of a [pre-trained Inception-V3 model](https://arxiv.org/abs/1512.00567), 22 | which was trained on ImageNet, and so is likely to focus more on the ImageNet classes (such as animals) than on other visual features (such as human faces). 23 | 24 | 25 | # Intended Use 26 | 27 | These models are intended to be used for research purposes only. In particular, they can be used as a baseline for generative modeling research, or as a starting point for advancing such research. These models are not intended to be commercially deployed. Additionally, they are not intended to be used to create propaganda or offensive imagery. 28 | 29 | # Limitations 30 | 31 | These models sometimes produce highly unrealistic outputs, particularly when generating images containing human faces. 32 | This may stem from ImageNet's emphasis on non-human objects. 33 | 34 | In consistency distillation and training, minimizing LPIPS results in better sample quality, as evidenced by improved FID and Inception scores. However, it also carries the risk of overestimating model performance, because LPIPS uses a VGG network pre-trained on ImageNet, while FID and Inception scores also rely on convolutional neural networks (the Inception network in particular) pre-trained on the same ImageNet dataset. Although these two convolutional neural networks do not share the same architecture and we extract latents from them in substantially different ways, knowledge leakage is still plausible which can undermine the fidelity of FID and Inception scores. 35 | 36 | Because ImageNet and LSUN contain images from the internet, they include photos of real people, and the model may have memorized some of the information contained in these photos. However, these images are already publicly available, and existing generative models trained on ImageNet have not demonstrated significant leakage of this information. 37 | -------------------------------------------------------------------------------- /scripts/cm_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train a diffusion model on images. 3 | """ 4 | 5 | import argparse 6 | 7 | from cm import dist_util, logger 8 | from cm.image_datasets import load_data 9 | from cm.resample import create_named_schedule_sampler 10 | from cm.script_util import ( 11 | model_and_diffusion_defaults, 12 | create_model_and_diffusion, 13 | cm_train_defaults, 14 | args_to_dict, 15 | add_dict_to_argparser, 16 | create_ema_and_scales_fn, 17 | ) 18 | from cm.train_util import CMTrainLoop 19 | import torch.distributed as dist 20 | import copy 21 | 22 | 23 | def main(): 24 | args = create_argparser().parse_args() 25 | 26 | dist_util.setup_dist() 27 | logger.configure() 28 | 29 | logger.log("creating model and diffusion...") 30 | ema_scale_fn = create_ema_and_scales_fn( 31 | target_ema_mode=args.target_ema_mode, 32 | start_ema=args.start_ema, 33 | scale_mode=args.scale_mode, 34 | start_scales=args.start_scales, 35 | end_scales=args.end_scales, 36 | total_steps=args.total_training_steps, 37 | distill_steps_per_iter=args.distill_steps_per_iter, 38 | ) 39 | if args.training_mode == "progdist": 40 | distillation = False 41 | elif "consistency" in args.training_mode: 42 | distillation = True 43 | else: 44 | raise ValueError(f"unknown training mode {args.training_mode}") 45 | 46 | model_and_diffusion_kwargs = args_to_dict( 47 | args, model_and_diffusion_defaults().keys() 48 | ) 49 | model_and_diffusion_kwargs["distillation"] = distillation 50 | model, diffusion = create_model_and_diffusion(**model_and_diffusion_kwargs) 51 | model.to(dist_util.dev()) 52 | model.train() 53 | if args.use_fp16: 54 | model.convert_to_fp16() 55 | 56 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion) 57 | 58 | logger.log("creating data loader...") 59 | if args.batch_size == -1: 60 | batch_size = args.global_batch_size // dist.get_world_size() 61 | if args.global_batch_size % dist.get_world_size() != 0: 62 | logger.log( 63 | f"warning, using smaller global_batch_size of {dist.get_world_size()*batch_size} instead of {args.global_batch_size}" 64 | ) 65 | else: 66 | batch_size = args.batch_size 67 | 68 | data = load_data( 69 | data_dir=args.data_dir, 70 | batch_size=batch_size, 71 | image_size=args.image_size, 72 | class_cond=args.class_cond, 73 | ) 74 | 75 | if len(args.teacher_model_path) > 0: # path to the teacher score model. 76 | logger.log(f"loading the teacher model from {args.teacher_model_path}") 77 | teacher_model_and_diffusion_kwargs = copy.deepcopy(model_and_diffusion_kwargs) 78 | teacher_model_and_diffusion_kwargs["dropout"] = args.teacher_dropout 79 | teacher_model_and_diffusion_kwargs["distillation"] = False 80 | teacher_model, teacher_diffusion = create_model_and_diffusion( 81 | **teacher_model_and_diffusion_kwargs, 82 | ) 83 | 84 | teacher_model.load_state_dict( 85 | dist_util.load_state_dict(args.teacher_model_path, map_location="cpu"), 86 | ) 87 | 88 | teacher_model.to(dist_util.dev()) 89 | teacher_model.eval() 90 | 91 | for dst, src in zip(model.parameters(), teacher_model.parameters()): 92 | dst.data.copy_(src.data) 93 | 94 | if args.use_fp16: 95 | teacher_model.convert_to_fp16() 96 | 97 | else: 98 | teacher_model = None 99 | teacher_diffusion = None 100 | 101 | # load the target model for distillation, if path specified. 102 | 103 | logger.log("creating the target model") 104 | target_model, _ = create_model_and_diffusion( 105 | **model_and_diffusion_kwargs, 106 | ) 107 | 108 | target_model.to(dist_util.dev()) 109 | target_model.train() 110 | 111 | dist_util.sync_params(target_model.parameters()) 112 | dist_util.sync_params(target_model.buffers()) 113 | 114 | for dst, src in zip(target_model.parameters(), model.parameters()): 115 | dst.data.copy_(src.data) 116 | 117 | if args.use_fp16: 118 | target_model.convert_to_fp16() 119 | 120 | logger.log("training...") 121 | CMTrainLoop( 122 | model=model, 123 | target_model=target_model, 124 | teacher_model=teacher_model, 125 | teacher_diffusion=teacher_diffusion, 126 | training_mode=args.training_mode, 127 | ema_scale_fn=ema_scale_fn, 128 | total_training_steps=args.total_training_steps, 129 | diffusion=diffusion, 130 | data=data, 131 | batch_size=batch_size, 132 | microbatch=args.microbatch, 133 | lr=args.lr, 134 | ema_rate=args.ema_rate, 135 | log_interval=args.log_interval, 136 | save_interval=args.save_interval, 137 | resume_checkpoint=args.resume_checkpoint, 138 | use_fp16=args.use_fp16, 139 | fp16_scale_growth=args.fp16_scale_growth, 140 | schedule_sampler=schedule_sampler, 141 | weight_decay=args.weight_decay, 142 | lr_anneal_steps=args.lr_anneal_steps, 143 | ).run_loop() 144 | 145 | 146 | def create_argparser(): 147 | defaults = dict( 148 | data_dir="", 149 | schedule_sampler="uniform", 150 | lr=1e-4, 151 | weight_decay=0.0, 152 | lr_anneal_steps=0, 153 | global_batch_size=2048, 154 | batch_size=-1, 155 | microbatch=-1, # -1 disables microbatches 156 | ema_rate="0.9999", # comma-separated list of EMA values 157 | log_interval=10, 158 | save_interval=10000, 159 | resume_checkpoint="", 160 | use_fp16=False, 161 | fp16_scale_growth=1e-3, 162 | ) 163 | defaults.update(model_and_diffusion_defaults()) 164 | defaults.update(cm_train_defaults()) 165 | parser = argparse.ArgumentParser() 166 | add_dict_to_argparser(parser, defaults) 167 | return parser 168 | 169 | 170 | if __name__ == "__main__": 171 | main() 172 | -------------------------------------------------------------------------------- /scripts/edm_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train a diffusion model on images. 3 | """ 4 | 5 | import argparse 6 | 7 | from cm import dist_util, logger 8 | from cm.image_datasets import load_data 9 | from cm.resample import create_named_schedule_sampler 10 | from cm.script_util import ( 11 | model_and_diffusion_defaults, 12 | create_model_and_diffusion, 13 | args_to_dict, 14 | add_dict_to_argparser, 15 | ) 16 | from cm.train_util import TrainLoop 17 | import torch.distributed as dist 18 | 19 | 20 | def main(): 21 | args = create_argparser().parse_args() 22 | 23 | dist_util.setup_dist() 24 | logger.configure() 25 | 26 | logger.log("creating model and diffusion...") 27 | model, diffusion = create_model_and_diffusion( 28 | **args_to_dict(args, model_and_diffusion_defaults().keys()) 29 | ) 30 | model.to(dist_util.dev()) 31 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion) 32 | 33 | logger.log("creating data loader...") 34 | if args.batch_size == -1: 35 | batch_size = args.global_batch_size // dist.get_world_size() 36 | if args.global_batch_size % dist.get_world_size() != 0: 37 | logger.log( 38 | f"warning, using smaller global_batch_size of {dist.get_world_size()*batch_size} instead of {args.global_batch_size}" 39 | ) 40 | else: 41 | batch_size = args.batch_size 42 | 43 | data = load_data( 44 | data_dir=args.data_dir, 45 | batch_size=batch_size, 46 | image_size=args.image_size, 47 | class_cond=args.class_cond, 48 | ) 49 | 50 | logger.log("creating data loader...") 51 | 52 | logger.log("training...") 53 | TrainLoop( 54 | model=model, 55 | diffusion=diffusion, 56 | data=data, 57 | batch_size=batch_size, 58 | microbatch=args.microbatch, 59 | lr=args.lr, 60 | ema_rate=args.ema_rate, 61 | log_interval=args.log_interval, 62 | save_interval=args.save_interval, 63 | resume_checkpoint=args.resume_checkpoint, 64 | use_fp16=args.use_fp16, 65 | fp16_scale_growth=args.fp16_scale_growth, 66 | schedule_sampler=schedule_sampler, 67 | weight_decay=args.weight_decay, 68 | lr_anneal_steps=args.lr_anneal_steps, 69 | ).run_loop() 70 | 71 | 72 | def create_argparser(): 73 | defaults = dict( 74 | data_dir="", 75 | schedule_sampler="uniform", 76 | lr=1e-4, 77 | weight_decay=0.0, 78 | lr_anneal_steps=0, 79 | global_batch_size=2048, 80 | batch_size=-1, 81 | microbatch=-1, # -1 disables microbatches 82 | ema_rate="0.9999", # comma-separated list of EMA values 83 | log_interval=10, 84 | save_interval=10000, 85 | resume_checkpoint="", 86 | use_fp16=False, 87 | fp16_scale_growth=1e-3, 88 | ) 89 | defaults.update(model_and_diffusion_defaults()) 90 | parser = argparse.ArgumentParser() 91 | add_dict_to_argparser(parser, defaults) 92 | return parser 93 | 94 | 95 | if __name__ == "__main__": 96 | main() 97 | -------------------------------------------------------------------------------- /scripts/image_sample.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate a large batch of image samples from a model and save them as a large 3 | numpy array. This can be used to produce samples for FID evaluation. 4 | """ 5 | 6 | import argparse 7 | import os 8 | 9 | import numpy as np 10 | import torch as th 11 | import torch.distributed as dist 12 | 13 | from cm import dist_util, logger 14 | from cm.script_util import ( 15 | NUM_CLASSES, 16 | model_and_diffusion_defaults, 17 | create_model_and_diffusion, 18 | add_dict_to_argparser, 19 | args_to_dict, 20 | ) 21 | from cm.random_util import get_generator 22 | from cm.karras_diffusion import karras_sample 23 | 24 | 25 | def main(): 26 | args = create_argparser().parse_args() 27 | 28 | dist_util.setup_dist() 29 | logger.configure() 30 | 31 | if "consistency" in args.training_mode: 32 | distillation = True 33 | else: 34 | distillation = False 35 | 36 | logger.log("creating model and diffusion...") 37 | model, diffusion = create_model_and_diffusion( 38 | **args_to_dict(args, model_and_diffusion_defaults().keys()), 39 | distillation=distillation, 40 | ) 41 | model.load_state_dict( 42 | dist_util.load_state_dict(args.model_path, map_location="cpu") 43 | ) 44 | model.to(dist_util.dev()) 45 | if args.use_fp16: 46 | model.convert_to_fp16() 47 | model.eval() 48 | 49 | logger.log("sampling...") 50 | if args.sampler == "multistep": 51 | assert len(args.ts) > 0 52 | ts = tuple(int(x) for x in args.ts.split(",")) 53 | else: 54 | ts = None 55 | 56 | all_images = [] 57 | all_labels = [] 58 | generator = get_generator(args.generator, args.num_samples, args.seed) 59 | 60 | while len(all_images) * args.batch_size < args.num_samples: 61 | model_kwargs = {} 62 | if args.class_cond: 63 | classes = th.randint( 64 | low=0, high=NUM_CLASSES, size=(args.batch_size,), device=dist_util.dev() 65 | ) 66 | model_kwargs["y"] = classes 67 | 68 | sample = karras_sample( 69 | diffusion, 70 | model, 71 | (args.batch_size, 3, args.image_size, args.image_size), 72 | steps=args.steps, 73 | model_kwargs=model_kwargs, 74 | device=dist_util.dev(), 75 | clip_denoised=args.clip_denoised, 76 | sampler=args.sampler, 77 | sigma_min=args.sigma_min, 78 | sigma_max=args.sigma_max, 79 | s_churn=args.s_churn, 80 | s_tmin=args.s_tmin, 81 | s_tmax=args.s_tmax, 82 | s_noise=args.s_noise, 83 | generator=generator, 84 | ts=ts, 85 | ) 86 | sample = ((sample + 1) * 127.5).clamp(0, 255).to(th.uint8) 87 | sample = sample.permute(0, 2, 3, 1) 88 | sample = sample.contiguous() 89 | 90 | gathered_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())] 91 | dist.all_gather(gathered_samples, sample) # gather not supported with NCCL 92 | all_images.extend([sample.cpu().numpy() for sample in gathered_samples]) 93 | if args.class_cond: 94 | gathered_labels = [ 95 | th.zeros_like(classes) for _ in range(dist.get_world_size()) 96 | ] 97 | dist.all_gather(gathered_labels, classes) 98 | all_labels.extend([labels.cpu().numpy() for labels in gathered_labels]) 99 | logger.log(f"created {len(all_images) * args.batch_size} samples") 100 | 101 | arr = np.concatenate(all_images, axis=0) 102 | arr = arr[: args.num_samples] 103 | if args.class_cond: 104 | label_arr = np.concatenate(all_labels, axis=0) 105 | label_arr = label_arr[: args.num_samples] 106 | if dist.get_rank() == 0: 107 | shape_str = "x".join([str(x) for x in arr.shape]) 108 | out_path = os.path.join(logger.get_dir(), f"samples_{shape_str}.npz") 109 | logger.log(f"saving to {out_path}") 110 | if args.class_cond: 111 | np.savez(out_path, arr, label_arr) 112 | else: 113 | np.savez(out_path, arr) 114 | 115 | dist.barrier() 116 | logger.log("sampling complete") 117 | 118 | 119 | def create_argparser(): 120 | defaults = dict( 121 | training_mode="edm", 122 | generator="determ", 123 | clip_denoised=True, 124 | num_samples=10000, 125 | batch_size=16, 126 | sampler="heun", 127 | s_churn=0.0, 128 | s_tmin=0.0, 129 | s_tmax=float("inf"), 130 | s_noise=1.0, 131 | steps=40, 132 | model_path="", 133 | seed=42, 134 | ts="", 135 | ) 136 | defaults.update(model_and_diffusion_defaults()) 137 | parser = argparse.ArgumentParser() 138 | add_dict_to_argparser(parser, defaults) 139 | return parser 140 | 141 | 142 | if __name__ == "__main__": 143 | main() 144 | -------------------------------------------------------------------------------- /scripts/launch.sh: -------------------------------------------------------------------------------- 1 | #################################################################### 2 | # Training EDM models on class-conditional ImageNet-64, and LSUN 256 3 | #################################################################### 4 | 5 | mpiexec -n 8 python edm_train.py --attention_resolutions 32,16,8 --class_cond True --use_scale_shift_norm True --dropout 0.1 --ema_rate 0.999,0.9999,0.9999432189950708 --global_batch_size 4096 --image_size 64 --lr 0.0001 --num_channels 192 --num_head_channels 64 --num_res_blocks 3 --resblock_updown True --schedule_sampler lognormal --use_fp16 True --weight_decay 0.0 --weight_schedule karras --data_dir /path/to/imagenet 6 | 7 | python -m orc.diffusion.scripts.train_imagenet_edm --attention_resolutions 32,16,8 --class_cond False --dropout 0.1 --ema_rate 0.999,0.9999,0.9999432189950708 --global_batch_size 256 --image_size 256 --lr 0.0001 --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --schedule_sampler lognormal --use_fp16 True --use_scale_shift_norm False --weight_decay 0.0 --weight_schedule karras --data_dir /path/to/lsun_bedroom 8 | 9 | ######################################################################### 10 | # Sampling from EDM models on class-conditional ImageNet-64, and LSUN 256 11 | ######################################################################### 12 | 13 | mpiexec -n 8 python image_sample.py --training_mode edm --batch_size 64 --sigma_max 80 --sigma_min 0.002 --s_churn 0 --steps 40 --sampler heun --model_path edm_imagenet64_ema.pt --attention_resolutions 32,16,8 --class_cond True --dropout 0.1 --image_size 64 --num_channels 192 --num_head_channels 64 --num_res_blocks 3 --num_samples 50000 --resblock_updown True --use_fp16 True --use_scale_shift_norm True --weight_schedule karras 14 | 15 | mpiexec -n 8 python image_sample.py --training_mode edm --generator determ-indiv --batch_size 8 --sigma_max 80 --sigma_min 0.002 --s_churn 0 --steps 40 --sampler heun --model_path /path/to/edm_bedroom256_ema.pt --attention_resolutions 32,16,8 --class_cond False --dropout 0.1 --image_size 256 --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --num_samples 50000 --resblock_updown True --use_fp16 True --use_scale_shift_norm False --weight_schedule karras 16 | 17 | ######################################################################### 18 | # Consistency distillation on class-conditional ImageNet-64, and LSUN 256 19 | ######################################################################### 20 | 21 | ## L_CD^N (l2) on ImageNet-64 22 | mpiexec -n 8 python cm_train.py --training_mode consistency_distillation --target_ema_mode fixed --start_ema 0.95 --scale_mode fixed --start_scales 40 --total_training_steps 600000 --loss_norm l2 --lr_anneal_steps 0 --teacher_model_path /path/to/edm_imagenet64_ema.pt --attention_resolutions 32,16,8 --class_cond True --use_scale_shift_norm True --dropout 0.0 --teacher_dropout 0.1 --ema_rate 0.999,0.9999,0.9999432189950708 --global_batch_size 2048 --image_size 64 --lr 0.000008 --num_channels 192 --num_head_channels 64 --num_res_blocks 3 --resblock_updown True --schedule_sampler uniform --use_fp16 True --weight_decay 0.0 --weight_schedule uniform --data_dir /path/to/data 23 | 24 | ## L_CD^N (LPIPS) on ImageNet-64 25 | mpiexec -n 8 python cm_train.py --training_mode consistency_distillation --target_ema_mode fixed --start_ema 0.95 --scale_mode fixed --start_scales 40 --total_training_steps 600000 --loss_norm lpips --lr_anneal_steps 0 --teacher_model_path /path/to/edm_imagenet64_ema.pt --attention_resolutions 32,16,8 --class_cond True --use_scale_shift_norm True --dropout 0.0 --teacher_dropout 0.1 --ema_rate 0.999,0.9999,0.9999432189950708 --global_batch_size 2048 --image_size 64 --lr 0.000008 --num_channels 192 --num_head_channels 64 --num_res_blocks 3 --resblock_updown True --schedule_sampler uniform --use_fp16 True --weight_decay 0.0 --weight_schedule uniform --data_dir /path/to/data 26 | 27 | ## L_CD^N (l2) on LSUN 256 28 | mpiexec -n 8 python cm_train.py --training_mode consistency_distillation --sigma_max 80 --sigma_min 0.002 --target_ema_mode fixed --start_ema 0.95 --scale_mode fixed --start_scales 40 --total_training_steps 600000 --loss_norm l2 --lr_anneal_steps 0 --teacher_model_path /path/to/edm_bedroom256_ema.pt --attention_resolutions 32,16,8 --class_cond False --use_scale_shift_norm False --dropout 0.0 --teacher_dropout 0.1 --ema_rate 0.9999,0.99994,0.9999432189950708 --global_batch_size 256 --image_size 256 --lr 0.00001 --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --schedule_sampler uniform --use_fp16 True --weight_decay 0.0 --weight_schedule uniform --data_dir /path/to/bedroom256 29 | 30 | ## L_CD^N (LPIPS) on LSUN 256 31 | mpiexec -n 8 python cm_train.py --training_mode consistency_distillation --sigma_max 80 --sigma_min 0.002 --target_ema_mode fixed --start_ema 0.95 --scale_mode fixed --start_scales 40 --total_training_steps 600000 --loss_norm lpips --lr_anneal_steps 0 --teacher_model_path /path/to/edm_bedroom256_ema.pt --attention_resolutions 32,16,8 --class_cond False --use_scale_shift_norm False --dropout 0.0 --teacher_dropout 0.1 --ema_rate 0.9999,0.99994,0.9999432189950708 --global_batch_size 256 --image_size 256 --lr 0.00001 --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --schedule_sampler uniform --use_fp16 True --weight_decay 0.0 --weight_schedule uniform --data_dir /path/to/bedroom256 32 | 33 | ######################################################################### 34 | # Consistency training on class-conditional ImageNet-64, and LSUN 256 35 | ######################################################################### 36 | 37 | ## L_CT^N on ImageNet-64 38 | mpiexec -n 8 python cm_train.py --training_mode consistency_training --target_ema_mode adaptive --start_ema 0.95 --scale_mode progressive --start_scales 2 --end_scales 200 --total_training_steps 800000 --loss_norm lpips --lr_anneal_steps 0 --teacher_model_path /path/to/edm_imagenet64_ema.pt --attention_resolutions 32,16,8 --class_cond True --use_scale_shift_norm True --dropout 0.0 --teacher_dropout 0.1 --ema_rate 0.999,0.9999,0.9999432189950708 --global_batch_size 2048 --image_size 64 --lr 0.0001 --num_channels 192 --num_head_channels 64 --num_res_blocks 3 --resblock_updown True --schedule_sampler uniform --use_fp16 True --weight_decay 0.0 --weight_schedule uniform --data_dir /path/to/imagenet64 39 | 40 | ## L_CT^N on LSUN 256 41 | mpiexec -n 8 python cm_train.py --training_mode consistency_training --target_ema_mode adaptive --start_ema 0.95 --scale_mode progressive --start_scales 2 --end_scales 150 --total_training_steps 1000000 --loss_norm lpips --lr_anneal_steps 0 --teacher_model_path /path/to/edm_bedroom256_ema.pt --attention_resolutions 32,16,8 --class_cond False --use_scale_shift_norm False --dropout 0.0 --teacher_dropout 0.1 --ema_rate 0.9999,0.99994,0.9999432189950708 --global_batch_size 256 --image_size 256 --lr 0.00005 --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --schedule_sampler uniform --use_fp16 True --weight_decay 0.0 --weight_schedule uniform --data_dir /path/to/bedroom256 42 | 43 | ################################################################################# 44 | # Sampling from consistency models on class-conditional ImageNet-64, and LSUN 256 45 | ################################################################################# 46 | 47 | ## ImageNet-64 48 | mpiexec -n 8 python image_sample.py --batch_size 256 --training_mode consistency_distillation --sampler onestep --model_path /path/to/checkpoint --attention_resolutions 32,16,8 --class_cond True --use_scale_shift_norm True --dropout 0.0 --image_size 64 --num_channels 192 --num_head_channels 64 --num_res_blocks 3 --num_samples 500 --resblock_updown True --use_fp16 True --weight_schedule uniform 49 | 50 | ## LSUN-256 51 | mpiexec -n 8 python image_sample.py --batch_size 32 --generator determ-indiv --training_mode consistency_distillation --sampler onestep --model_path /root/consistency/ct_bedroom256.pt --attention_resolutions 32,16,8 --class_cond False --use_scale_shift_norm False --dropout 0.0 --image_size 256 --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --num_samples 100 --resblock_updown True --use_fp16 True --weight_schedule uniform 52 | 53 | ###################################################################################### 54 | # Tenary search for multi-step sampling on class-conditional ImageNet-64, and LSUN 256 55 | ###################################################################################### 56 | 57 | ## CD on ImageNet-64 58 | mpiexec -n 8 python ternary_search.py --begin 0 --end 39 --steps 40 --generator determ --ref_batch /root/consistency/ref_batches/imagenet64.npz --batch_size 256 --model_path /root/consistency/cd_imagenet64_lpips.pt --attention_resolutions 32,16,8 --class_cond True --use_scale_shift_norm True --dropout 0.0 --image_size 64 --num_channels 192 --num_head_channels 64 --num_res_blocks 3 --num_samples 50000 --resblock_updown True --use_fp16 True --weight_schedule uniform 59 | 60 | ## CT on ImageNet-64 61 | mpiexec -n 8 python ternary_search.py --begin 0 --end 200 --steps 201 --generator determ --ref_batch /root/consistency/ref_batches/imagenet64.npz --batch_size 256 --model_path /root/consistency/ct_imagenet64.pt --attention_resolutions 32,16,8 --class_cond True --use_scale_shift_norm True --dropout 0.0 --image_size 64 --num_channels 192 --num_head_channels 64 --num_res_blocks 3 --num_samples 50000 --resblock_updown True --use_fp16 True --weight_schedule uniform 62 | 63 | ## CD on LSUN-256 64 | mpiexec -n 8 python ternary_search.py --begin 0 --end 39 --steps 40 --generator determ-indiv --ref_batch /root/consistency/ref_batches/bedroom256.npz --batch_size 32 --model_path /root/consistency/cd_bedroom256_lpips.pt --attention_resolutions 32,16,8 --class_cond False --use_scale_shift_norm False --dropout 0.0 --image_size 256 --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --num_samples 50000 --resblock_updown True --use_fp16 True --weight_schedule uniform 65 | 66 | ## CT on LSUN-256 67 | mpiexec -n 8 python ternary_search.py --begin 0 --end 150 --steps 151 --generator determ-indiv --ref_batch /root/consistency/ref_batches/bedroom256.npz --batch_size 32 --model_path /root/consistency/ct_bedroom256.pt --attention_resolutions 32,16,8 --class_cond False --use_scale_shift_norm False --dropout 0.0 --image_size 256 --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --num_samples 50000 --resblock_updown True --use_fp16 True --weight_schedule uniform 68 | 69 | ################################################################### 70 | # Multistep sampling on class-conditional ImageNet-64, and LSUN 256 71 | ################################################################### 72 | 73 | ## Two-step sampling for CD (LPIPS) on ImageNet-64 74 | mpiexec -n 8 python image_sample.py --batch_size 256 --training_mode consistency_distillation --sampler multistep --ts 0,22,39 --steps 40 --model_path /path/to/cd_imagenet64_lpips.pt --attention_resolutions 32,16,8 --class_cond True --use_scale_shift_norm True --dropout 0.0 --image_size 64 --num_channels 192 --num_head_channels 64 --num_res_blocks 3 --num_samples 500 --resblock_updown True --use_fp16 True --weight_schedule uniform 75 | 76 | ## Two-step sampling for CD (L2) on ImageNet-64 77 | mpiexec -n 8 python image_sample.py --batch_size 256 --training_mode consistency_distillation --sampler multistep --ts 0,22,39 --steps 40 --model_path /path/to/cd_imagenet64_l2.pt --attention_resolutions 32,16,8 --class_cond True --use_scale_shift_norm True --dropout 0.0 --image_size 64 --num_channels 192 --num_head_channels 64 --num_res_blocks 3 --num_samples 500 --resblock_updown True --use_fp16 True --weight_schedule uniform 78 | 79 | ## Two-step sampling for CT on ImageNet-64 80 | mpiexec -n 8 python image_sample.py --batch_size 256 --training_mode consistency_training --sampler multistep --ts 0,106,200 --steps 201 --model_path /path/to/ct_imagenet64.pt --attention_resolutions 32,16,8 --class_cond True --use_scale_shift_norm True --dropout 0.0 --image_size 64 --num_channels 192 --num_head_channels 64 --num_res_blocks 3 --num_samples 500 --resblock_updown True --use_fp16 True --weight_schedule uniform 81 | 82 | ## Two-step sampling for CD (LPIPS) on LSUN-256 83 | mpiexec -n 8 python image_sample.py --batch_size 32 --training_mode consistency_distillation --sampler multistep --ts 0,17,39 --steps 40 --model_path /path/to/cd_bedroom256_lpips.pt --attention_resolutions 32,16,8 --class_cond False --use_scale_shift_norm False --dropout 0.0 --image_size 256 --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --num_samples 500 --resblock_updown True --use_fp16 True --weight_schedule uniform 84 | 85 | ## Two-step sampling for CD (l2) on LSUN-256 86 | mpiexec -n 8 python image_sample.py --batch_size 32 --training_mode consistency_distillation --sampler multistep --ts 0,18,39 --steps 40 --model_path /path/to/cd_bedroom256_l2.pt --attention_resolutions 32,16,8 --class_cond False --use_scale_shift_norm False --dropout 0.0 --image_size 256 --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --num_samples 500 --resblock_updown True --use_fp16 True --weight_schedule uniform 87 | 88 | ## Two-step sampling for CT on LSUN Bedroom-256 89 | mpiexec -n 8 python image_sample.py --batch_size 32 --training_mode consistency_distillation --sampler multistep --ts 0,67,150 --steps 151 --model_path /path/to/ct_bedroom256.pt --attention_resolutions 32,16,8 --class_cond False --use_scale_shift_norm False --dropout 0.0 --image_size 256 --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --num_samples 500 --resblock_updown True --use_fp16 True --weight_schedule uniform 90 | 91 | ## Two-step sampling for CT on LSUN Cat-256 92 | mpiexec -n 8 python image_sample.py --batch_size 32 --training_mode consistency_distillation --sampler multistep --ts 0,62,150 --steps 151 --model_path /path/to/ct_cat256.pt --attention_resolutions 32,16,8 --class_cond False --use_scale_shift_norm False --dropout 0.0 --image_size 256 --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --num_samples 500 --resblock_updown True --use_fp16 True --weight_schedule uniform 93 | -------------------------------------------------------------------------------- /scripts/ternary_search.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate a large batch of image samples from a model and save them as a large 3 | numpy array. This can be used to produce samples for FID evaluation. 4 | """ 5 | 6 | import argparse 7 | import os 8 | 9 | import numpy as np 10 | import torch as th 11 | import torch.distributed as dist 12 | from functools import cache 13 | from mpi4py import MPI 14 | 15 | from cm import dist_util, logger 16 | from cm.script_util import ( 17 | NUM_CLASSES, 18 | model_and_diffusion_defaults, 19 | create_model_and_diffusion, 20 | add_dict_to_argparser, 21 | args_to_dict, 22 | ) 23 | from cm.random_util import get_generator 24 | from cm.karras_diffusion import stochastic_iterative_sampler 25 | from evaluations.th_evaluator import FIDAndIS 26 | 27 | 28 | def main(): 29 | args = create_argparser().parse_args() 30 | 31 | dist_util.setup_dist() 32 | logger.configure() 33 | 34 | if "consistency" in args.training_mode: 35 | distillation = True 36 | else: 37 | distillation = False 38 | 39 | logger.log("creating model and diffusion...") 40 | model, diffusion = create_model_and_diffusion( 41 | **args_to_dict(args, model_and_diffusion_defaults().keys()), 42 | distillation=distillation, 43 | ) 44 | model.load_state_dict( 45 | dist_util.load_state_dict(args.model_path, map_location="cpu") 46 | ) 47 | model.to(dist_util.dev()) 48 | if args.use_fp16: 49 | model.convert_to_fp16() 50 | model.eval() 51 | 52 | fid_is = FIDAndIS() 53 | fid_is.set_ref_batch(args.ref_batch) 54 | ( 55 | ref_fid_stats, 56 | ref_spatial_stats, 57 | ref_clip_stats, 58 | ) = fid_is.get_ref_batch(args.ref_batch) 59 | 60 | def sample_generator(ts): 61 | logger.log("sampling...") 62 | all_images = [] 63 | all_labels = [] 64 | all_preds = [] 65 | 66 | generator = get_generator(args.generator, args.num_samples, args.seed) 67 | while len(all_images) * args.batch_size < args.num_samples: 68 | model_kwargs = {} 69 | if args.class_cond: 70 | classes = th.randint( 71 | low=0, 72 | high=NUM_CLASSES, 73 | size=(args.batch_size,), 74 | device=dist_util.dev(), 75 | ) 76 | model_kwargs["y"] = classes 77 | 78 | def denoiser(x_t, sigma): 79 | _, denoised = diffusion.denoise(model, x_t, sigma, **model_kwargs) 80 | if args.clip_denoised: 81 | denoised = denoised.clamp(-1, 1) 82 | return denoised 83 | 84 | x_T = ( 85 | generator.randn( 86 | *(args.batch_size, 3, args.image_size, args.image_size), 87 | device=dist_util.dev(), 88 | ) 89 | * args.sigma_max 90 | ) 91 | 92 | sample = stochastic_iterative_sampler( 93 | denoiser, 94 | x_T, 95 | ts, 96 | t_min=args.sigma_min, 97 | t_max=args.sigma_max, 98 | rho=diffusion.rho, 99 | steps=args.steps, 100 | generator=generator, 101 | ) 102 | pred, spatial_pred, clip_pred, text_pred, _ = fid_is.get_preds(sample) 103 | 104 | sample = ((sample + 1) * 127.5).clamp(0, 255).to(th.uint8) 105 | sample = sample.permute(0, 2, 3, 1) 106 | sample = sample.contiguous() 107 | 108 | gathered_samples = [ 109 | th.zeros_like(sample) for _ in range(dist.get_world_size()) 110 | ] 111 | gathered_preds = [th.zeros_like(pred) for _ in range(dist.get_world_size())] 112 | dist.all_gather(gathered_samples, sample) # gather not supported with NCCL 113 | dist.all_gather(gathered_preds, pred) 114 | all_images.extend([sample.cpu().numpy() for sample in gathered_samples]) 115 | all_preds.extend([pred.cpu().numpy() for pred in gathered_preds]) 116 | if args.class_cond: 117 | gathered_labels = [ 118 | th.zeros_like(classes) for _ in range(dist.get_world_size()) 119 | ] 120 | dist.all_gather(gathered_labels, classes) 121 | all_labels.extend([labels.cpu().numpy() for labels in gathered_labels]) 122 | 123 | logger.log(f"created {len(all_images) * args.batch_size} samples") 124 | 125 | arr = np.concatenate(all_images, axis=0) 126 | arr = arr[: args.num_samples] 127 | preds = np.concatenate(all_preds, axis=0) 128 | preds = preds[: args.num_samples] 129 | if args.class_cond: 130 | label_arr = np.concatenate(all_labels, axis=0) 131 | label_arr = label_arr[: args.num_samples] 132 | 133 | dist.barrier() 134 | logger.log("sampling complete") 135 | 136 | return arr, preds 137 | 138 | @cache 139 | def get_fid(p, begin=(0,), end=(args.steps - 1,)): 140 | 141 | samples, preds = sample_generator(begin + (p,) + end) 142 | is_root = dist.get_rank() == 0 143 | if is_root: 144 | fid_stats = fid_is.get_statistics(preds, -1) 145 | fid = ref_fid_stats.frechet_distance(fid_stats) 146 | fid = MPI.COMM_WORLD.bcast(fid) 147 | # spatial_stats = fid_is.get_statistics(spatial_preds, -1) 148 | # sfid = ref_spatial_stats.frechet_distance(spatial_stats) 149 | # clip_stats = fid_is.get_statistics(clip_preds, -1) 150 | IS = fid_is.get_inception_score(preds) 151 | IS = MPI.COMM_WORLD.bcast(IS) 152 | # clip_fid = fid_is.get_clip_score(clip_preds, text_preds) 153 | # fcd = ref_clip_stats.frechet_distance(clip_stats) 154 | else: 155 | fid = MPI.COMM_WORLD.bcast(None) 156 | IS = MPI.COMM_WORLD.bcast(None) 157 | 158 | dist.barrier() 159 | return fid, IS 160 | 161 | def ternary_search(before=(0,), after=(17,)): 162 | left = before[-1] 163 | right = after[0] 164 | is_root = dist.get_rank() == 0 165 | while right - left >= 3: 166 | m1 = int(left + (right - left) / 3.0) 167 | m2 = int(right - (right - left) / 3.0) 168 | f1, is1 = get_fid(m1, before, after) 169 | if is_root: 170 | logger.log(f"fid at m1 = {m1} is {f1}, IS is {is1}") 171 | f2, is2 = get_fid(m2, before, after) 172 | if is_root: 173 | logger.log(f"fid at m2 = {m2} is {f2}, IS is {is2}") 174 | if f1 < f2: 175 | right = m2 176 | else: 177 | left = m1 178 | if is_root: 179 | logger.log(f"new interval is [{left}, {right}]") 180 | 181 | if right == left: 182 | p = right 183 | elif right - left == 1: 184 | f1, _ = get_fid(left, before, after) 185 | f2, _ = get_fid(right, before, after) 186 | p = m1 if f1 < f2 else m2 187 | elif right - left == 2: 188 | mid = left + 1 189 | f1, _ = get_fid(left, before, after) 190 | f2, _ = get_fid(right, before, after) 191 | fmid, ismid = get_fid(mid, before, after) 192 | if is_root: 193 | logger.log(f"fmid at mid = {mid} is {fmid}, ISmid is {ismid}") 194 | if fmid < f1 and fmid < f2: 195 | p = mid 196 | elif f1 < f2: 197 | p = m1 198 | else: 199 | p = m2 200 | 201 | return p 202 | 203 | # convert comma separated numbers to tuples 204 | begin = tuple(int(x) for x in args.begin.split(",")) 205 | end = tuple(int(x) for x in args.end.split(",")) 206 | 207 | optimal_p = ternary_search(begin, end) 208 | if dist.get_rank() == 0: 209 | logger.log(f"ternary_search_results: {optimal_p}") 210 | fid, IS = get_fid(optimal_p, begin, end) 211 | logger.log(f"fid at optimal p = {optimal_p} is {fid}, IS is {IS}") 212 | 213 | 214 | def create_argparser(): 215 | defaults = dict( 216 | begin="0", 217 | end="39", 218 | training_mode="consistency_distillation", 219 | generator="determ", 220 | clip_denoised=True, 221 | num_samples=10000, 222 | batch_size=16, 223 | sampler="heun", 224 | s_churn=0.0, 225 | s_tmin=0.0, 226 | s_tmax=float("inf"), 227 | s_noise=1.0, 228 | steps=40, 229 | model_path="", 230 | ref_batch="", 231 | seed=42, 232 | ) 233 | defaults.update(model_and_diffusion_defaults()) 234 | parser = argparse.ArgumentParser() 235 | add_dict_to_argparser(parser, defaults) 236 | return parser 237 | 238 | 239 | if __name__ == "__main__": 240 | main() 241 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name="consistency-models", 5 | py_modules=["cm", "evaluations"], 6 | install_requires=[ 7 | "blobfile>=1.0.5", 8 | "torch", 9 | "tqdm", 10 | "numpy", 11 | "scipy", 12 | "pandas", 13 | "Cython", 14 | "piq==0.7.0", 15 | "joblib==0.14.0", 16 | "albumentations==0.4.3", 17 | "lmdb", 18 | "clip @ git+https://github.com/openai/CLIP.git", 19 | "mpi4py", 20 | "flash-attn==0.2.8", 21 | "pillow", 22 | ], 23 | ) 24 | --------------------------------------------------------------------------------