├── functions ├── __init__.py ├── ckpt_util.py ├── denoising.py └── svd_replacement.py ├── runners ├── __init__.py └── diffusion.py ├── .gitignore ├── inp_masks ├── lorem3.npy └── lolcat_extra.npy ├── figures └── ddrm-overview.png ├── configs ├── celeba_hq.yml ├── imagenet_256.yml ├── cat.yml ├── bedroom.yml ├── church.yml ├── imagenet_256_cc.yml └── imagenet_512_cc.yml ├── LICENSE ├── datasets ├── vision.py ├── imagenet_subset.py ├── lsun.py ├── utils.py ├── celeba.py └── __init__.py ├── environment.yml ├── main.py ├── guided_diffusion ├── nn.py ├── fp16_util.py ├── script_util.py ├── logger.py └── unet.py ├── README.md └── models └── diffusion.py /functions/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /runners/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | __pycache__/ 3 | *.pth 4 | exp/ 5 | -------------------------------------------------------------------------------- /inp_masks/lorem3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bahjat-kawar/ddrm/HEAD/inp_masks/lorem3.npy -------------------------------------------------------------------------------- /figures/ddrm-overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bahjat-kawar/ddrm/HEAD/figures/ddrm-overview.png -------------------------------------------------------------------------------- /inp_masks/lolcat_extra.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bahjat-kawar/ddrm/HEAD/inp_masks/lolcat_extra.npy -------------------------------------------------------------------------------- /configs/celeba_hq.yml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset: "CelebA_HQ" 3 | category: "" 4 | image_size: 256 5 | channels: 3 6 | logit_transform: false 7 | uniform_dequantization: false 8 | gaussian_dequantization: false 9 | random_flip: true 10 | rescaled: true 11 | num_workers: 32 12 | out_of_dist: True 13 | 14 | model: 15 | type: "simple" 16 | in_channels: 3 17 | out_ch: 3 18 | ch: 128 19 | ch_mult: [1, 1, 2, 2, 4, 4] 20 | num_res_blocks: 2 21 | attn_resolutions: [16, ] 22 | dropout: 0.0 23 | var_type: fixedsmall 24 | ema_rate: 0.999 25 | ema: True 26 | resamp_with_conv: True 27 | 28 | diffusion: 29 | beta_schedule: linear 30 | beta_start: 0.0001 31 | beta_end: 0.02 32 | num_diffusion_timesteps: 1000 33 | 34 | sampling: 35 | batch_size: 4 36 | last_only: True -------------------------------------------------------------------------------- /configs/imagenet_256.yml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset: "ImageNet" 3 | image_size: 256 4 | channels: 3 5 | logit_transform: false 6 | uniform_dequantization: false 7 | gaussian_dequantization: false 8 | random_flip: true 9 | rescaled: true 10 | num_workers: 32 11 | subset_1k: True 12 | out_of_dist: False 13 | 14 | model: 15 | type: "openai" 16 | in_channels: 3 17 | out_channels: 3 18 | num_channels: 256 19 | num_heads: 4 20 | num_res_blocks: 2 21 | attention_resolutions: "32,16,8" 22 | dropout: 0.0 23 | resamp_with_conv: True 24 | learn_sigma: True 25 | use_scale_shift_norm: true 26 | use_fp16: true 27 | resblock_updown: true 28 | num_heads_upsample: -1 29 | var_type: 'fixedsmall' 30 | num_head_channels: 64 31 | image_size: 256 32 | class_cond: false 33 | use_new_attention_order: false 34 | 35 | diffusion: 36 | beta_schedule: linear 37 | beta_start: 0.0001 38 | beta_end: 0.02 39 | num_diffusion_timesteps: 1000 40 | 41 | sampling: 42 | batch_size: 8 43 | last_only: True 44 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Bahjat Kawar 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /configs/cat.yml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset: "LSUN" 3 | category: "cat" 4 | image_size: 256 5 | channels: 3 6 | logit_transform: false 7 | uniform_dequantization: false 8 | gaussian_dequantization: false 9 | random_flip: true 10 | rescaled: true 11 | num_workers: 32 12 | out_of_dist: false 13 | 14 | model: 15 | type: "simple" 16 | in_channels: 3 17 | out_ch: 3 18 | ch: 128 19 | ch_mult: [1, 1, 2, 2, 4, 4] 20 | num_res_blocks: 2 21 | attn_resolutions: [16, ] 22 | dropout: 0.0 23 | var_type: fixedsmall 24 | ema_rate: 0.999 25 | ema: True 26 | resamp_with_conv: True 27 | 28 | diffusion: 29 | beta_schedule: linear 30 | beta_start: 0.0001 31 | beta_end: 0.02 32 | num_diffusion_timesteps: 1000 33 | 34 | training: 35 | batch_size: 64 36 | n_epochs: 10000 37 | n_iters: 5000000 38 | snapshot_freq: 5000 39 | validation_freq: 2000 40 | 41 | sampling: 42 | batch_size: 32 43 | last_only: True 44 | 45 | optim: 46 | weight_decay: 0.000 47 | optimizer: "Adam" 48 | lr: 0.00002 49 | beta1: 0.9 50 | amsgrad: false 51 | eps: 0.00000001 52 | -------------------------------------------------------------------------------- /configs/bedroom.yml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset: "LSUN" 3 | category: "bedroom" 4 | image_size: 256 5 | channels: 3 6 | logit_transform: false 7 | uniform_dequantization: false 8 | gaussian_dequantization: false 9 | random_flip: true 10 | rescaled: true 11 | num_workers: 32 12 | out_of_dist: true 13 | 14 | model: 15 | type: "simple" 16 | in_channels: 3 17 | out_ch: 3 18 | ch: 128 19 | ch_mult: [1, 1, 2, 2, 4, 4] 20 | num_res_blocks: 2 21 | attn_resolutions: [16, ] 22 | dropout: 0.0 23 | var_type: fixedsmall 24 | ema_rate: 0.999 25 | ema: True 26 | resamp_with_conv: True 27 | 28 | diffusion: 29 | beta_schedule: linear 30 | beta_start: 0.0001 31 | beta_end: 0.02 32 | num_diffusion_timesteps: 1000 33 | 34 | training: 35 | batch_size: 64 36 | n_epochs: 10000 37 | n_iters: 5000000 38 | snapshot_freq: 5000 39 | validation_freq: 2000 40 | 41 | sampling: 42 | batch_size: 6 43 | last_only: True 44 | 45 | optim: 46 | weight_decay: 0.000 47 | optimizer: "Adam" 48 | lr: 0.00002 49 | beta1: 0.9 50 | amsgrad: false 51 | eps: 0.00000001 52 | -------------------------------------------------------------------------------- /configs/church.yml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset: "LSUN" 3 | category: "church_outdoor" 4 | image_size: 256 5 | channels: 3 6 | logit_transform: false 7 | uniform_dequantization: false 8 | gaussian_dequantization: false 9 | random_flip: true 10 | rescaled: true 11 | num_workers: 32 12 | out_of_dist: true 13 | 14 | model: 15 | type: "simple" 16 | in_channels: 3 17 | out_ch: 3 18 | ch: 128 19 | ch_mult: [1, 1, 2, 2, 4, 4] 20 | num_res_blocks: 2 21 | attn_resolutions: [16, ] 22 | dropout: 0.0 23 | var_type: fixedsmall 24 | ema_rate: 0.999 25 | ema: True 26 | resamp_with_conv: True 27 | 28 | diffusion: 29 | beta_schedule: linear 30 | beta_start: 0.0001 31 | beta_end: 0.02 32 | num_diffusion_timesteps: 1000 33 | 34 | training: 35 | batch_size: 64 36 | n_epochs: 10000 37 | n_iters: 5000000 38 | snapshot_freq: 5000 39 | validation_freq: 2000 40 | 41 | sampling: 42 | batch_size: 6 43 | last_only: True 44 | 45 | optim: 46 | weight_decay: 0.000 47 | optimizer: "Adam" 48 | lr: 0.00002 49 | beta1: 0.9 50 | amsgrad: false 51 | eps: 0.00000001 52 | -------------------------------------------------------------------------------- /configs/imagenet_256_cc.yml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset: "ImageNet" 3 | image_size: 256 4 | channels: 3 5 | logit_transform: false 6 | uniform_dequantization: false 7 | gaussian_dequantization: false 8 | random_flip: true 9 | rescaled: true 10 | num_workers: 32 11 | subset_1k: False 12 | out_of_dist: False 13 | 14 | model: 15 | type: "openai" 16 | in_channels: 3 17 | out_channels: 3 18 | num_channels: 256 19 | num_heads: 4 20 | num_res_blocks: 2 21 | attention_resolutions: "32,16,8" 22 | dropout: 0.0 23 | resamp_with_conv: True 24 | learn_sigma: True 25 | use_scale_shift_norm: true 26 | use_fp16: true 27 | resblock_updown: true 28 | num_heads_upsample: -1 29 | var_type: 'fixedsmall' 30 | num_head_channels: 64 31 | image_size: 256 32 | class_cond: True 33 | use_new_attention_order: false 34 | 35 | classifier: 36 | image_size: 256 37 | classifier_attention_resolutions: "32,16,8" 38 | classifier_depth: 2 39 | classifier_pool: "attention" 40 | classifier_resblock_updown: True 41 | classifier_width: 128 42 | classifier_use_scale_shift_norm: True 43 | classifier_scale: 1.0 44 | classifier_use_fp16: True 45 | 46 | 47 | diffusion: 48 | beta_schedule: linear 49 | beta_start: 0.0001 50 | beta_end: 0.02 51 | num_diffusion_timesteps: 1000 52 | 53 | sampling: 54 | batch_size: 8 55 | last_only: True -------------------------------------------------------------------------------- /configs/imagenet_512_cc.yml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset: "ImageNet" 3 | image_size: 512 4 | channels: 3 5 | logit_transform: false 6 | uniform_dequantization: false 7 | gaussian_dequantization: false 8 | random_flip: true 9 | rescaled: true 10 | num_workers: 32 11 | subset_1k: False 12 | out_of_dist: False 13 | 14 | model: 15 | type: "openai" 16 | in_channels: 3 17 | out_channels: 3 18 | num_channels: 256 19 | num_heads: 4 20 | num_res_blocks: 2 21 | attention_resolutions: "32,16,8" 22 | dropout: 0.0 23 | resamp_with_conv: True 24 | learn_sigma: True 25 | use_scale_shift_norm: true 26 | use_fp16: false 27 | resblock_updown: true 28 | num_heads_upsample: -1 29 | var_type: 'fixedsmall' 30 | num_head_channels: 64 31 | image_size: 512 32 | class_cond: True 33 | use_new_attention_order: false 34 | 35 | classifier: 36 | image_size: 512 37 | classifier_attention_resolutions: "32,16,8" 38 | classifier_depth: 2 39 | classifier_pool: "attention" 40 | classifier_resblock_updown: True 41 | classifier_width: 128 42 | classifier_use_scale_shift_norm: True 43 | classifier_scale: 1.0 44 | classifier_use_fp16: false 45 | 46 | 47 | diffusion: 48 | beta_schedule: linear 49 | beta_start: 0.0001 50 | beta_end: 0.02 51 | num_diffusion_timesteps: 1000 52 | 53 | sampling: 54 | batch_size: 1 55 | last_only: True -------------------------------------------------------------------------------- /functions/ckpt_util.py: -------------------------------------------------------------------------------- 1 | import os, hashlib 2 | import requests 3 | from tqdm import tqdm 4 | 5 | URL_MAP = { 6 | "cifar10": "https://heibox.uni-heidelberg.de/f/869980b53bf5416c8a28/?dl=1", 7 | "ema_cifar10": "https://heibox.uni-heidelberg.de/f/2e4f01e2d9ee49bab1d5/?dl=1", 8 | "lsun_bedroom": "https://heibox.uni-heidelberg.de/f/f179d4f21ebc4d43bbfe/?dl=1", 9 | "ema_lsun_bedroom": "https://heibox.uni-heidelberg.de/f/b95206528f384185889b/?dl=1", 10 | "lsun_cat": "https://heibox.uni-heidelberg.de/f/fac870bd988348eab88e/?dl=1", 11 | "ema_lsun_cat": "https://heibox.uni-heidelberg.de/f/0701aac3aa69457bbe34/?dl=1", 12 | "lsun_church": "https://heibox.uni-heidelberg.de/f/2711a6f712e34b06b9d8/?dl=1", 13 | "ema_lsun_church": "https://heibox.uni-heidelberg.de/f/44ccb50ef3c6436db52e/?dl=1", 14 | } 15 | CKPT_MAP = { 16 | "cifar10": "diffusion_cifar10_model/model-790000.ckpt", 17 | "ema_cifar10": "ema_diffusion_cifar10_model/model-790000.ckpt", 18 | "lsun_bedroom": "diffusion_lsun_bedroom_model/model-2388000.ckpt", 19 | "ema_lsun_bedroom": "ema_diffusion_lsun_bedroom_model/model-2388000.ckpt", 20 | "lsun_cat": "diffusion_lsun_cat_model/model-1761000.ckpt", 21 | "ema_lsun_cat": "ema_diffusion_lsun_cat_model/model-1761000.ckpt", 22 | "lsun_church": "diffusion_lsun_church_model/model-4432000.ckpt", 23 | "ema_lsun_church": "ema_diffusion_lsun_church_model/model-4432000.ckpt", 24 | } 25 | MD5_MAP = { 26 | "cifar10": "82ed3067fd1002f5cf4c339fb80c4669", 27 | "ema_cifar10": "1fa350b952534ae442b1d5235cce5cd3", 28 | "lsun_bedroom": "f70280ac0e08b8e696f42cb8e948ff1c", 29 | "ema_lsun_bedroom": "1921fa46b66a3665e450e42f36c2720f", 30 | "lsun_cat": "bbee0e7c3d7abfb6e2539eaf2fb9987b", 31 | "ema_lsun_cat": "646f23f4821f2459b8bafc57fd824558", 32 | "lsun_church": "eb619b8a5ab95ef80f94ce8a5488dae3", 33 | "ema_lsun_church": "fdc68a23938c2397caba4a260bc2445f", 34 | } 35 | 36 | 37 | def download(url, local_path, chunk_size=1024): 38 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 39 | with requests.get(url, stream=True) as r: 40 | total_size = int(r.headers.get("content-length", 0)) 41 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 42 | with open(local_path, "wb") as f: 43 | for data in r.iter_content(chunk_size=chunk_size): 44 | if data: 45 | f.write(data) 46 | pbar.update(chunk_size) 47 | 48 | 49 | def md5_hash(path): 50 | with open(path, "rb") as f: 51 | content = f.read() 52 | return hashlib.md5(content).hexdigest() 53 | 54 | 55 | def get_ckpt_path(name, root=None, check=False, prefix='exp'): 56 | if 'church_outdoor' in name: 57 | name = name.replace('church_outdoor', 'church') 58 | assert name in URL_MAP 59 | # Modify the path when necessary 60 | cachedir = os.environ.get("XDG_CACHE_HOME", os.path.join(prefix, "logs/")) 61 | root = ( 62 | root 63 | if root is not None 64 | else os.path.join(cachedir, "diffusion_models_converted") 65 | ) 66 | path = os.path.join(root, CKPT_MAP[name]) 67 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): 68 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) 69 | download(URL_MAP[name], path) 70 | md5 = md5_hash(path) 71 | assert md5 == MD5_MAP[name], md5 72 | return path 73 | -------------------------------------------------------------------------------- /datasets/vision.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.utils.data as data 4 | 5 | 6 | class VisionDataset(data.Dataset): 7 | _repr_indent = 4 8 | 9 | def __init__(self, root, transforms=None, transform=None, target_transform=None): 10 | if isinstance(root, torch._six.string_classes): 11 | root = os.path.expanduser(root) 12 | self.root = root 13 | 14 | has_transforms = transforms is not None 15 | has_separate_transform = transform is not None or target_transform is not None 16 | if has_transforms and has_separate_transform: 17 | raise ValueError("Only transforms or transform/target_transform can " 18 | "be passed as argument") 19 | 20 | # for backwards-compatibility 21 | self.transform = transform 22 | self.target_transform = target_transform 23 | 24 | if has_separate_transform: 25 | transforms = StandardTransform(transform, target_transform) 26 | self.transforms = transforms 27 | 28 | def __getitem__(self, index): 29 | raise NotImplementedError 30 | 31 | def __len__(self): 32 | raise NotImplementedError 33 | 34 | def __repr__(self): 35 | head = "Dataset " + self.__class__.__name__ 36 | body = ["Number of datapoints: {}".format(self.__len__())] 37 | if self.root is not None: 38 | body.append("Root location: {}".format(self.root)) 39 | body += self.extra_repr().splitlines() 40 | if hasattr(self, 'transform') and self.transform is not None: 41 | body += self._format_transform_repr(self.transform, 42 | "Transforms: ") 43 | if hasattr(self, 'target_transform') and self.target_transform is not None: 44 | body += self._format_transform_repr(self.target_transform, 45 | "Target transforms: ") 46 | lines = [head] + [" " * self._repr_indent + line for line in body] 47 | return '\n'.join(lines) 48 | 49 | def _format_transform_repr(self, transform, head): 50 | lines = transform.__repr__().splitlines() 51 | return (["{}{}".format(head, lines[0])] + 52 | ["{}{}".format(" " * len(head), line) for line in lines[1:]]) 53 | 54 | def extra_repr(self): 55 | return "" 56 | 57 | 58 | class StandardTransform(object): 59 | def __init__(self, transform=None, target_transform=None): 60 | self.transform = transform 61 | self.target_transform = target_transform 62 | 63 | def __call__(self, input, target): 64 | if self.transform is not None: 65 | input = self.transform(input) 66 | if self.target_transform is not None: 67 | target = self.target_transform(target) 68 | return input, target 69 | 70 | def _format_transform_repr(self, transform, head): 71 | lines = transform.__repr__().splitlines() 72 | return (["{}{}".format(head, lines[0])] + 73 | ["{}{}".format(" " * len(head), line) for line in lines[1:]]) 74 | 75 | def __repr__(self): 76 | body = [self.__class__.__name__] 77 | if self.transform is not None: 78 | body += self._format_transform_repr(self.transform, 79 | "Transform: ") 80 | if self.target_transform is not None: 81 | body += self._format_transform_repr(self.target_transform, 82 | "Target transform: ") 83 | 84 | return '\n'.join(body) 85 | -------------------------------------------------------------------------------- /datasets/imagenet_subset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import torchvision.transforms as transforms 3 | from PIL import Image 4 | 5 | class CenterCropLongEdge(object): 6 | """Crops the given PIL Image on the long edge. 7 | Args: 8 | size (sequence or int): Desired output size of the crop. If size is an 9 | int instead of sequence like (h, w), a square crop (size, size) is 10 | made. 11 | """ 12 | 13 | def __call__(self, img): 14 | """ 15 | Args: 16 | img (PIL Image): Image to be cropped. 17 | Returns: 18 | PIL Image: Cropped image. 19 | """ 20 | return transforms.functional.center_crop(img, min(img.size)) 21 | 22 | def __repr__(self): 23 | return self.__class__.__name__ 24 | 25 | def pil_loader(path): 26 | # open path as file to avoid ResourceWarning 27 | # (https://github.com/python-pillow/Pillow/issues/835) 28 | with open(path, 'rb') as f: 29 | img = Image.open(f) 30 | return img.convert('RGB') 31 | 32 | 33 | def accimage_loader(path): 34 | import accimage 35 | try: 36 | return accimage.Image(path) 37 | except IOError: 38 | # Potentially a decoding problem, fall back to PIL.Image 39 | return pil_loader(path) 40 | 41 | def default_loader(path): 42 | from torchvision import get_image_backend 43 | if get_image_backend() == 'accimage': 44 | return accimage_loader(path) 45 | else: 46 | return pil_loader(path) 47 | 48 | class ImageDataset(data.Dataset): 49 | 50 | def __init__(self, 51 | root_dir, 52 | meta_file, 53 | transform=None, 54 | image_size=128, 55 | normalize=True): 56 | self.root_dir = root_dir 57 | if transform is not None: 58 | self.transform = transform 59 | else: 60 | norm_mean = [0.5, 0.5, 0.5] 61 | norm_std = [0.5, 0.5, 0.5] 62 | if normalize: 63 | self.transform = transforms.Compose([ 64 | CenterCropLongEdge(), 65 | transforms.Resize(image_size), 66 | transforms.ToTensor(), 67 | transforms.Normalize(norm_mean, norm_std) 68 | ]) 69 | else: 70 | self.transform = transforms.Compose([ 71 | CenterCropLongEdge(), 72 | transforms.Resize(image_size), 73 | transforms.ToTensor() 74 | ]) 75 | with open(meta_file) as f: 76 | lines = f.readlines() 77 | print("building dataset from %s" % meta_file) 78 | self.num = len(lines) 79 | self.metas = [] 80 | self.classifier = None 81 | suffix = ".jpeg" 82 | for line in lines: 83 | line_split = line.rstrip().split() 84 | if len(line_split) == 2: 85 | self.metas.append((line_split[0] + suffix, int(line_split[1]))) 86 | else: 87 | self.metas.append((line_split[0] + suffix, -1)) 88 | print("read meta done") 89 | 90 | def __len__(self): 91 | return self.num 92 | 93 | def __getitem__(self, idx): 94 | filename = self.root_dir + '/' + self.metas[idx][0] 95 | cls = self.metas[idx][1] 96 | img = default_loader(filename) 97 | 98 | # transform 99 | if self.transform is not None: 100 | img = self.transform(img) 101 | 102 | return img, cls #, self.metas[idx][0] -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: ddrm 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=conda_forge 8 | - _openmp_mutex=4.5=1_gnu 9 | - absl-py=1.0.0=pyhd8ed1ab_0 10 | - aiohttp=3.8.1=py39h3811e60_0 11 | - aiosignal=1.2.0=pyhd8ed1ab_0 12 | - async-timeout=4.0.2=pyhd8ed1ab_0 13 | - attrs=21.4.0=pyhd8ed1ab_0 14 | - backcall=0.2.0=pyh9f0ad1d_0 15 | - backports=1.0=py_2 16 | - backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0 17 | - blas=1.0=mkl 18 | - blinker=1.4=py_1 19 | - brotlipy=0.7.0=py39h3811e60_1003 20 | - bzip2=1.0.8=h7b6447c_0 21 | - c-ares=1.18.1=h7f98852_0 22 | - ca-certificates=2021.10.26=h06a4308_2 23 | - cachetools=4.2.4=pyhd8ed1ab_0 24 | - certifi=2021.10.8=py39hf3d152e_1 25 | - cffi=1.15.0=py39h4bc2ebd_0 26 | - charset-normalizer=2.0.9=pyhd8ed1ab_0 27 | - click=8.0.3=py39hf3d152e_1 28 | - colorama=0.4.4=pyh9f0ad1d_0 29 | - cryptography=36.0.1=py39h95dcef6_0 30 | - cudatoolkit=11.3.1=h2bc3f7f_2 31 | - dataclasses=0.8=pyhc8e2a94_3 32 | - decorator=5.1.0=pyhd8ed1ab_0 33 | - ffmpeg=4.3=hf484d3e_0 34 | - freetype=2.11.0=h70c0345_0 35 | - frozenlist=1.2.0=py39h3811e60_1 36 | - giflib=5.2.1=h7b6447c_0 37 | - gmp=6.2.1=h2531618_2 38 | - gnutls=3.6.15=he1e5248_0 39 | - google-auth=2.3.3=pyh6c4a22f_0 40 | - google-auth-oauthlib=0.4.6=pyhd8ed1ab_0 41 | - grpcio=1.42.0=py39hce63b2e_0 42 | - idna=3.3=pyhd3eb1b0_0 43 | - importlib-metadata=4.10.0=py39hf3d152e_0 44 | - intel-openmp=2021.4.0=h06a4308_3561 45 | - ipdb=0.13.9=pyhd8ed1ab_0 46 | - ipython=7.30.1=py39hf3d152e_0 47 | - jedi=0.18.1=py39hf3d152e_0 48 | - jpeg=9d=h7f8727e_0 49 | - lame=3.100=h7b6447c_0 50 | - lcms2=2.12=h3be6417_0 51 | - ld_impl_linux-64=2.36.1=hea4e1c9_2 52 | - libblas=3.9.0=12_linux64_mkl 53 | - libcblas=3.9.0=12_linux64_mkl 54 | - libffi=3.4.2=h7f98852_5 55 | - libgcc-ng=11.2.0=h1d223b6_11 56 | - libgfortran-ng=11.2.0=h69a702a_11 57 | - libgfortran5=11.2.0=h5c6108e_11 58 | - libgomp=11.2.0=h1d223b6_11 59 | - libiconv=1.15=h63c8f33_5 60 | - libidn2=2.3.2=h7f8727e_0 61 | - liblapack=3.9.0=12_linux64_mkl 62 | - libpng=1.6.37=hbc83047_0 63 | - libprotobuf=3.17.2=h780b84a_1 64 | - libstdcxx-ng=11.2.0=he4da1e4_11 65 | - libtasn1=4.16.0=h27cfd23_0 66 | - libtiff=4.2.0=h85742a9_0 67 | - libunistring=0.9.10=h27cfd23_0 68 | - libuv=1.40.0=h7b6447c_0 69 | - libwebp=1.2.0=h89dd481_0 70 | - libwebp-base=1.2.0=h27cfd23_0 71 | - lmdb=0.9.29=h2531618_0 72 | - lz4-c=1.9.3=h295c915_1 73 | - markdown=3.3.6=pyhd8ed1ab_0 74 | - matplotlib-inline=0.1.3=pyhd8ed1ab_0 75 | - mkl=2021.4.0=h06a4308_640 76 | - mkl-service=2.4.0=py39h7f8727e_0 77 | - mkl_fft=1.3.1=py39hd3c417c_0 78 | - mkl_random=1.2.2=py39h51133e4_0 79 | - multidict=5.2.0=py39h3811e60_1 80 | - ncurses=6.2=h58526e2_4 81 | - nettle=3.7.3=hbbd107a_1 82 | - numpy=1.21.2=py39h20f2e39_0 83 | - numpy-base=1.21.2=py39h79a1101_0 84 | - oauthlib=3.1.1=pyhd8ed1ab_0 85 | - olefile=0.46=pyhd3eb1b0_0 86 | - openh264=2.1.1=h4ff587b_0 87 | - openssl=1.1.1l=h7f98852_0 88 | - parso=0.8.3=pyhd8ed1ab_0 89 | - pexpect=4.8.0=pyh9f0ad1d_2 90 | - pickleshare=0.7.5=py39hde42818_1002 91 | - pillow=8.4.0=py39h5aabda8_0 92 | - pip=21.2.4=py39h06a4308_0 93 | - prompt-toolkit=3.0.24=pyha770c72_0 94 | - protobuf=3.17.2=py39he80948d_0 95 | - ptyprocess=0.7.0=pyhd3deb0d_0 96 | - pyasn1=0.4.8=py_0 97 | - pyasn1-modules=0.2.8=py_0 98 | - pycparser=2.21=pyhd8ed1ab_0 99 | - pygments=2.11.1=pyhd8ed1ab_0 100 | - pyjwt=2.3.0=pyhd8ed1ab_1 101 | - pyopenssl=21.0.0=pyhd8ed1ab_0 102 | - pysocks=1.7.1=py39hf3d152e_4 103 | - python=3.9.7=hb7a2778_3_cpython 104 | - python-lmdb=1.2.1=py39h2531618_1 105 | - python_abi=3.9=2_cp39 106 | - pytorch=1.10.1=py3.9_cuda11.3_cudnn8.2.0_0 107 | - pytorch-mutex=1.0=cuda 108 | - pyu2f=0.1.5=pyhd8ed1ab_0 109 | - pyyaml=6.0=py39h3811e60_3 110 | - readline=8.1=h27cfd23_0 111 | - requests=2.26.0=pyhd8ed1ab_1 112 | - requests-oauthlib=1.3.0=pyh9f0ad1d_0 113 | - rsa=4.8=pyhd8ed1ab_0 114 | - scipy=1.7.3=py39hee8e79c_0 115 | - setuptools=58.0.4=py39h06a4308_0 116 | - six=1.16.0=pyhd3eb1b0_0 117 | - sqlite=3.37.0=hc218d9a_0 118 | - tensorboard=2.7.0=pyhd8ed1ab_0 119 | - tensorboard-data-server=0.6.0=py39h95dcef6_1 120 | - tensorboard-plugin-wit=1.8.0=pyh44b312d_0 121 | - tk=8.6.11=h1ccaba5_0 122 | - torchaudio=0.10.1=py39_cu113 123 | - torchvision=0.11.2=py39_cu113 124 | - tqdm=4.62.3=pyhd8ed1ab_0 125 | - traitlets=5.1.1=pyhd8ed1ab_0 126 | - typing-extensions=3.10.0.2=hd3eb1b0_0 127 | - typing_extensions=3.10.0.2=pyh06a4308_0 128 | - tzdata=2021e=hda174b7_0 129 | - urllib3=1.26.7=pyhd8ed1ab_0 130 | - wcwidth=0.2.5=pyh9f0ad1d_2 131 | - werkzeug=2.0.2=pyhd3eb1b0_0 132 | - wheel=0.37.0=pyhd3eb1b0_1 133 | - xz=5.2.5=h7b6447c_0 134 | - yaml=0.2.5=h516909a_0 135 | - yarl=1.7.2=py39h3811e60_1 136 | - zipp=3.6.0=pyhd8ed1ab_0 137 | - zlib=1.2.11=h7f8727e_4 138 | - zstd=1.4.9=haebb681_0 139 | - pip: 140 | - torch-fidelity==0.3.0 141 | -------------------------------------------------------------------------------- /functions/denoising.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | import torchvision.utils as tvu 4 | import os 5 | 6 | def compute_alpha(beta, t): 7 | beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0) 8 | a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1) 9 | return a 10 | 11 | def efficient_generalized_steps(x, seq, model, b, H_funcs, y_0, sigma_0, etaB, etaA, etaC, cls_fn=None, classes=None): 12 | with torch.no_grad(): 13 | #setup vectors used in the algorithm 14 | singulars = H_funcs.singulars() 15 | Sigma = torch.zeros(x.shape[1]*x.shape[2]*x.shape[3], device=x.device) 16 | Sigma[:singulars.shape[0]] = singulars 17 | U_t_y = H_funcs.Ut(y_0) 18 | Sig_inv_U_t_y = U_t_y / singulars[:U_t_y.shape[-1]] 19 | 20 | #initialize x_T as given in the paper 21 | largest_alphas = compute_alpha(b, (torch.ones(x.size(0)) * seq[-1]).to(x.device).long()) 22 | largest_sigmas = (1 - largest_alphas).sqrt() / largest_alphas.sqrt() 23 | large_singulars_index = torch.where(singulars * largest_sigmas[0, 0, 0, 0] > sigma_0) 24 | inv_singulars_and_zero = torch.zeros(x.shape[1] * x.shape[2] * x.shape[3]).to(singulars.device) 25 | inv_singulars_and_zero[large_singulars_index] = sigma_0 / singulars[large_singulars_index] 26 | inv_singulars_and_zero = inv_singulars_and_zero.view(1, -1) 27 | 28 | # implement p(x_T | x_0, y) as given in the paper 29 | # if eigenvalue is too small, we just treat it as zero (only for init) 30 | init_y = torch.zeros(x.shape[0], x.shape[1] * x.shape[2] * x.shape[3]).to(x.device) 31 | init_y[:, large_singulars_index[0]] = U_t_y[:, large_singulars_index[0]] / singulars[large_singulars_index].view(1, -1) 32 | init_y = init_y.view(*x.size()) 33 | remaining_s = largest_sigmas.view(-1, 1) ** 2 - inv_singulars_and_zero ** 2 34 | remaining_s = remaining_s.view(x.shape[0], x.shape[1], x.shape[2], x.shape[3]).clamp_min(0.0).sqrt() 35 | init_y = init_y + remaining_s * x 36 | init_y = init_y / largest_sigmas 37 | 38 | #setup iteration variables 39 | x = H_funcs.V(init_y.view(x.size(0), -1)).view(*x.size()) 40 | n = x.size(0) 41 | seq_next = [-1] + list(seq[:-1]) 42 | x0_preds = [] 43 | xs = [x] 44 | 45 | #iterate over the timesteps 46 | for i, j in tqdm(zip(reversed(seq), reversed(seq_next))): 47 | t = (torch.ones(n) * i).to(x.device) 48 | next_t = (torch.ones(n) * j).to(x.device) 49 | at = compute_alpha(b, t.long()) 50 | at_next = compute_alpha(b, next_t.long()) 51 | xt = xs[-1].to('cuda') 52 | if cls_fn == None: 53 | et = model(xt, t) 54 | else: 55 | et = model(xt, t, classes) 56 | et = et[:, :3] 57 | et = et - (1 - at).sqrt()[0,0,0,0] * cls_fn(x,t,classes) 58 | 59 | if et.size(1) == 6: 60 | et = et[:, :3] 61 | 62 | x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt() 63 | 64 | #variational inference conditioned on y 65 | sigma = (1 - at).sqrt()[0, 0, 0, 0] / at.sqrt()[0, 0, 0, 0] 66 | sigma_next = (1 - at_next).sqrt()[0, 0, 0, 0] / at_next.sqrt()[0, 0, 0, 0] 67 | xt_mod = xt / at.sqrt()[0, 0, 0, 0] 68 | V_t_x = H_funcs.Vt(xt_mod) 69 | SVt_x = (V_t_x * Sigma)[:, :U_t_y.shape[1]] 70 | V_t_x0 = H_funcs.Vt(x0_t) 71 | SVt_x0 = (V_t_x0 * Sigma)[:, :U_t_y.shape[1]] 72 | 73 | falses = torch.zeros(V_t_x0.shape[1] - singulars.shape[0], dtype=torch.bool, device=xt.device) 74 | cond_before_lite = singulars * sigma_next > sigma_0 75 | cond_after_lite = singulars * sigma_next < sigma_0 76 | cond_before = torch.hstack((cond_before_lite, falses)) 77 | cond_after = torch.hstack((cond_after_lite, falses)) 78 | 79 | std_nextC = sigma_next * etaC 80 | sigma_tilde_nextC = torch.sqrt(sigma_next ** 2 - std_nextC ** 2) 81 | 82 | std_nextA = sigma_next * etaA 83 | sigma_tilde_nextA = torch.sqrt(sigma_next**2 - std_nextA**2) 84 | 85 | diff_sigma_t_nextB = torch.sqrt(sigma_next ** 2 - sigma_0 ** 2 / singulars[cond_before_lite] ** 2 * (etaB ** 2)) 86 | 87 | #missing pixels 88 | Vt_xt_mod_next = V_t_x0 + sigma_tilde_nextC * H_funcs.Vt(et) + std_nextC * torch.randn_like(V_t_x0) 89 | 90 | #less noisy than y (after) 91 | Vt_xt_mod_next[:, cond_after] = \ 92 | V_t_x0[:, cond_after] + sigma_tilde_nextA * ((U_t_y - SVt_x0) / sigma_0)[:, cond_after_lite] + std_nextA * torch.randn_like(V_t_x0[:, cond_after]) 93 | 94 | #noisier than y (before) 95 | Vt_xt_mod_next[:, cond_before] = \ 96 | (Sig_inv_U_t_y[:, cond_before_lite] * etaB + (1 - etaB) * V_t_x0[:, cond_before] + diff_sigma_t_nextB * torch.randn_like(U_t_y)[:, cond_before_lite]) 97 | 98 | #aggregate all 3 cases and give next prediction 99 | xt_mod_next = H_funcs.V(Vt_xt_mod_next) 100 | xt_next = (at_next.sqrt()[0, 0, 0, 0] * xt_mod_next).view(*x.shape) 101 | 102 | x0_preds.append(x0_t.to('cpu')) 103 | xs.append(xt_next.to('cpu')) 104 | 105 | 106 | return xs, x0_preds -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import traceback 3 | import shutil 4 | import logging 5 | import yaml 6 | import sys 7 | import os 8 | import torch 9 | import numpy as np 10 | import torch.utils.tensorboard as tb 11 | 12 | from runners.diffusion import Diffusion 13 | 14 | torch.set_printoptions(sci_mode=False) 15 | 16 | 17 | def parse_args_and_config(): 18 | parser = argparse.ArgumentParser(description=globals()["__doc__"]) 19 | 20 | parser.add_argument( 21 | "--config", type=str, required=True, help="Path to the config file" 22 | ) 23 | parser.add_argument("--seed", type=int, default=1234, help="Random seed") 24 | parser.add_argument( 25 | "--exp", type=str, default="exp", help="Path for saving running related data." 26 | ) 27 | parser.add_argument( 28 | "--doc", 29 | type=str, 30 | required=True, 31 | help="A string for documentation purpose. " 32 | "Will be the name of the log folder.", 33 | ) 34 | parser.add_argument( 35 | "--comment", type=str, default="", help="A string for experiment comment" 36 | ) 37 | parser.add_argument( 38 | "--verbose", 39 | type=str, 40 | default="info", 41 | help="Verbose level: info | debug | warning | critical", 42 | ) 43 | parser.add_argument( 44 | "--sample", 45 | action="store_true", 46 | help="Whether to produce samples from the model", 47 | ) 48 | parser.add_argument( 49 | "-i", 50 | "--image_folder", 51 | type=str, 52 | default="images", 53 | help="The folder name of samples", 54 | ) 55 | parser.add_argument( 56 | "--ni", 57 | action="store_true", 58 | help="No interaction. Suitable for Slurm Job launcher", 59 | ) 60 | parser.add_argument( 61 | "--timesteps", type=int, default=1000, help="number of steps involved" 62 | ) 63 | parser.add_argument( 64 | "--deg", type=str, required=True, help="Degradation" 65 | ) 66 | parser.add_argument( 67 | "--sigma_0", type=float, required=True, help="Sigma_0" 68 | ) 69 | parser.add_argument( 70 | "--eta", type=float, default=0.85, help="Eta" 71 | ) 72 | parser.add_argument( 73 | "--etaB", type=float, default=1, help="Eta_b (before)" 74 | ) 75 | parser.add_argument( 76 | '--subset_start', type=int, default=-1 77 | ) 78 | parser.add_argument( 79 | '--subset_end', type=int, default=-1 80 | ) 81 | 82 | args = parser.parse_args() 83 | args.log_path = os.path.join(args.exp, "logs", args.doc) 84 | 85 | # parse config file 86 | with open(os.path.join("configs", args.config), "r") as f: 87 | config = yaml.safe_load(f) 88 | new_config = dict2namespace(config) 89 | 90 | tb_path = os.path.join(args.exp, "tensorboard", args.doc) 91 | 92 | level = getattr(logging, args.verbose.upper(), None) 93 | if not isinstance(level, int): 94 | raise ValueError("level {} not supported".format(args.verbose)) 95 | 96 | handler1 = logging.StreamHandler() 97 | formatter = logging.Formatter( 98 | "%(levelname)s - %(filename)s - %(asctime)s - %(message)s" 99 | ) 100 | handler1.setFormatter(formatter) 101 | logger = logging.getLogger() 102 | logger.addHandler(handler1) 103 | logger.setLevel(level) 104 | 105 | os.makedirs(os.path.join(args.exp, "image_samples"), exist_ok=True) 106 | args.image_folder = os.path.join( 107 | args.exp, "image_samples", args.image_folder 108 | ) 109 | if not os.path.exists(args.image_folder): 110 | os.makedirs(args.image_folder) 111 | else: 112 | overwrite = False 113 | if args.ni: 114 | overwrite = True 115 | else: 116 | response = input( 117 | f"Image folder {args.image_folder} already exists. Overwrite? (Y/N)" 118 | ) 119 | if response.upper() == "Y": 120 | overwrite = True 121 | 122 | if overwrite: 123 | shutil.rmtree(args.image_folder) 124 | os.makedirs(args.image_folder) 125 | else: 126 | print("Output image folder exists. Program halted.") 127 | sys.exit(0) 128 | 129 | # add device 130 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 131 | logging.info("Using device: {}".format(device)) 132 | new_config.device = device 133 | 134 | # set random seed 135 | torch.manual_seed(args.seed) 136 | np.random.seed(args.seed) 137 | if torch.cuda.is_available(): 138 | torch.cuda.manual_seed_all(args.seed) 139 | 140 | torch.backends.cudnn.benchmark = True 141 | 142 | return args, new_config 143 | 144 | 145 | def dict2namespace(config): 146 | namespace = argparse.Namespace() 147 | for key, value in config.items(): 148 | if isinstance(value, dict): 149 | new_value = dict2namespace(value) 150 | else: 151 | new_value = value 152 | setattr(namespace, key, new_value) 153 | return namespace 154 | 155 | 156 | def main(): 157 | args, config = parse_args_and_config() 158 | logging.info("Writing log file to {}".format(args.log_path)) 159 | logging.info("Exp instance id = {}".format(os.getpid())) 160 | logging.info("Exp comment = {}".format(args.comment)) 161 | 162 | try: 163 | runner = Diffusion(args, config) 164 | runner.sample() 165 | except Exception: 166 | logging.error(traceback.format_exc()) 167 | 168 | return 0 169 | 170 | 171 | if __name__ == "__main__": 172 | sys.exit(main()) 173 | -------------------------------------------------------------------------------- /guided_diffusion/nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utilities for neural networks. 3 | """ 4 | 5 | import math 6 | 7 | import torch as th 8 | import torch.nn as nn 9 | 10 | 11 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 12 | class SiLU(nn.Module): 13 | def forward(self, x): 14 | return x * th.sigmoid(x) 15 | 16 | 17 | class GroupNorm32(nn.GroupNorm): 18 | def forward(self, x): 19 | return super().forward(x.float()).type(x.dtype) 20 | 21 | 22 | def conv_nd(dims, *args, **kwargs): 23 | """ 24 | Create a 1D, 2D, or 3D convolution module. 25 | """ 26 | if dims == 1: 27 | return nn.Conv1d(*args, **kwargs) 28 | elif dims == 2: 29 | return nn.Conv2d(*args, **kwargs) 30 | elif dims == 3: 31 | return nn.Conv3d(*args, **kwargs) 32 | raise ValueError(f"unsupported dimensions: {dims}") 33 | 34 | 35 | def linear(*args, **kwargs): 36 | """ 37 | Create a linear module. 38 | """ 39 | return nn.Linear(*args, **kwargs) 40 | 41 | 42 | def avg_pool_nd(dims, *args, **kwargs): 43 | """ 44 | Create a 1D, 2D, or 3D average pooling module. 45 | """ 46 | if dims == 1: 47 | return nn.AvgPool1d(*args, **kwargs) 48 | elif dims == 2: 49 | return nn.AvgPool2d(*args, **kwargs) 50 | elif dims == 3: 51 | return nn.AvgPool3d(*args, **kwargs) 52 | raise ValueError(f"unsupported dimensions: {dims}") 53 | 54 | 55 | def update_ema(target_params, source_params, rate=0.99): 56 | """ 57 | Update target parameters to be closer to those of source parameters using 58 | an exponential moving average. 59 | 60 | :param target_params: the target parameter sequence. 61 | :param source_params: the source parameter sequence. 62 | :param rate: the EMA rate (closer to 1 means slower). 63 | """ 64 | for targ, src in zip(target_params, source_params): 65 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 66 | 67 | 68 | def zero_module(module): 69 | """ 70 | Zero out the parameters of a module and return it. 71 | """ 72 | for p in module.parameters(): 73 | p.detach().zero_() 74 | return module 75 | 76 | 77 | def scale_module(module, scale): 78 | """ 79 | Scale the parameters of a module and return it. 80 | """ 81 | for p in module.parameters(): 82 | p.detach().mul_(scale) 83 | return module 84 | 85 | 86 | def mean_flat(tensor): 87 | """ 88 | Take the mean over all non-batch dimensions. 89 | """ 90 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 91 | 92 | 93 | def normalization(channels): 94 | """ 95 | Make a standard normalization layer. 96 | 97 | :param channels: number of input channels. 98 | :return: an nn.Module for normalization. 99 | """ 100 | return GroupNorm32(32, channels) 101 | 102 | 103 | def timestep_embedding(timesteps, dim, max_period=10000): 104 | """ 105 | Create sinusoidal timestep embeddings. 106 | 107 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 108 | These may be fractional. 109 | :param dim: the dimension of the output. 110 | :param max_period: controls the minimum frequency of the embeddings. 111 | :return: an [N x dim] Tensor of positional embeddings. 112 | """ 113 | half = dim // 2 114 | freqs = th.exp( 115 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half 116 | ).to(device=timesteps.device) 117 | args = timesteps[:, None].float() * freqs[None] 118 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 119 | if dim % 2: 120 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 121 | return embedding 122 | 123 | 124 | def checkpoint(func, inputs, params, flag): 125 | """ 126 | Evaluate a function without caching intermediate activations, allowing for 127 | reduced memory at the expense of extra compute in the backward pass. 128 | 129 | :param func: the function to evaluate. 130 | :param inputs: the argument sequence to pass to `func`. 131 | :param params: a sequence of parameters `func` depends on but does not 132 | explicitly take as arguments. 133 | :param flag: if False, disable gradient checkpointing. 134 | """ 135 | if flag: 136 | args = tuple(inputs) + tuple(params) 137 | return CheckpointFunction.apply(func, len(inputs), *args) 138 | else: 139 | return func(*inputs) 140 | 141 | 142 | class CheckpointFunction(th.autograd.Function): 143 | @staticmethod 144 | def forward(ctx, run_function, length, *args): 145 | ctx.run_function = run_function 146 | ctx.input_tensors = list(args[:length]) 147 | ctx.input_params = list(args[length:]) 148 | with th.no_grad(): 149 | output_tensors = ctx.run_function(*ctx.input_tensors) 150 | return output_tensors 151 | 152 | @staticmethod 153 | def backward(ctx, *output_grads): 154 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 155 | with th.enable_grad(): 156 | # Fixes a bug where the first op in run_function modifies the 157 | # Tensor storage in place, which is not allowed for detach()'d 158 | # Tensors. 159 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 160 | output_tensors = ctx.run_function(*shallow_copies) 161 | input_grads = th.autograd.grad( 162 | output_tensors, 163 | ctx.input_tensors + ctx.input_params, 164 | output_grads, 165 | allow_unused=True, 166 | ) 167 | del ctx.input_tensors 168 | del ctx.input_params 169 | del output_tensors 170 | return (None, None) + input_grads 171 | -------------------------------------------------------------------------------- /datasets/lsun.py: -------------------------------------------------------------------------------- 1 | from .vision import VisionDataset 2 | from PIL import Image 3 | import os 4 | import os.path 5 | import io 6 | from collections.abc import Iterable 7 | import pickle 8 | from torchvision.datasets.utils import verify_str_arg, iterable_to_str 9 | 10 | 11 | class LSUNClass(VisionDataset): 12 | def __init__(self, root, transform=None, target_transform=None): 13 | import lmdb 14 | 15 | super(LSUNClass, self).__init__( 16 | root, transform=transform, target_transform=target_transform 17 | ) 18 | 19 | self.env = lmdb.open( 20 | root, 21 | max_readers=1, 22 | readonly=True, 23 | lock=False, 24 | readahead=False, 25 | meminit=False, 26 | ) 27 | with self.env.begin(write=False) as txn: 28 | self.length = txn.stat()["entries"] 29 | root_split = root.split("/") 30 | cache_file = os.path.join("/".join(root_split[:-1]), f"_cache_{root_split[-1]}") 31 | if os.path.isfile(cache_file): 32 | self.keys = pickle.load(open(cache_file, "rb")) 33 | else: 34 | with self.env.begin(write=False) as txn: 35 | self.keys = [key for key, _ in txn.cursor()] 36 | pickle.dump(self.keys, open(cache_file, "wb")) 37 | 38 | def __getitem__(self, index): 39 | img, target = None, None 40 | env = self.env 41 | with env.begin(write=False) as txn: 42 | imgbuf = txn.get(self.keys[index]) 43 | 44 | buf = io.BytesIO() 45 | buf.write(imgbuf) 46 | buf.seek(0) 47 | img = Image.open(buf).convert("RGB") 48 | 49 | if self.transform is not None: 50 | img = self.transform(img) 51 | 52 | if self.target_transform is not None: 53 | target = self.target_transform(target) 54 | 55 | return img, target 56 | 57 | def __len__(self): 58 | return self.length 59 | 60 | 61 | class LSUN(VisionDataset): 62 | """ 63 | `LSUN `_ dataset. 64 | 65 | Args: 66 | root (string): Root directory for the database files. 67 | classes (string or list): One of {'train', 'val', 'test'} or a list of 68 | categories to load. e,g. ['bedroom_train', 'church_outdoor_train']. 69 | transform (callable, optional): A function/transform that takes in an PIL image 70 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 71 | target_transform (callable, optional): A function/transform that takes in the 72 | target and transforms it. 73 | """ 74 | 75 | def __init__(self, root, classes="train", transform=None, target_transform=None): 76 | super(LSUN, self).__init__( 77 | root, transform=transform, target_transform=target_transform 78 | ) 79 | self.classes = self._verify_classes(classes) 80 | 81 | # for each class, create an LSUNClassDataset 82 | self.dbs = [] 83 | for c in self.classes: 84 | self.dbs.append( 85 | LSUNClass(root=root + "/" + c + "_lmdb", transform=transform) 86 | ) 87 | 88 | self.indices = [] 89 | count = 0 90 | for db in self.dbs: 91 | count += len(db) 92 | self.indices.append(count) 93 | 94 | self.length = count 95 | 96 | def _verify_classes(self, classes): 97 | categories = [ 98 | "bedroom", 99 | "bridge", 100 | "church_outdoor", 101 | "classroom", 102 | "conference_room", 103 | "dining_room", 104 | "kitchen", 105 | "living_room", 106 | "restaurant", 107 | "tower", 108 | "cat", 109 | ] 110 | dset_opts = ["train", "val", "test"] 111 | 112 | try: 113 | verify_str_arg(classes, "classes", dset_opts) 114 | if classes == "test": 115 | classes = [classes] 116 | else: 117 | classes = [c + "_" + classes for c in categories] 118 | except ValueError: 119 | if not isinstance(classes, Iterable): 120 | msg = ( 121 | "Expected type str or Iterable for argument classes, " 122 | "but got type {}." 123 | ) 124 | raise ValueError(msg.format(type(classes))) 125 | 126 | classes = list(classes) 127 | msg_fmtstr = ( 128 | "Expected type str for elements in argument classes, " 129 | "but got type {}." 130 | ) 131 | for c in classes: 132 | verify_str_arg(c, custom_msg=msg_fmtstr.format(type(c))) 133 | c_short = c.split("_") 134 | category, dset_opt = "_".join(c_short[:-1]), c_short[-1] 135 | 136 | msg_fmtstr = "Unknown value '{}' for {}. Valid values are {{{}}}." 137 | msg = msg_fmtstr.format( 138 | category, "LSUN class", iterable_to_str(categories) 139 | ) 140 | verify_str_arg(category, valid_values=categories, custom_msg=msg) 141 | 142 | msg = msg_fmtstr.format(dset_opt, "postfix", iterable_to_str(dset_opts)) 143 | verify_str_arg(dset_opt, valid_values=dset_opts, custom_msg=msg) 144 | 145 | return classes 146 | 147 | def __getitem__(self, index): 148 | """ 149 | Args: 150 | index (int): Index 151 | 152 | Returns: 153 | tuple: Tuple (image, target) where target is the index of the target category. 154 | """ 155 | target = 0 156 | sub = 0 157 | for ind in self.indices: 158 | if index < ind: 159 | break 160 | target += 1 161 | sub = ind 162 | 163 | db = self.dbs[target] 164 | index = index - sub 165 | 166 | if self.target_transform is not None: 167 | target = self.target_transform(target) 168 | 169 | img, _ = db[index] 170 | return img, target 171 | 172 | def __len__(self): 173 | return self.length 174 | 175 | def extra_repr(self): 176 | return "Classes: {classes}".format(**self.__dict__) 177 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Denoising Diffusion Restoration Models (DDRM) 2 | 3 | [arXiv](https://arxiv.org/abs/2201.11793) | [PDF](https://ddrm-ml.github.io/DDRM-paper.pdf) | [Project Website](https://ddrm-ml.github.io/) 4 | 5 | [Bahjat Kawar](https://bahjat-kawar.github.io/)1, [Michael Elad](https://elad.cs.technion.ac.il/)1, [Stefano Ermon](http://cs.stanford.edu/~ermon)2, [Jiaming Song](http://tsong.me)2
6 | 1 Technion, 2Stanford University 7 | 8 | DDRM uses pre-trained [DDPMs](https://hojonathanho.github.io/diffusion/) for solving general linear inverse problems. It does so efficiently and without problem-specific supervised training. 9 | 10 | ddrm-overview 11 | 12 | ## Running the Experiments 13 | The code has been tested on PyTorch 1.8 and PyTorch 1.10. Please refer to `environment.yml` for a list of conda/mamba environments that can be used to run the code. 14 | 15 | ### Pretrained models 16 | We use pretrained models from [https://github.com/openai/guided-diffusion](https://github.com/openai/guided-diffusion), [https://github.com/pesser/pytorch_diffusion](https://github.com/pesser/pytorch_diffusion) and [https://github.com/ermongroup/SDEdit](https://github.com/ermongroup/SDEdit) 17 | 18 | We use 1,000 images from the ImageNet validation set for comparison with other methods. The list of images is taken from [https://github.com/XingangPan/deep-generative-prior/](https://github.com/XingangPan/deep-generative-prior/) 19 | 20 | The models and datasets are placed in the `exp/` folder as follows: 21 | ```bash 22 | # a folder named by the argument `--exp` given to main.py 23 | ├── datasets # all dataset files 24 | │ ├── celeba # all CelebA files 25 | │ ├── imagenet # all ImageNet files 26 | │ ├── ood # out of distribution ImageNet images 27 | │ ├── ood_bedroom # out of distribution bedroom images 28 | │ ├── ood_cat # out of distribution cat images 29 | │ └── ood_celeba # out of distribution CelebA images 30 | ├── logs # contains checkpoints and samples produced during training 31 | │ ├── celeba 32 | │ │ └── celeba_hq.ckpt # the checkpoint file for CelebA-HQ 33 | │ ├── diffusion_models_converted 34 | │ │ └── ema_diffusion_lsun__model 35 | │ │ └── model-x.ckpt # the checkpoint file saved at the x-th training iteration 36 | │ ├── imagenet # ImageNet checkpoint files 37 | │ │ ├── 256x256_classifier.pt 38 | │ │ ├── 256x256_diffusion.pt 39 | │ │ ├── 256x256_diffusion_uncond.pt 40 | │ │ ├── 512x512_classifier.pt 41 | │ │ └── 512x512_diffusion.pt 42 | ├── image_samples # contains generated samples 43 | └── imagenet_val_1k.txt # list of the 1k images used in ImageNet-1K. 44 | ``` 45 | 46 | We note that some models may not generate high-quality samples in unconditional image synthesis; this is especially the case for the pre-trained CelebA model. 47 | 48 | ### Sampling from the model 49 | 50 | The general command to sample from the model is as follows: 51 | ``` 52 | python main.py --ni --config {CONFIG}.yml --doc {DATASET} --timesteps {STEPS} --eta {ETA} --etaB {ETA_B} --deg {DEGRADATION} --sigma_0 {SIGMA_0} -i {IMAGE_FOLDER} 53 | ``` 54 | where the following are options 55 | - `ETA` is the eta hyperparameter in the paper. (default: `0.85`) 56 | - `ETA_B` is the eta_b hyperparameter in the paper. (default: `1`) 57 | - `STEPS` controls how many timesteps used in the process. 58 | - `DEGREDATION` is the type of degredation allowed. (One of: `cs2`, `cs4`, `inp`, `inp_lolcat`, `inp_lorem`, `deno`, `deblur_uni`, `deblur_gauss`, `deblur_aniso`, `sr2`, `sr4`, `sr8`, `sr16`, `sr_bicubic4`, `sr_bicubic8`, `sr_bicubic16` `color`) 59 | - `SIGMA_0` is the noise observed in y. 60 | - `CONFIG` is the name of the config file (see `configs/` for a list), including hyperparameters such as batch size and network architectures. 61 | - `DATASET` is the name of the dataset used, to determine where the checkpoint file is found. 62 | - `IMAGE_FOLDER` is the name of the folder the resulting images will be placed in (default: `images`) 63 | 64 | For example, for sampling noisy 4x super resolution from the ImageNet 256x256 unconditional model using 20 steps: 65 | ``` 66 | python main.py --ni --config imagenet_256.yml --doc imagenet --timesteps 20 --eta 0.85 --etaB 1 --deg sr4 --sigma_0 0.05 67 | ``` 68 | The generated images are place in the `/image_samples/{IMAGE_FOLDER}` folder, where `orig_{id}.png`, `y0_{id}.png`, `{id}_-1.png` refer to the original, degraded, restored images respectively. 69 | 70 | The config files contain a setting controlling whether to test on samples from the trained dataset's distribution or not. 71 | 72 | ### Images for Demonstration Purposes 73 | A list of images for demonstration purposes can be found here: [https://github.com/jiamings/ddrm-exp-datasets](https://github.com/jiamings/ddrm-exp-datasets). Place them under the `/datasets` folder, and these commands can be excecuted directly: 74 | 75 | CelebA noisy 4x super-resolution: 76 | ``` 77 | python main.py --ni --config celeba_hq.yml --doc celeba --timesteps 20 --eta 0.85 --etaB 1 --deg sr4 --sigma_0 0.05 -i celeba_hq_sr4_sigma_0.05 78 | ``` 79 | 80 | General content images uniform deblurring: 81 | ``` 82 | python main.py --ni --config imagenet_256.yml --doc imagenet_ood --timesteps 20 --eta 0.85 --etaB 1 --deg deblur_uni --sigma_0 0.0 -i imagenet_sr4_sigma_0.0 83 | ``` 84 | 85 | Bedroom noisy 4x super-resolution: 86 | ``` 87 | python main.py --ni --config bedroom.yml --doc bedroom --timesteps 20 --eta 0.85 --etaB 1 --deg sr4 --sigma_0 0.05 -i bedroom_sr4_sigma_0.05 88 | ``` 89 | 90 | ## References and Acknowledgements 91 | ``` 92 | @inproceedings{kawar2022denoising, 93 | title={Denoising Diffusion Restoration Models}, 94 | author={Bahjat Kawar and Michael Elad and Stefano Ermon and Jiaming Song}, 95 | booktitle={Advances in Neural Information Processing Systems}, 96 | year={2022} 97 | } 98 | ``` 99 | 100 | This implementation is based on / inspired by: 101 | - [https://github.com/hojonathanho/diffusion](https://github.com/hojonathanho/diffusion) (the DDPM TensorFlow repo), 102 | - [https://github.com/pesser/pytorch_diffusion](https://github.com/pesser/pytorch_diffusion) (PyTorch helper that loads the DDPM model), and 103 | - [https://github.com/ermongroup/ddim](https://github.com/ermongroup/ddim) (code structure) 104 | -------------------------------------------------------------------------------- /datasets/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import hashlib 4 | import errno 5 | from torch.utils.model_zoo import tqdm 6 | 7 | 8 | def gen_bar_updater(): 9 | pbar = tqdm(total=None) 10 | 11 | def bar_update(count, block_size, total_size): 12 | if pbar.total is None and total_size: 13 | pbar.total = total_size 14 | progress_bytes = count * block_size 15 | pbar.update(progress_bytes - pbar.n) 16 | 17 | return bar_update 18 | 19 | 20 | def check_integrity(fpath, md5=None): 21 | if md5 is None: 22 | return True 23 | if not os.path.isfile(fpath): 24 | return False 25 | md5o = hashlib.md5() 26 | with open(fpath, 'rb') as f: 27 | # read in 1MB chunks 28 | for chunk in iter(lambda: f.read(1024 * 1024), b''): 29 | md5o.update(chunk) 30 | md5c = md5o.hexdigest() 31 | if md5c != md5: 32 | return False 33 | return True 34 | 35 | 36 | def makedir_exist_ok(dirpath): 37 | """ 38 | Python2 support for os.makedirs(.., exist_ok=True) 39 | """ 40 | try: 41 | os.makedirs(dirpath) 42 | except OSError as e: 43 | if e.errno == errno.EEXIST: 44 | pass 45 | else: 46 | raise 47 | 48 | 49 | def download_url(url, root, filename=None, md5=None): 50 | """Download a file from a url and place it in root. 51 | 52 | Args: 53 | url (str): URL to download file from 54 | root (str): Directory to place downloaded file in 55 | filename (str, optional): Name to save the file under. If None, use the basename of the URL 56 | md5 (str, optional): MD5 checksum of the download. If None, do not check 57 | """ 58 | from six.moves import urllib 59 | 60 | root = os.path.expanduser(root) 61 | if not filename: 62 | filename = os.path.basename(url) 63 | fpath = os.path.join(root, filename) 64 | 65 | makedir_exist_ok(root) 66 | 67 | # downloads file 68 | if os.path.isfile(fpath) and check_integrity(fpath, md5): 69 | print('Using downloaded and verified file: ' + fpath) 70 | else: 71 | try: 72 | print('Downloading ' + url + ' to ' + fpath) 73 | urllib.request.urlretrieve( 74 | url, fpath, 75 | reporthook=gen_bar_updater() 76 | ) 77 | except OSError: 78 | if url[:5] == 'https': 79 | url = url.replace('https:', 'http:') 80 | print('Failed download. Trying https -> http instead.' 81 | ' Downloading ' + url + ' to ' + fpath) 82 | urllib.request.urlretrieve( 83 | url, fpath, 84 | reporthook=gen_bar_updater() 85 | ) 86 | 87 | 88 | def list_dir(root, prefix=False): 89 | """List all directories at a given root 90 | 91 | Args: 92 | root (str): Path to directory whose folders need to be listed 93 | prefix (bool, optional): If true, prepends the path to each result, otherwise 94 | only returns the name of the directories found 95 | """ 96 | root = os.path.expanduser(root) 97 | directories = list( 98 | filter( 99 | lambda p: os.path.isdir(os.path.join(root, p)), 100 | os.listdir(root) 101 | ) 102 | ) 103 | 104 | if prefix is True: 105 | directories = [os.path.join(root, d) for d in directories] 106 | 107 | return directories 108 | 109 | 110 | def list_files(root, suffix, prefix=False): 111 | """List all files ending with a suffix at a given root 112 | 113 | Args: 114 | root (str): Path to directory whose folders need to be listed 115 | suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png'). 116 | It uses the Python "str.endswith" method and is passed directly 117 | prefix (bool, optional): If true, prepends the path to each result, otherwise 118 | only returns the name of the files found 119 | """ 120 | root = os.path.expanduser(root) 121 | files = list( 122 | filter( 123 | lambda p: os.path.isfile(os.path.join(root, p)) and p.endswith(suffix), 124 | os.listdir(root) 125 | ) 126 | ) 127 | 128 | if prefix is True: 129 | files = [os.path.join(root, d) for d in files] 130 | 131 | return files 132 | 133 | 134 | def download_file_from_google_drive(file_id, root, filename=None, md5=None): 135 | """Download a Google Drive file from and place it in root. 136 | 137 | Args: 138 | file_id (str): id of file to be downloaded 139 | root (str): Directory to place downloaded file in 140 | filename (str, optional): Name to save the file under. If None, use the id of the file. 141 | md5 (str, optional): MD5 checksum of the download. If None, do not check 142 | """ 143 | # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url 144 | import requests 145 | url = "https://docs.google.com/uc?export=download" 146 | 147 | root = os.path.expanduser(root) 148 | if not filename: 149 | filename = file_id 150 | fpath = os.path.join(root, filename) 151 | 152 | makedir_exist_ok(root) 153 | 154 | if os.path.isfile(fpath) and check_integrity(fpath, md5): 155 | print('Using downloaded and verified file: ' + fpath) 156 | else: 157 | session = requests.Session() 158 | 159 | response = session.get(url, params={'id': file_id}, stream=True) 160 | token = _get_confirm_token(response) 161 | 162 | if token: 163 | params = {'id': file_id, 'confirm': token} 164 | response = session.get(url, params=params, stream=True) 165 | 166 | _save_response_content(response, fpath) 167 | 168 | 169 | def _get_confirm_token(response): 170 | for key, value in response.cookies.items(): 171 | if key.startswith('download_warning'): 172 | return value 173 | 174 | return None 175 | 176 | 177 | def _save_response_content(response, destination, chunk_size=32768): 178 | with open(destination, "wb") as f: 179 | pbar = tqdm(total=None) 180 | progress = 0 181 | for chunk in response.iter_content(chunk_size): 182 | if chunk: # filter out keep-alive new chunks 183 | f.write(chunk) 184 | progress += len(chunk) 185 | pbar.update(progress - pbar.n) 186 | pbar.close() 187 | -------------------------------------------------------------------------------- /datasets/celeba.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import PIL 4 | from .vision import VisionDataset 5 | from .utils import download_file_from_google_drive, check_integrity 6 | 7 | 8 | class CelebA(VisionDataset): 9 | """`Large-scale CelebFaces Attributes (CelebA) Dataset `_ Dataset. 10 | 11 | Args: 12 | root (string): Root directory where images are downloaded to. 13 | split (string): One of {'train', 'valid', 'test'}. 14 | Accordingly dataset is selected. 15 | target_type (string or list, optional): Type of target to use, ``attr``, ``identity``, ``bbox``, 16 | or ``landmarks``. Can also be a list to output a tuple with all specified target types. 17 | The targets represent: 18 | ``attr`` (np.array shape=(40,) dtype=int): binary (0, 1) labels for attributes 19 | ``identity`` (int): label for each person (data points with the same identity are the same person) 20 | ``bbox`` (np.array shape=(4,) dtype=int): bounding box (x, y, width, height) 21 | ``landmarks`` (np.array shape=(10,) dtype=int): landmark points (lefteye_x, lefteye_y, righteye_x, 22 | righteye_y, nose_x, nose_y, leftmouth_x, leftmouth_y, rightmouth_x, rightmouth_y) 23 | Defaults to ``attr``. 24 | transform (callable, optional): A function/transform that takes in an PIL image 25 | and returns a transformed version. E.g, ``transforms.ToTensor`` 26 | target_transform (callable, optional): A function/transform that takes in the 27 | target and transforms it. 28 | download (bool, optional): If true, downloads the dataset from the internet and 29 | puts it in root directory. If dataset is already downloaded, it is not 30 | downloaded again. 31 | """ 32 | 33 | base_folder = "celeba" 34 | # There currently does not appear to be a easy way to extract 7z in python (without introducing additional 35 | # dependencies). The "in-the-wild" (not aligned+cropped) images are only in 7z, so they are not available 36 | # right now. 37 | file_list = [ 38 | # File ID MD5 Hash Filename 39 | ("0B7EVK8r0v71pZjFTYXZWM3FlRnM", "00d2c5bc6d35e252742224ab0c1e8fcb", "img_align_celeba.zip"), 40 | # ("0B7EVK8r0v71pbWNEUjJKdDQ3dGc", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_align_celeba_png.7z"), 41 | # ("0B7EVK8r0v71peklHb0pGdDl6R28", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_celeba.7z"), 42 | ("0B7EVK8r0v71pblRyaVFSWGxPY0U", "75e246fa4810816ffd6ee81facbd244c", "list_attr_celeba.txt"), 43 | ("1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS", "32bd1bd63d3c78cd57e08160ec5ed1e2", "identity_CelebA.txt"), 44 | ("0B7EVK8r0v71pbThiMVRxWXZ4dU0", "00566efa6fedff7a56946cd1c10f1c16", "list_bbox_celeba.txt"), 45 | ("0B7EVK8r0v71pd0FJY3Blby1HUTQ", "cc24ecafdb5b50baae59b03474781f8c", "list_landmarks_align_celeba.txt"), 46 | # ("0B7EVK8r0v71pTzJIdlJWdHczRlU", "063ee6ddb681f96bc9ca28c6febb9d1a", "list_landmarks_celeba.txt"), 47 | ("0B7EVK8r0v71pY0NSMzRuSXJEVkk", "d32c9cbf5e040fd4025c592c306e6668", "list_eval_partition.txt"), 48 | ] 49 | 50 | def __init__(self, root, 51 | split="train", 52 | target_type="attr", 53 | transform=None, target_transform=None, 54 | download=False): 55 | import pandas 56 | super(CelebA, self).__init__(root) 57 | self.split = split 58 | if isinstance(target_type, list): 59 | self.target_type = target_type 60 | else: 61 | self.target_type = [target_type] 62 | self.transform = transform 63 | self.target_transform = target_transform 64 | 65 | if download: 66 | self.download() 67 | 68 | if not self._check_integrity(): 69 | raise RuntimeError('Dataset not found or corrupted.' + 70 | ' You can use download=True to download it') 71 | 72 | self.transform = transform 73 | self.target_transform = target_transform 74 | 75 | if split.lower() == "train": 76 | split = 0 77 | elif split.lower() == "valid": 78 | split = 1 79 | elif split.lower() == "test": 80 | split = 2 81 | else: 82 | raise ValueError('Wrong split entered! Please use split="train" ' 83 | 'or split="valid" or split="test"') 84 | 85 | with open(os.path.join(self.root, self.base_folder, "list_eval_partition.txt"), "r") as f: 86 | splits = pandas.read_csv(f, delim_whitespace=True, header=None, index_col=0) 87 | 88 | with open(os.path.join(self.root, self.base_folder, "identity_CelebA.txt"), "r") as f: 89 | self.identity = pandas.read_csv(f, delim_whitespace=True, header=None, index_col=0) 90 | 91 | with open(os.path.join(self.root, self.base_folder, "list_bbox_celeba.txt"), "r") as f: 92 | self.bbox = pandas.read_csv(f, delim_whitespace=True, header=1, index_col=0) 93 | 94 | with open(os.path.join(self.root, self.base_folder, "list_landmarks_align_celeba.txt"), "r") as f: 95 | self.landmarks_align = pandas.read_csv(f, delim_whitespace=True, header=1) 96 | 97 | with open(os.path.join(self.root, self.base_folder, "list_attr_celeba.txt"), "r") as f: 98 | self.attr = pandas.read_csv(f, delim_whitespace=True, header=1) 99 | 100 | mask = (splits[1] == split) 101 | self.filename = splits[mask].index.values 102 | self.identity = torch.as_tensor(self.identity[mask].values) 103 | self.bbox = torch.as_tensor(self.bbox[mask].values) 104 | self.landmarks_align = torch.as_tensor(self.landmarks_align[mask].values) 105 | self.attr = torch.as_tensor(self.attr[mask].values) 106 | self.attr = (self.attr + 1) // 2 # map from {-1, 1} to {0, 1} 107 | 108 | def _check_integrity(self): 109 | for (_, md5, filename) in self.file_list: 110 | fpath = os.path.join(self.root, self.base_folder, filename) 111 | _, ext = os.path.splitext(filename) 112 | # Allow original archive to be deleted (zip and 7z) 113 | # Only need the extracted images 114 | if ext not in [".zip", ".7z"] and not check_integrity(fpath, md5): 115 | return False 116 | 117 | # Should check a hash of the images 118 | return os.path.isdir(os.path.join(self.root, self.base_folder, "img_align_celeba")) 119 | 120 | def download(self): 121 | import zipfile 122 | 123 | if self._check_integrity(): 124 | print('Files already downloaded and verified') 125 | return 126 | 127 | for (file_id, md5, filename) in self.file_list: 128 | download_file_from_google_drive(file_id, os.path.join(self.root, self.base_folder), filename, md5) 129 | 130 | with zipfile.ZipFile(os.path.join(self.root, self.base_folder, "img_align_celeba.zip"), "r") as f: 131 | f.extractall(os.path.join(self.root, self.base_folder)) 132 | 133 | def __getitem__(self, index): 134 | X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index])) 135 | 136 | target = [] 137 | for t in self.target_type: 138 | if t == "attr": 139 | target.append(self.attr[index, :]) 140 | elif t == "identity": 141 | target.append(self.identity[index, 0]) 142 | elif t == "bbox": 143 | target.append(self.bbox[index, :]) 144 | elif t == "landmarks": 145 | target.append(self.landmarks_align[index, :]) 146 | else: 147 | raise ValueError("Target type \"{}\" is not recognized.".format(t)) 148 | target = tuple(target) if len(target) > 1 else target[0] 149 | 150 | if self.transform is not None: 151 | X = self.transform(X) 152 | 153 | if self.target_transform is not None: 154 | target = self.target_transform(target) 155 | 156 | return X, target 157 | 158 | def __len__(self): 159 | return len(self.attr) 160 | 161 | def extra_repr(self): 162 | lines = ["Target type: {target_type}", "Split: {split}"] 163 | return '\n'.join(lines).format(**self.__dict__) 164 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numbers 4 | import torchvision.transforms as transforms 5 | import torchvision.transforms.functional as F 6 | from datasets.celeba import CelebA 7 | from datasets.lsun import LSUN 8 | from torch.utils.data import Subset 9 | import numpy as np 10 | import torchvision 11 | from PIL import Image 12 | from functools import partial 13 | 14 | class Crop(object): 15 | def __init__(self, x1, x2, y1, y2): 16 | self.x1 = x1 17 | self.x2 = x2 18 | self.y1 = y1 19 | self.y2 = y2 20 | 21 | def __call__(self, img): 22 | return F.crop(img, self.x1, self.y1, self.x2 - self.x1, self.y2 - self.y1) 23 | 24 | def __repr__(self): 25 | return self.__class__.__name__ + "(x1={}, x2={}, y1={}, y2={})".format( 26 | self.x1, self.x2, self.y1, self.y2 27 | ) 28 | 29 | def center_crop_arr(pil_image, image_size = 256): 30 | # Imported from openai/guided-diffusion 31 | while min(*pil_image.size) >= 2 * image_size: 32 | pil_image = pil_image.resize( 33 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 34 | ) 35 | 36 | scale = image_size / min(*pil_image.size) 37 | pil_image = pil_image.resize( 38 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 39 | ) 40 | 41 | arr = np.array(pil_image) 42 | crop_y = (arr.shape[0] - image_size) // 2 43 | crop_x = (arr.shape[1] - image_size) // 2 44 | return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size] 45 | 46 | 47 | def get_dataset(args, config): 48 | if config.data.random_flip is False: 49 | tran_transform = test_transform = transforms.Compose( 50 | [transforms.Resize(config.data.image_size), transforms.ToTensor()] 51 | ) 52 | else: 53 | tran_transform = transforms.Compose( 54 | [ 55 | transforms.Resize(config.data.image_size), 56 | transforms.RandomHorizontalFlip(p=0.5), 57 | transforms.ToTensor(), 58 | ] 59 | ) 60 | test_transform = transforms.Compose( 61 | [transforms.Resize(config.data.image_size), transforms.ToTensor()] 62 | ) 63 | 64 | if config.data.dataset == "CELEBA": 65 | cx = 89 66 | cy = 121 67 | x1 = cy - 64 68 | x2 = cy + 64 69 | y1 = cx - 64 70 | y2 = cx + 64 71 | if config.data.random_flip: 72 | dataset = CelebA( 73 | root=os.path.join(args.exp, "datasets", "celeba"), 74 | split="train", 75 | transform=transforms.Compose( 76 | [ 77 | Crop(x1, x2, y1, y2), 78 | transforms.Resize(config.data.image_size), 79 | transforms.RandomHorizontalFlip(), 80 | transforms.ToTensor(), 81 | ] 82 | ), 83 | download=True, 84 | ) 85 | else: 86 | dataset = CelebA( 87 | root=os.path.join(args.exp, "datasets", "celeba"), 88 | split="train", 89 | transform=transforms.Compose( 90 | [ 91 | Crop(x1, x2, y1, y2), 92 | transforms.Resize(config.data.image_size), 93 | transforms.ToTensor(), 94 | ] 95 | ), 96 | download=True, 97 | ) 98 | 99 | test_dataset = CelebA( 100 | root=os.path.join(args.exp, "datasets", "celeba"), 101 | split="test", 102 | transform=transforms.Compose( 103 | [ 104 | Crop(x1, x2, y1, y2), 105 | transforms.Resize(config.data.image_size), 106 | transforms.ToTensor(), 107 | ] 108 | ), 109 | download=True, 110 | ) 111 | 112 | elif config.data.dataset == "LSUN": 113 | if config.data.out_of_dist: 114 | dataset = torchvision.datasets.ImageFolder( 115 | os.path.join(args.exp, 'datasets', "ood_{}".format(config.data.category)), 116 | transform=transforms.Compose([partial(center_crop_arr, image_size=config.data.image_size), 117 | transforms.ToTensor()]) 118 | ) 119 | test_dataset = dataset 120 | else: 121 | train_folder = "{}_train".format(config.data.category) 122 | val_folder = "{}_val".format(config.data.category) 123 | test_dataset = LSUN( 124 | root=os.path.join(args.exp, "datasets", "lsun"), 125 | classes=[val_folder], 126 | transform=transforms.Compose( 127 | [ 128 | transforms.Resize(config.data.image_size), 129 | transforms.CenterCrop(config.data.image_size), 130 | transforms.ToTensor(), 131 | ] 132 | ) 133 | ) 134 | dataset = test_dataset 135 | 136 | elif config.data.dataset == "CelebA_HQ" or config.data.dataset == 'FFHQ': 137 | if config.data.out_of_dist: 138 | dataset = torchvision.datasets.ImageFolder( 139 | os.path.join(args.exp, "datasets", "ood_celeba"), 140 | transform=transforms.Compose([transforms.Resize([config.data.image_size, config.data.image_size]), 141 | transforms.ToTensor()]) 142 | ) 143 | test_dataset = dataset 144 | else: 145 | dataset = torchvision.datasets.ImageFolder( 146 | os.path.join(args.exp, "datasets", "celeba_hq"), 147 | transform=transforms.Compose([transforms.Resize([config.data.image_size, config.data.image_size]), 148 | transforms.ToTensor()]) 149 | ) 150 | num_items = len(dataset) 151 | indices = list(range(num_items)) 152 | random_state = np.random.get_state() 153 | np.random.seed(2019) 154 | np.random.shuffle(indices) 155 | np.random.set_state(random_state) 156 | train_indices, test_indices = ( 157 | indices[: int(num_items * 0.9)], 158 | indices[int(num_items * 0.9) :], 159 | ) 160 | test_dataset = Subset(dataset, test_indices) 161 | 162 | elif config.data.dataset == 'ImageNet': 163 | # only use validation dataset here 164 | 165 | if config.data.subset_1k: 166 | from datasets.imagenet_subset import ImageDataset 167 | dataset = ImageDataset(os.path.join(args.exp, 'datasets', 'imagenet', 'imagenet'), 168 | os.path.join(args.exp, 'imagenet_val_1k.txt'), 169 | image_size=config.data.image_size, 170 | normalize=False) 171 | test_dataset = dataset 172 | elif config.data.out_of_dist: 173 | dataset = torchvision.datasets.ImageFolder( 174 | os.path.join(args.exp, 'datasets', 'ood'), 175 | transform=transforms.Compose([partial(center_crop_arr, image_size=config.data.image_size), 176 | transforms.ToTensor()]) 177 | ) 178 | test_dataset = dataset 179 | else: 180 | dataset = torchvision.datasets.ImageNet( 181 | os.path.join(args.exp, 'datasets', 'imagenet'), split='val', 182 | transform=transforms.Compose([partial(center_crop_arr, image_size=config.data.image_size), 183 | transforms.ToTensor()]) 184 | ) 185 | test_dataset = dataset 186 | else: 187 | dataset, test_dataset = None, None 188 | 189 | return dataset, test_dataset 190 | 191 | 192 | def logit_transform(image, lam=1e-6): 193 | image = lam + (1 - 2 * lam) * image 194 | return torch.log(image) - torch.log1p(-image) 195 | 196 | 197 | def data_transform(config, X): 198 | if config.data.uniform_dequantization: 199 | X = X / 256.0 * 255.0 + torch.rand_like(X) / 256.0 200 | if config.data.gaussian_dequantization: 201 | X = X + torch.randn_like(X) * 0.01 202 | 203 | if config.data.rescaled: 204 | X = 2 * X - 1.0 205 | elif config.data.logit_transform: 206 | X = logit_transform(X) 207 | 208 | if hasattr(config, "image_mean"): 209 | return X - config.image_mean.to(X.device)[None, ...] 210 | 211 | return X 212 | 213 | 214 | def inverse_data_transform(config, X): 215 | if hasattr(config, "image_mean"): 216 | X = X + config.image_mean.to(X.device)[None, ...] 217 | 218 | if config.data.logit_transform: 219 | X = torch.sigmoid(X) 220 | elif config.data.rescaled: 221 | X = (X + 1.0) / 2.0 222 | 223 | return torch.clamp(X, 0.0, 1.0) 224 | -------------------------------------------------------------------------------- /guided_diffusion/fp16_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers to train with 16-bit precision. 3 | """ 4 | 5 | import numpy as np 6 | import torch as th 7 | import torch.nn as nn 8 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 9 | 10 | from . import logger 11 | 12 | INITIAL_LOG_LOSS_SCALE = 20.0 13 | 14 | 15 | def convert_module_to_f16(l): 16 | """ 17 | Convert primitive modules to float16. 18 | """ 19 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 20 | l.weight.data = l.weight.data.half() 21 | if l.bias is not None: 22 | l.bias.data = l.bias.data.half() 23 | 24 | 25 | def convert_module_to_f32(l): 26 | """ 27 | Convert primitive modules to float32, undoing convert_module_to_f16(). 28 | """ 29 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 30 | l.weight.data = l.weight.data.float() 31 | if l.bias is not None: 32 | l.bias.data = l.bias.data.float() 33 | 34 | 35 | def make_master_params(param_groups_and_shapes): 36 | """ 37 | Copy model parameters into a (differently-shaped) list of full-precision 38 | parameters. 39 | """ 40 | master_params = [] 41 | for param_group, shape in param_groups_and_shapes: 42 | master_param = nn.Parameter( 43 | _flatten_dense_tensors( 44 | [param.detach().float() for (_, param) in param_group] 45 | ).view(shape) 46 | ) 47 | master_param.requires_grad = True 48 | master_params.append(master_param) 49 | return master_params 50 | 51 | 52 | def model_grads_to_master_grads(param_groups_and_shapes, master_params): 53 | """ 54 | Copy the gradients from the model parameters into the master parameters 55 | from make_master_params(). 56 | """ 57 | for master_param, (param_group, shape) in zip( 58 | master_params, param_groups_and_shapes 59 | ): 60 | master_param.grad = _flatten_dense_tensors( 61 | [param_grad_or_zeros(param) for (_, param) in param_group] 62 | ).view(shape) 63 | 64 | 65 | def master_params_to_model_params(param_groups_and_shapes, master_params): 66 | """ 67 | Copy the master parameter data back into the model parameters. 68 | """ 69 | # Without copying to a list, if a generator is passed, this will 70 | # silently not copy any parameters. 71 | for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes): 72 | for (_, param), unflat_master_param in zip( 73 | param_group, unflatten_master_params(param_group, master_param.view(-1)) 74 | ): 75 | param.detach().copy_(unflat_master_param) 76 | 77 | 78 | def unflatten_master_params(param_group, master_param): 79 | return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group]) 80 | 81 | 82 | def get_param_groups_and_shapes(named_model_params): 83 | named_model_params = list(named_model_params) 84 | scalar_vector_named_params = ( 85 | [(n, p) for (n, p) in named_model_params if p.ndim <= 1], 86 | (-1), 87 | ) 88 | matrix_named_params = ( 89 | [(n, p) for (n, p) in named_model_params if p.ndim > 1], 90 | (1, -1), 91 | ) 92 | return [scalar_vector_named_params, matrix_named_params] 93 | 94 | 95 | def master_params_to_state_dict( 96 | model, param_groups_and_shapes, master_params, use_fp16 97 | ): 98 | if use_fp16: 99 | state_dict = model.state_dict() 100 | for master_param, (param_group, _) in zip( 101 | master_params, param_groups_and_shapes 102 | ): 103 | for (name, _), unflat_master_param in zip( 104 | param_group, unflatten_master_params(param_group, master_param.view(-1)) 105 | ): 106 | assert name in state_dict 107 | state_dict[name] = unflat_master_param 108 | else: 109 | state_dict = model.state_dict() 110 | for i, (name, _value) in enumerate(model.named_parameters()): 111 | assert name in state_dict 112 | state_dict[name] = master_params[i] 113 | return state_dict 114 | 115 | 116 | def state_dict_to_master_params(model, state_dict, use_fp16): 117 | if use_fp16: 118 | named_model_params = [ 119 | (name, state_dict[name]) for name, _ in model.named_parameters() 120 | ] 121 | param_groups_and_shapes = get_param_groups_and_shapes(named_model_params) 122 | master_params = make_master_params(param_groups_and_shapes) 123 | else: 124 | master_params = [state_dict[name] for name, _ in model.named_parameters()] 125 | return master_params 126 | 127 | 128 | def zero_master_grads(master_params): 129 | for param in master_params: 130 | param.grad = None 131 | 132 | 133 | def zero_grad(model_params): 134 | for param in model_params: 135 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group 136 | if param.grad is not None: 137 | param.grad.detach_() 138 | param.grad.zero_() 139 | 140 | 141 | def param_grad_or_zeros(param): 142 | if param.grad is not None: 143 | return param.grad.data.detach() 144 | else: 145 | return th.zeros_like(param) 146 | 147 | 148 | class MixedPrecisionTrainer: 149 | def __init__( 150 | self, 151 | *, 152 | model, 153 | use_fp16=False, 154 | fp16_scale_growth=1e-3, 155 | initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE, 156 | ): 157 | self.model = model 158 | self.use_fp16 = use_fp16 159 | self.fp16_scale_growth = fp16_scale_growth 160 | 161 | self.model_params = list(self.model.parameters()) 162 | self.master_params = self.model_params 163 | self.param_groups_and_shapes = None 164 | self.lg_loss_scale = initial_lg_loss_scale 165 | 166 | if self.use_fp16: 167 | self.param_groups_and_shapes = get_param_groups_and_shapes( 168 | self.model.named_parameters() 169 | ) 170 | self.master_params = make_master_params(self.param_groups_and_shapes) 171 | self.model.convert_to_fp16() 172 | 173 | def zero_grad(self): 174 | zero_grad(self.model_params) 175 | 176 | def backward(self, loss: th.Tensor): 177 | if self.use_fp16: 178 | loss_scale = 2 ** self.lg_loss_scale 179 | (loss * loss_scale).backward() 180 | else: 181 | loss.backward() 182 | 183 | def optimize(self, opt: th.optim.Optimizer): 184 | if self.use_fp16: 185 | return self._optimize_fp16(opt) 186 | else: 187 | return self._optimize_normal(opt) 188 | 189 | def _optimize_fp16(self, opt: th.optim.Optimizer): 190 | logger.logkv_mean("lg_loss_scale", self.lg_loss_scale) 191 | model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params) 192 | grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale) 193 | if check_overflow(grad_norm): 194 | self.lg_loss_scale -= 1 195 | logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}") 196 | zero_master_grads(self.master_params) 197 | return False 198 | 199 | logger.logkv_mean("grad_norm", grad_norm) 200 | logger.logkv_mean("param_norm", param_norm) 201 | 202 | self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale)) 203 | opt.step() 204 | zero_master_grads(self.master_params) 205 | master_params_to_model_params(self.param_groups_and_shapes, self.master_params) 206 | self.lg_loss_scale += self.fp16_scale_growth 207 | return True 208 | 209 | def _optimize_normal(self, opt: th.optim.Optimizer): 210 | grad_norm, param_norm = self._compute_norms() 211 | logger.logkv_mean("grad_norm", grad_norm) 212 | logger.logkv_mean("param_norm", param_norm) 213 | opt.step() 214 | return True 215 | 216 | def _compute_norms(self, grad_scale=1.0): 217 | grad_norm = 0.0 218 | param_norm = 0.0 219 | for p in self.master_params: 220 | with th.no_grad(): 221 | param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2 222 | if p.grad is not None: 223 | grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2 224 | return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm) 225 | 226 | def master_params_to_state_dict(self, master_params): 227 | return master_params_to_state_dict( 228 | self.model, self.param_groups_and_shapes, master_params, self.use_fp16 229 | ) 230 | 231 | def state_dict_to_master_params(self, state_dict): 232 | return state_dict_to_master_params(self.model, state_dict, self.use_fp16) 233 | 234 | 235 | def check_overflow(value): 236 | return (value == float("inf")) or (value == -float("inf")) or (value != value) 237 | -------------------------------------------------------------------------------- /models/diffusion.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | def get_timestep_embedding(timesteps, embedding_dim): 7 | """ 8 | This matches the implementation in Denoising Diffusion Probabilistic Models: 9 | From Fairseq. 10 | Build sinusoidal embeddings. 11 | This matches the implementation in tensor2tensor, but differs slightly 12 | from the description in Section 3.5 of "Attention Is All You Need". 13 | """ 14 | assert len(timesteps.shape) == 1 15 | 16 | half_dim = embedding_dim // 2 17 | emb = math.log(10000) / (half_dim - 1) 18 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) 19 | emb = emb.to(device=timesteps.device) 20 | emb = timesteps.float()[:, None] * emb[None, :] 21 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) 22 | if embedding_dim % 2 == 1: # zero pad 23 | emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) 24 | return emb 25 | 26 | 27 | def nonlinearity(x): 28 | # swish 29 | return x*torch.sigmoid(x) 30 | 31 | 32 | def Normalize(in_channels): 33 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 34 | 35 | 36 | class Upsample(nn.Module): 37 | def __init__(self, in_channels, with_conv): 38 | super().__init__() 39 | self.with_conv = with_conv 40 | if self.with_conv: 41 | self.conv = torch.nn.Conv2d(in_channels, 42 | in_channels, 43 | kernel_size=3, 44 | stride=1, 45 | padding=1) 46 | 47 | def forward(self, x): 48 | x = torch.nn.functional.interpolate( 49 | x, scale_factor=2.0, mode="nearest") 50 | if self.with_conv: 51 | x = self.conv(x) 52 | return x 53 | 54 | 55 | class Downsample(nn.Module): 56 | def __init__(self, in_channels, with_conv): 57 | super().__init__() 58 | self.with_conv = with_conv 59 | if self.with_conv: 60 | # no asymmetric padding in torch conv, must do it ourselves 61 | self.conv = torch.nn.Conv2d(in_channels, 62 | in_channels, 63 | kernel_size=3, 64 | stride=2, 65 | padding=0) 66 | 67 | def forward(self, x): 68 | if self.with_conv: 69 | pad = (0, 1, 0, 1) 70 | x = torch.nn.functional.pad(x, pad, mode="constant", value=0) 71 | x = self.conv(x) 72 | else: 73 | x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) 74 | return x 75 | 76 | 77 | class ResnetBlock(nn.Module): 78 | def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, 79 | dropout, temb_channels=512): 80 | super().__init__() 81 | self.in_channels = in_channels 82 | out_channels = in_channels if out_channels is None else out_channels 83 | self.out_channels = out_channels 84 | self.use_conv_shortcut = conv_shortcut 85 | 86 | self.norm1 = Normalize(in_channels) 87 | self.conv1 = torch.nn.Conv2d(in_channels, 88 | out_channels, 89 | kernel_size=3, 90 | stride=1, 91 | padding=1) 92 | self.temb_proj = torch.nn.Linear(temb_channels, 93 | out_channels) 94 | self.norm2 = Normalize(out_channels) 95 | self.dropout = torch.nn.Dropout(dropout) 96 | self.conv2 = torch.nn.Conv2d(out_channels, 97 | out_channels, 98 | kernel_size=3, 99 | stride=1, 100 | padding=1) 101 | if self.in_channels != self.out_channels: 102 | if self.use_conv_shortcut: 103 | self.conv_shortcut = torch.nn.Conv2d(in_channels, 104 | out_channels, 105 | kernel_size=3, 106 | stride=1, 107 | padding=1) 108 | else: 109 | self.nin_shortcut = torch.nn.Conv2d(in_channels, 110 | out_channels, 111 | kernel_size=1, 112 | stride=1, 113 | padding=0) 114 | 115 | def forward(self, x, temb): 116 | h = x 117 | h = self.norm1(h) 118 | h = nonlinearity(h) 119 | h = self.conv1(h) 120 | 121 | h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] 122 | 123 | h = self.norm2(h) 124 | h = nonlinearity(h) 125 | h = self.dropout(h) 126 | h = self.conv2(h) 127 | 128 | if self.in_channels != self.out_channels: 129 | if self.use_conv_shortcut: 130 | x = self.conv_shortcut(x) 131 | else: 132 | x = self.nin_shortcut(x) 133 | 134 | return x+h 135 | 136 | 137 | class AttnBlock(nn.Module): 138 | def __init__(self, in_channels): 139 | super().__init__() 140 | self.in_channels = in_channels 141 | 142 | self.norm = Normalize(in_channels) 143 | self.q = torch.nn.Conv2d(in_channels, 144 | in_channels, 145 | kernel_size=1, 146 | stride=1, 147 | padding=0) 148 | self.k = torch.nn.Conv2d(in_channels, 149 | in_channels, 150 | kernel_size=1, 151 | stride=1, 152 | padding=0) 153 | self.v = torch.nn.Conv2d(in_channels, 154 | in_channels, 155 | kernel_size=1, 156 | stride=1, 157 | padding=0) 158 | self.proj_out = torch.nn.Conv2d(in_channels, 159 | in_channels, 160 | kernel_size=1, 161 | stride=1, 162 | padding=0) 163 | 164 | def forward(self, x): 165 | h_ = x 166 | h_ = self.norm(h_) 167 | q = self.q(h_) 168 | k = self.k(h_) 169 | v = self.v(h_) 170 | 171 | # compute attention 172 | b, c, h, w = q.shape 173 | q = q.reshape(b, c, h*w) 174 | q = q.permute(0, 2, 1) # b,hw,c 175 | k = k.reshape(b, c, h*w) # b,c,hw 176 | w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] 177 | w_ = w_ * (int(c)**(-0.5)) 178 | w_ = torch.nn.functional.softmax(w_, dim=2) 179 | 180 | # attend to values 181 | v = v.reshape(b, c, h*w) 182 | w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) 183 | # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] 184 | h_ = torch.bmm(v, w_) 185 | h_ = h_.reshape(b, c, h, w) 186 | 187 | h_ = self.proj_out(h_) 188 | 189 | return x+h_ 190 | 191 | 192 | class Model(nn.Module): 193 | def __init__(self, config): 194 | super().__init__() 195 | self.config = config 196 | ch, out_ch, ch_mult = config.model.ch, config.model.out_ch, tuple(config.model.ch_mult) 197 | num_res_blocks = config.model.num_res_blocks 198 | attn_resolutions = config.model.attn_resolutions 199 | dropout = config.model.dropout 200 | in_channels = config.model.in_channels 201 | resolution = config.data.image_size 202 | resamp_with_conv = config.model.resamp_with_conv 203 | num_timesteps = config.diffusion.num_diffusion_timesteps 204 | 205 | if config.model.type == 'bayesian': 206 | self.logvar = nn.Parameter(torch.zeros(num_timesteps)) 207 | 208 | self.ch = ch 209 | self.temb_ch = self.ch*4 210 | self.num_resolutions = len(ch_mult) 211 | self.num_res_blocks = num_res_blocks 212 | self.resolution = resolution 213 | self.in_channels = in_channels 214 | 215 | # timestep embedding 216 | self.temb = nn.Module() 217 | self.temb.dense = nn.ModuleList([ 218 | torch.nn.Linear(self.ch, 219 | self.temb_ch), 220 | torch.nn.Linear(self.temb_ch, 221 | self.temb_ch), 222 | ]) 223 | 224 | # downsampling 225 | self.conv_in = torch.nn.Conv2d(in_channels, 226 | self.ch, 227 | kernel_size=3, 228 | stride=1, 229 | padding=1) 230 | 231 | curr_res = resolution 232 | in_ch_mult = (1,)+ch_mult 233 | self.down = nn.ModuleList() 234 | block_in = None 235 | for i_level in range(self.num_resolutions): 236 | block = nn.ModuleList() 237 | attn = nn.ModuleList() 238 | block_in = ch*in_ch_mult[i_level] 239 | block_out = ch*ch_mult[i_level] 240 | for i_block in range(self.num_res_blocks): 241 | block.append(ResnetBlock(in_channels=block_in, 242 | out_channels=block_out, 243 | temb_channels=self.temb_ch, 244 | dropout=dropout)) 245 | block_in = block_out 246 | if curr_res in attn_resolutions: 247 | attn.append(AttnBlock(block_in)) 248 | down = nn.Module() 249 | down.block = block 250 | down.attn = attn 251 | if i_level != self.num_resolutions-1: 252 | down.downsample = Downsample(block_in, resamp_with_conv) 253 | curr_res = curr_res // 2 254 | self.down.append(down) 255 | 256 | # middle 257 | self.mid = nn.Module() 258 | self.mid.block_1 = ResnetBlock(in_channels=block_in, 259 | out_channels=block_in, 260 | temb_channels=self.temb_ch, 261 | dropout=dropout) 262 | self.mid.attn_1 = AttnBlock(block_in) 263 | self.mid.block_2 = ResnetBlock(in_channels=block_in, 264 | out_channels=block_in, 265 | temb_channels=self.temb_ch, 266 | dropout=dropout) 267 | 268 | # upsampling 269 | self.up = nn.ModuleList() 270 | for i_level in reversed(range(self.num_resolutions)): 271 | block = nn.ModuleList() 272 | attn = nn.ModuleList() 273 | block_out = ch*ch_mult[i_level] 274 | skip_in = ch*ch_mult[i_level] 275 | for i_block in range(self.num_res_blocks+1): 276 | if i_block == self.num_res_blocks: 277 | skip_in = ch*in_ch_mult[i_level] 278 | block.append(ResnetBlock(in_channels=block_in+skip_in, 279 | out_channels=block_out, 280 | temb_channels=self.temb_ch, 281 | dropout=dropout)) 282 | block_in = block_out 283 | if curr_res in attn_resolutions: 284 | attn.append(AttnBlock(block_in)) 285 | up = nn.Module() 286 | up.block = block 287 | up.attn = attn 288 | if i_level != 0: 289 | up.upsample = Upsample(block_in, resamp_with_conv) 290 | curr_res = curr_res * 2 291 | self.up.insert(0, up) # prepend to get consistent order 292 | 293 | # end 294 | self.norm_out = Normalize(block_in) 295 | self.conv_out = torch.nn.Conv2d(block_in, 296 | out_ch, 297 | kernel_size=3, 298 | stride=1, 299 | padding=1) 300 | 301 | def forward(self, x, t): 302 | assert x.shape[2] == x.shape[3] == self.resolution 303 | 304 | # timestep embedding 305 | temb = get_timestep_embedding(t, self.ch) 306 | temb = self.temb.dense[0](temb) 307 | temb = nonlinearity(temb) 308 | temb = self.temb.dense[1](temb) 309 | 310 | # downsampling 311 | hs = [self.conv_in(x)] 312 | for i_level in range(self.num_resolutions): 313 | for i_block in range(self.num_res_blocks): 314 | h = self.down[i_level].block[i_block](hs[-1], temb) 315 | if len(self.down[i_level].attn) > 0: 316 | h = self.down[i_level].attn[i_block](h) 317 | hs.append(h) 318 | if i_level != self.num_resolutions-1: 319 | hs.append(self.down[i_level].downsample(hs[-1])) 320 | 321 | # middle 322 | h = hs[-1] 323 | h = self.mid.block_1(h, temb) 324 | h = self.mid.attn_1(h) 325 | h = self.mid.block_2(h, temb) 326 | 327 | # upsampling 328 | for i_level in reversed(range(self.num_resolutions)): 329 | for i_block in range(self.num_res_blocks+1): 330 | h = self.up[i_level].block[i_block]( 331 | torch.cat([h, hs.pop()], dim=1), temb) 332 | if len(self.up[i_level].attn) > 0: 333 | h = self.up[i_level].attn[i_block](h) 334 | if i_level != 0: 335 | h = self.up[i_level].upsample(h) 336 | 337 | # end 338 | h = self.norm_out(h) 339 | h = nonlinearity(h) 340 | h = self.conv_out(h) 341 | return h 342 | -------------------------------------------------------------------------------- /guided_diffusion/script_util.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import inspect 3 | 4 | #from . import gaussian_diffusion as gd 5 | #from .respace import SpacedDiffusion, space_timesteps 6 | from .unet import SuperResModel, UNetModel, EncoderUNetModel 7 | 8 | NUM_CLASSES = 1000 9 | 10 | 11 | def diffusion_defaults(): 12 | """ 13 | Defaults for image and classifier training. 14 | """ 15 | return dict( 16 | learn_sigma=False, 17 | diffusion_steps=1000, 18 | noise_schedule="linear", 19 | timestep_respacing="", 20 | use_kl=False, 21 | predict_xstart=False, 22 | rescale_timesteps=False, 23 | rescale_learned_sigmas=False, 24 | ) 25 | 26 | 27 | def classifier_defaults(): 28 | """ 29 | Defaults for classifier models. 30 | """ 31 | return dict( 32 | image_size=64, 33 | classifier_use_fp16=False, 34 | classifier_width=128, 35 | classifier_depth=2, 36 | classifier_attention_resolutions="32,16,8", # 16 37 | classifier_use_scale_shift_norm=True, # False 38 | classifier_resblock_updown=True, # False 39 | classifier_pool="attention", 40 | ) 41 | 42 | 43 | def model_and_diffusion_defaults(): 44 | """ 45 | Defaults for image training. 46 | """ 47 | res = dict( 48 | image_size=64, 49 | num_channels=128, 50 | num_res_blocks=2, 51 | num_heads=4, 52 | num_heads_upsample=-1, 53 | num_head_channels=-1, 54 | attention_resolutions="16,8", 55 | channel_mult="", 56 | dropout=0.0, 57 | class_cond=False, 58 | use_checkpoint=False, 59 | use_scale_shift_norm=True, 60 | resblock_updown=False, 61 | use_fp16=False, 62 | use_new_attention_order=False, 63 | ) 64 | res.update(diffusion_defaults()) 65 | return res 66 | 67 | 68 | def classifier_and_diffusion_defaults(): 69 | res = classifier_defaults() 70 | res.update(diffusion_defaults()) 71 | return res 72 | 73 | 74 | def create_model_and_diffusion( 75 | image_size, 76 | class_cond, 77 | learn_sigma, 78 | num_channels, 79 | num_res_blocks, 80 | channel_mult, 81 | num_heads, 82 | num_head_channels, 83 | num_heads_upsample, 84 | attention_resolutions, 85 | dropout, 86 | diffusion_steps, 87 | noise_schedule, 88 | timestep_respacing, 89 | use_kl, 90 | predict_xstart, 91 | rescale_timesteps, 92 | rescale_learned_sigmas, 93 | use_checkpoint, 94 | use_scale_shift_norm, 95 | resblock_updown, 96 | use_fp16, 97 | use_new_attention_order, 98 | ): 99 | model = create_model( 100 | image_size, 101 | num_channels, 102 | num_res_blocks, 103 | channel_mult=channel_mult, 104 | learn_sigma=learn_sigma, 105 | class_cond=class_cond, 106 | use_checkpoint=use_checkpoint, 107 | attention_resolutions=attention_resolutions, 108 | num_heads=num_heads, 109 | num_head_channels=num_head_channels, 110 | num_heads_upsample=num_heads_upsample, 111 | use_scale_shift_norm=use_scale_shift_norm, 112 | dropout=dropout, 113 | resblock_updown=resblock_updown, 114 | use_fp16=use_fp16, 115 | use_new_attention_order=use_new_attention_order, 116 | ) 117 | diffusion = create_gaussian_diffusion( 118 | steps=diffusion_steps, 119 | learn_sigma=learn_sigma, 120 | noise_schedule=noise_schedule, 121 | use_kl=use_kl, 122 | predict_xstart=predict_xstart, 123 | rescale_timesteps=rescale_timesteps, 124 | rescale_learned_sigmas=rescale_learned_sigmas, 125 | timestep_respacing=timestep_respacing, 126 | ) 127 | return model, diffusion 128 | 129 | 130 | def create_model( 131 | image_size, 132 | num_channels, 133 | num_res_blocks, 134 | channel_mult="", 135 | learn_sigma=False, 136 | class_cond=False, 137 | use_checkpoint=False, 138 | attention_resolutions="16", 139 | num_heads=1, 140 | num_head_channels=-1, 141 | num_heads_upsample=-1, 142 | use_scale_shift_norm=False, 143 | dropout=0, 144 | resblock_updown=False, 145 | use_fp16=False, 146 | use_new_attention_order=False, 147 | **kwargs 148 | ): 149 | if channel_mult == "": 150 | if image_size == 512: 151 | channel_mult = (0.5, 1, 1, 2, 2, 4, 4) 152 | elif image_size == 256: 153 | channel_mult = (1, 1, 2, 2, 4, 4) 154 | elif image_size == 128: 155 | channel_mult = (1, 1, 2, 3, 4) 156 | elif image_size == 64: 157 | channel_mult = (1, 2, 3, 4) 158 | else: 159 | raise ValueError(f"unsupported image size: {image_size}") 160 | else: 161 | channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(",")) 162 | 163 | attention_ds = [] 164 | for res in attention_resolutions.split(","): 165 | attention_ds.append(image_size // int(res)) 166 | 167 | return UNetModel( 168 | image_size=image_size, 169 | in_channels=3, 170 | model_channels=num_channels, 171 | out_channels=(3 if not learn_sigma else 6), 172 | num_res_blocks=num_res_blocks, 173 | attention_resolutions=tuple(attention_ds), 174 | dropout=dropout, 175 | channel_mult=channel_mult, 176 | num_classes=(NUM_CLASSES if class_cond else None), 177 | use_checkpoint=use_checkpoint, 178 | use_fp16=use_fp16, 179 | num_heads=num_heads, 180 | num_head_channels=num_head_channels, 181 | num_heads_upsample=num_heads_upsample, 182 | use_scale_shift_norm=use_scale_shift_norm, 183 | resblock_updown=resblock_updown, 184 | use_new_attention_order=use_new_attention_order, 185 | ) 186 | 187 | 188 | def create_classifier_and_diffusion( 189 | image_size, 190 | classifier_use_fp16, 191 | classifier_width, 192 | classifier_depth, 193 | classifier_attention_resolutions, 194 | classifier_use_scale_shift_norm, 195 | classifier_resblock_updown, 196 | classifier_pool, 197 | learn_sigma, 198 | diffusion_steps, 199 | noise_schedule, 200 | timestep_respacing, 201 | use_kl, 202 | predict_xstart, 203 | rescale_timesteps, 204 | rescale_learned_sigmas, 205 | ): 206 | classifier = create_classifier( 207 | image_size, 208 | classifier_use_fp16, 209 | classifier_width, 210 | classifier_depth, 211 | classifier_attention_resolutions, 212 | classifier_use_scale_shift_norm, 213 | classifier_resblock_updown, 214 | classifier_pool, 215 | ) 216 | diffusion = create_gaussian_diffusion( 217 | steps=diffusion_steps, 218 | learn_sigma=learn_sigma, 219 | noise_schedule=noise_schedule, 220 | use_kl=use_kl, 221 | predict_xstart=predict_xstart, 222 | rescale_timesteps=rescale_timesteps, 223 | rescale_learned_sigmas=rescale_learned_sigmas, 224 | timestep_respacing=timestep_respacing, 225 | ) 226 | return classifier, diffusion 227 | 228 | 229 | def create_classifier( 230 | image_size, 231 | classifier_use_fp16, 232 | classifier_width, 233 | classifier_depth, 234 | classifier_attention_resolutions, 235 | classifier_use_scale_shift_norm, 236 | classifier_resblock_updown, 237 | classifier_pool, 238 | ): 239 | if image_size == 512: 240 | channel_mult = (0.5, 1, 1, 2, 2, 4, 4) 241 | elif image_size == 256: 242 | channel_mult = (1, 1, 2, 2, 4, 4) 243 | elif image_size == 128: 244 | channel_mult = (1, 1, 2, 3, 4) 245 | elif image_size == 64: 246 | channel_mult = (1, 2, 3, 4) 247 | else: 248 | raise ValueError(f"unsupported image size: {image_size}") 249 | 250 | attention_ds = [] 251 | for res in classifier_attention_resolutions.split(","): 252 | attention_ds.append(image_size // int(res)) 253 | 254 | return EncoderUNetModel( 255 | image_size=image_size, 256 | in_channels=3, 257 | model_channels=classifier_width, 258 | out_channels=1000, 259 | num_res_blocks=classifier_depth, 260 | attention_resolutions=tuple(attention_ds), 261 | channel_mult=channel_mult, 262 | use_fp16=classifier_use_fp16, 263 | num_head_channels=64, 264 | use_scale_shift_norm=classifier_use_scale_shift_norm, 265 | resblock_updown=classifier_resblock_updown, 266 | pool=classifier_pool, 267 | ) 268 | 269 | 270 | def sr_model_and_diffusion_defaults(): 271 | res = model_and_diffusion_defaults() 272 | res["large_size"] = 256 273 | res["small_size"] = 64 274 | arg_names = inspect.getfullargspec(sr_create_model_and_diffusion)[0] 275 | for k in res.copy().keys(): 276 | if k not in arg_names: 277 | del res[k] 278 | return res 279 | 280 | 281 | def sr_create_model_and_diffusion( 282 | large_size, 283 | small_size, 284 | class_cond, 285 | learn_sigma, 286 | num_channels, 287 | num_res_blocks, 288 | num_heads, 289 | num_head_channels, 290 | num_heads_upsample, 291 | attention_resolutions, 292 | dropout, 293 | diffusion_steps, 294 | noise_schedule, 295 | timestep_respacing, 296 | use_kl, 297 | predict_xstart, 298 | rescale_timesteps, 299 | rescale_learned_sigmas, 300 | use_checkpoint, 301 | use_scale_shift_norm, 302 | resblock_updown, 303 | use_fp16, 304 | ): 305 | model = sr_create_model( 306 | large_size, 307 | small_size, 308 | num_channels, 309 | num_res_blocks, 310 | learn_sigma=learn_sigma, 311 | class_cond=class_cond, 312 | use_checkpoint=use_checkpoint, 313 | attention_resolutions=attention_resolutions, 314 | num_heads=num_heads, 315 | num_head_channels=num_head_channels, 316 | num_heads_upsample=num_heads_upsample, 317 | use_scale_shift_norm=use_scale_shift_norm, 318 | dropout=dropout, 319 | resblock_updown=resblock_updown, 320 | use_fp16=use_fp16, 321 | ) 322 | diffusion = create_gaussian_diffusion( 323 | steps=diffusion_steps, 324 | learn_sigma=learn_sigma, 325 | noise_schedule=noise_schedule, 326 | use_kl=use_kl, 327 | predict_xstart=predict_xstart, 328 | rescale_timesteps=rescale_timesteps, 329 | rescale_learned_sigmas=rescale_learned_sigmas, 330 | timestep_respacing=timestep_respacing, 331 | ) 332 | return model, diffusion 333 | 334 | 335 | def sr_create_model( 336 | large_size, 337 | small_size, 338 | num_channels, 339 | num_res_blocks, 340 | learn_sigma, 341 | class_cond, 342 | use_checkpoint, 343 | attention_resolutions, 344 | num_heads, 345 | num_head_channels, 346 | num_heads_upsample, 347 | use_scale_shift_norm, 348 | dropout, 349 | resblock_updown, 350 | use_fp16, 351 | ): 352 | _ = small_size # hack to prevent unused variable 353 | 354 | if large_size == 512: 355 | channel_mult = (1, 1, 2, 2, 4, 4) 356 | elif large_size == 256: 357 | channel_mult = (1, 1, 2, 2, 4, 4) 358 | elif large_size == 64: 359 | channel_mult = (1, 2, 3, 4) 360 | else: 361 | raise ValueError(f"unsupported large size: {large_size}") 362 | 363 | attention_ds = [] 364 | for res in attention_resolutions.split(","): 365 | attention_ds.append(large_size // int(res)) 366 | 367 | return SuperResModel( 368 | image_size=large_size, 369 | in_channels=3, 370 | model_channels=num_channels, 371 | out_channels=(3 if not learn_sigma else 6), 372 | num_res_blocks=num_res_blocks, 373 | attention_resolutions=tuple(attention_ds), 374 | dropout=dropout, 375 | channel_mult=channel_mult, 376 | num_classes=(NUM_CLASSES if class_cond else None), 377 | use_checkpoint=use_checkpoint, 378 | num_heads=num_heads, 379 | num_head_channels=num_head_channels, 380 | num_heads_upsample=num_heads_upsample, 381 | use_scale_shift_norm=use_scale_shift_norm, 382 | resblock_updown=resblock_updown, 383 | use_fp16=use_fp16, 384 | ) 385 | 386 | 387 | def create_gaussian_diffusion( 388 | *, 389 | steps=1000, 390 | learn_sigma=False, 391 | sigma_small=False, 392 | noise_schedule="linear", 393 | use_kl=False, 394 | predict_xstart=False, 395 | rescale_timesteps=False, 396 | rescale_learned_sigmas=False, 397 | timestep_respacing="", 398 | ): 399 | betas = gd.get_named_beta_schedule(noise_schedule, steps) 400 | if use_kl: 401 | loss_type = gd.LossType.RESCALED_KL 402 | elif rescale_learned_sigmas: 403 | loss_type = gd.LossType.RESCALED_MSE 404 | else: 405 | loss_type = gd.LossType.MSE 406 | if not timestep_respacing: 407 | timestep_respacing = [steps] 408 | return SpacedDiffusion( 409 | use_timesteps=space_timesteps(steps, timestep_respacing), 410 | betas=betas, 411 | model_mean_type=( 412 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X 413 | ), 414 | model_var_type=( 415 | ( 416 | gd.ModelVarType.FIXED_LARGE 417 | if not sigma_small 418 | else gd.ModelVarType.FIXED_SMALL 419 | ) 420 | if not learn_sigma 421 | else gd.ModelVarType.LEARNED_RANGE 422 | ), 423 | loss_type=loss_type, 424 | rescale_timesteps=rescale_timesteps, 425 | ) 426 | 427 | 428 | def add_dict_to_argparser(parser, default_dict): 429 | for k, v in default_dict.items(): 430 | v_type = type(v) 431 | if v is None: 432 | v_type = str 433 | elif isinstance(v, bool): 434 | v_type = str2bool 435 | parser.add_argument(f"--{k}", default=v, type=v_type) 436 | 437 | 438 | def args_to_dict(args, keys): 439 | return {k: getattr(args, k) for k in keys} 440 | 441 | 442 | def str2bool(v): 443 | """ 444 | https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse 445 | """ 446 | if isinstance(v, bool): 447 | return v 448 | if v.lower() in ("yes", "true", "t", "y", "1"): 449 | return True 450 | elif v.lower() in ("no", "false", "f", "n", "0"): 451 | return False 452 | else: 453 | raise argparse.ArgumentTypeError("boolean value expected") 454 | -------------------------------------------------------------------------------- /guided_diffusion/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Logger copied from OpenAI baselines to avoid extra RL-based dependencies: 3 | https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py 4 | """ 5 | 6 | import os 7 | import sys 8 | import shutil 9 | import os.path as osp 10 | import json 11 | import time 12 | import datetime 13 | import tempfile 14 | import warnings 15 | from collections import defaultdict 16 | from contextlib import contextmanager 17 | 18 | DEBUG = 10 19 | INFO = 20 20 | WARN = 30 21 | ERROR = 40 22 | 23 | DISABLED = 50 24 | 25 | 26 | class KVWriter(object): 27 | def writekvs(self, kvs): 28 | raise NotImplementedError 29 | 30 | 31 | class SeqWriter(object): 32 | def writeseq(self, seq): 33 | raise NotImplementedError 34 | 35 | 36 | class HumanOutputFormat(KVWriter, SeqWriter): 37 | def __init__(self, filename_or_file): 38 | if isinstance(filename_or_file, str): 39 | self.file = open(filename_or_file, "wt") 40 | self.own_file = True 41 | else: 42 | assert hasattr(filename_or_file, "read"), ( 43 | "expected file or str, got %s" % filename_or_file 44 | ) 45 | self.file = filename_or_file 46 | self.own_file = False 47 | 48 | def writekvs(self, kvs): 49 | # Create strings for printing 50 | key2str = {} 51 | for (key, val) in sorted(kvs.items()): 52 | if hasattr(val, "__float__"): 53 | valstr = "%-8.3g" % val 54 | else: 55 | valstr = str(val) 56 | key2str[self._truncate(key)] = self._truncate(valstr) 57 | 58 | # Find max widths 59 | if len(key2str) == 0: 60 | print("WARNING: tried to write empty key-value dict") 61 | return 62 | else: 63 | keywidth = max(map(len, key2str.keys())) 64 | valwidth = max(map(len, key2str.values())) 65 | 66 | # Write out the data 67 | dashes = "-" * (keywidth + valwidth + 7) 68 | lines = [dashes] 69 | for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()): 70 | lines.append( 71 | "| %s%s | %s%s |" 72 | % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val))) 73 | ) 74 | lines.append(dashes) 75 | self.file.write("\n".join(lines) + "\n") 76 | 77 | # Flush the output to the file 78 | self.file.flush() 79 | 80 | def _truncate(self, s): 81 | maxlen = 30 82 | return s[: maxlen - 3] + "..." if len(s) > maxlen else s 83 | 84 | def writeseq(self, seq): 85 | seq = list(seq) 86 | for (i, elem) in enumerate(seq): 87 | self.file.write(elem) 88 | if i < len(seq) - 1: # add space unless this is the last one 89 | self.file.write(" ") 90 | self.file.write("\n") 91 | self.file.flush() 92 | 93 | def close(self): 94 | if self.own_file: 95 | self.file.close() 96 | 97 | 98 | class JSONOutputFormat(KVWriter): 99 | def __init__(self, filename): 100 | self.file = open(filename, "wt") 101 | 102 | def writekvs(self, kvs): 103 | for k, v in sorted(kvs.items()): 104 | if hasattr(v, "dtype"): 105 | kvs[k] = float(v) 106 | self.file.write(json.dumps(kvs) + "\n") 107 | self.file.flush() 108 | 109 | def close(self): 110 | self.file.close() 111 | 112 | 113 | class CSVOutputFormat(KVWriter): 114 | def __init__(self, filename): 115 | self.file = open(filename, "w+t") 116 | self.keys = [] 117 | self.sep = "," 118 | 119 | def writekvs(self, kvs): 120 | # Add our current row to the history 121 | extra_keys = list(kvs.keys() - self.keys) 122 | extra_keys.sort() 123 | if extra_keys: 124 | self.keys.extend(extra_keys) 125 | self.file.seek(0) 126 | lines = self.file.readlines() 127 | self.file.seek(0) 128 | for (i, k) in enumerate(self.keys): 129 | if i > 0: 130 | self.file.write(",") 131 | self.file.write(k) 132 | self.file.write("\n") 133 | for line in lines[1:]: 134 | self.file.write(line[:-1]) 135 | self.file.write(self.sep * len(extra_keys)) 136 | self.file.write("\n") 137 | for (i, k) in enumerate(self.keys): 138 | if i > 0: 139 | self.file.write(",") 140 | v = kvs.get(k) 141 | if v is not None: 142 | self.file.write(str(v)) 143 | self.file.write("\n") 144 | self.file.flush() 145 | 146 | def close(self): 147 | self.file.close() 148 | 149 | 150 | class TensorBoardOutputFormat(KVWriter): 151 | """ 152 | Dumps key/value pairs into TensorBoard's numeric format. 153 | """ 154 | 155 | def __init__(self, dir): 156 | os.makedirs(dir, exist_ok=True) 157 | self.dir = dir 158 | self.step = 1 159 | prefix = "events" 160 | path = osp.join(osp.abspath(dir), prefix) 161 | import tensorflow as tf 162 | from tensorflow.python import pywrap_tensorflow 163 | from tensorflow.core.util import event_pb2 164 | from tensorflow.python.util import compat 165 | 166 | self.tf = tf 167 | self.event_pb2 = event_pb2 168 | self.pywrap_tensorflow = pywrap_tensorflow 169 | self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path)) 170 | 171 | def writekvs(self, kvs): 172 | def summary_val(k, v): 173 | kwargs = {"tag": k, "simple_value": float(v)} 174 | return self.tf.Summary.Value(**kwargs) 175 | 176 | summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()]) 177 | event = self.event_pb2.Event(wall_time=time.time(), summary=summary) 178 | event.step = ( 179 | self.step 180 | ) # is there any reason why you'd want to specify the step? 181 | self.writer.WriteEvent(event) 182 | self.writer.Flush() 183 | self.step += 1 184 | 185 | def close(self): 186 | if self.writer: 187 | self.writer.Close() 188 | self.writer = None 189 | 190 | 191 | def make_output_format(format, ev_dir, log_suffix=""): 192 | os.makedirs(ev_dir, exist_ok=True) 193 | if format == "stdout": 194 | return HumanOutputFormat(sys.stdout) 195 | elif format == "log": 196 | return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix)) 197 | elif format == "json": 198 | return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix)) 199 | elif format == "csv": 200 | return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix)) 201 | elif format == "tensorboard": 202 | return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix)) 203 | else: 204 | raise ValueError("Unknown format specified: %s" % (format,)) 205 | 206 | 207 | # ================================================================ 208 | # API 209 | # ================================================================ 210 | 211 | 212 | def logkv(key, val): 213 | """ 214 | Log a value of some diagnostic 215 | Call this once for each diagnostic quantity, each iteration 216 | If called many times, last value will be used. 217 | """ 218 | get_current().logkv(key, val) 219 | 220 | 221 | def logkv_mean(key, val): 222 | """ 223 | The same as logkv(), but if called many times, values averaged. 224 | """ 225 | get_current().logkv_mean(key, val) 226 | 227 | 228 | def logkvs(d): 229 | """ 230 | Log a dictionary of key-value pairs 231 | """ 232 | for (k, v) in d.items(): 233 | logkv(k, v) 234 | 235 | 236 | def dumpkvs(): 237 | """ 238 | Write all of the diagnostics from the current iteration 239 | """ 240 | return get_current().dumpkvs() 241 | 242 | 243 | def getkvs(): 244 | return get_current().name2val 245 | 246 | 247 | def log(*args, level=INFO): 248 | """ 249 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). 250 | """ 251 | get_current().log(*args, level=level) 252 | 253 | 254 | def debug(*args): 255 | log(*args, level=DEBUG) 256 | 257 | 258 | def info(*args): 259 | log(*args, level=INFO) 260 | 261 | 262 | def warn(*args): 263 | log(*args, level=WARN) 264 | 265 | 266 | def error(*args): 267 | log(*args, level=ERROR) 268 | 269 | 270 | def set_level(level): 271 | """ 272 | Set logging threshold on current logger. 273 | """ 274 | get_current().set_level(level) 275 | 276 | 277 | def set_comm(comm): 278 | get_current().set_comm(comm) 279 | 280 | 281 | def get_dir(): 282 | """ 283 | Get directory that log files are being written to. 284 | will be None if there is no output directory (i.e., if you didn't call start) 285 | """ 286 | return get_current().get_dir() 287 | 288 | 289 | record_tabular = logkv 290 | dump_tabular = dumpkvs 291 | 292 | 293 | @contextmanager 294 | def profile_kv(scopename): 295 | logkey = "wait_" + scopename 296 | tstart = time.time() 297 | try: 298 | yield 299 | finally: 300 | get_current().name2val[logkey] += time.time() - tstart 301 | 302 | 303 | def profile(n): 304 | """ 305 | Usage: 306 | @profile("my_func") 307 | def my_func(): code 308 | """ 309 | 310 | def decorator_with_name(func): 311 | def func_wrapper(*args, **kwargs): 312 | with profile_kv(n): 313 | return func(*args, **kwargs) 314 | 315 | return func_wrapper 316 | 317 | return decorator_with_name 318 | 319 | 320 | # ================================================================ 321 | # Backend 322 | # ================================================================ 323 | 324 | 325 | def get_current(): 326 | if Logger.CURRENT is None: 327 | _configure_default_logger() 328 | 329 | return Logger.CURRENT 330 | 331 | 332 | class Logger(object): 333 | DEFAULT = None # A logger with no output files. (See right below class definition) 334 | # So that you can still log to the terminal without setting up any output files 335 | CURRENT = None # Current logger being used by the free functions above 336 | 337 | def __init__(self, dir, output_formats, comm=None): 338 | self.name2val = defaultdict(float) # values this iteration 339 | self.name2cnt = defaultdict(int) 340 | self.level = INFO 341 | self.dir = dir 342 | self.output_formats = output_formats 343 | self.comm = comm 344 | 345 | # Logging API, forwarded 346 | # ---------------------------------------- 347 | def logkv(self, key, val): 348 | self.name2val[key] = val 349 | 350 | def logkv_mean(self, key, val): 351 | oldval, cnt = self.name2val[key], self.name2cnt[key] 352 | self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1) 353 | self.name2cnt[key] = cnt + 1 354 | 355 | def dumpkvs(self): 356 | if self.comm is None: 357 | d = self.name2val 358 | else: 359 | d = mpi_weighted_mean( 360 | self.comm, 361 | { 362 | name: (val, self.name2cnt.get(name, 1)) 363 | for (name, val) in self.name2val.items() 364 | }, 365 | ) 366 | if self.comm.rank != 0: 367 | d["dummy"] = 1 # so we don't get a warning about empty dict 368 | out = d.copy() # Return the dict for unit testing purposes 369 | for fmt in self.output_formats: 370 | if isinstance(fmt, KVWriter): 371 | fmt.writekvs(d) 372 | self.name2val.clear() 373 | self.name2cnt.clear() 374 | return out 375 | 376 | def log(self, *args, level=INFO): 377 | if self.level <= level: 378 | self._do_log(args) 379 | 380 | # Configuration 381 | # ---------------------------------------- 382 | def set_level(self, level): 383 | self.level = level 384 | 385 | def set_comm(self, comm): 386 | self.comm = comm 387 | 388 | def get_dir(self): 389 | return self.dir 390 | 391 | def close(self): 392 | for fmt in self.output_formats: 393 | fmt.close() 394 | 395 | # Misc 396 | # ---------------------------------------- 397 | def _do_log(self, args): 398 | for fmt in self.output_formats: 399 | if isinstance(fmt, SeqWriter): 400 | fmt.writeseq(map(str, args)) 401 | 402 | 403 | def get_rank_without_mpi_import(): 404 | # check environment variables here instead of importing mpi4py 405 | # to avoid calling MPI_Init() when this module is imported 406 | for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]: 407 | if varname in os.environ: 408 | return int(os.environ[varname]) 409 | return 0 410 | 411 | 412 | def mpi_weighted_mean(comm, local_name2valcount): 413 | """ 414 | Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110 415 | Perform a weighted average over dicts that are each on a different node 416 | Input: local_name2valcount: dict mapping key -> (value, count) 417 | Returns: key -> mean 418 | """ 419 | all_name2valcount = comm.gather(local_name2valcount) 420 | if comm.rank == 0: 421 | name2sum = defaultdict(float) 422 | name2count = defaultdict(float) 423 | for n2vc in all_name2valcount: 424 | for (name, (val, count)) in n2vc.items(): 425 | try: 426 | val = float(val) 427 | except ValueError: 428 | if comm.rank == 0: 429 | warnings.warn( 430 | "WARNING: tried to compute mean on non-float {}={}".format( 431 | name, val 432 | ) 433 | ) 434 | else: 435 | name2sum[name] += val * count 436 | name2count[name] += count 437 | return {name: name2sum[name] / name2count[name] for name in name2sum} 438 | else: 439 | return {} 440 | 441 | 442 | def configure(dir=None, format_strs=None, comm=None, log_suffix=""): 443 | """ 444 | If comm is provided, average all numerical stats across that comm 445 | """ 446 | if dir is None: 447 | dir = os.getenv("OPENAI_LOGDIR") 448 | if dir is None: 449 | dir = osp.join( 450 | tempfile.gettempdir(), 451 | datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"), 452 | ) 453 | assert isinstance(dir, str) 454 | dir = os.path.expanduser(dir) 455 | os.makedirs(os.path.expanduser(dir), exist_ok=True) 456 | 457 | rank = get_rank_without_mpi_import() 458 | if rank > 0: 459 | log_suffix = log_suffix + "-rank%03i" % rank 460 | 461 | if format_strs is None: 462 | if rank == 0: 463 | format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",") 464 | else: 465 | format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",") 466 | format_strs = filter(None, format_strs) 467 | output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs] 468 | 469 | Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm) 470 | if output_formats: 471 | log("Logging to %s" % dir) 472 | 473 | 474 | def _configure_default_logger(): 475 | configure() 476 | Logger.DEFAULT = Logger.CURRENT 477 | 478 | 479 | def reset(): 480 | if Logger.CURRENT is not Logger.DEFAULT: 481 | Logger.CURRENT.close() 482 | Logger.CURRENT = Logger.DEFAULT 483 | log("Reset logger") 484 | 485 | 486 | @contextmanager 487 | def scoped_configure(dir=None, format_strs=None, comm=None): 488 | prevlogger = Logger.CURRENT 489 | configure(dir=dir, format_strs=format_strs, comm=comm) 490 | try: 491 | yield 492 | finally: 493 | Logger.CURRENT.close() 494 | Logger.CURRENT = prevlogger 495 | 496 | -------------------------------------------------------------------------------- /runners/diffusion.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import time 4 | import glob 5 | 6 | import numpy as np 7 | import tqdm 8 | import torch 9 | import torch.utils.data as data 10 | 11 | from models.diffusion import Model 12 | from datasets import get_dataset, data_transform, inverse_data_transform 13 | from functions.ckpt_util import get_ckpt_path, download 14 | from functions.denoising import efficient_generalized_steps 15 | 16 | import torchvision.utils as tvu 17 | 18 | from guided_diffusion.unet import UNetModel 19 | from guided_diffusion.script_util import create_model, create_classifier, classifier_defaults, args_to_dict 20 | import random 21 | 22 | def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): 23 | def sigmoid(x): 24 | return 1 / (np.exp(-x) + 1) 25 | 26 | if beta_schedule == "quad": 27 | betas = ( 28 | np.linspace( 29 | beta_start ** 0.5, 30 | beta_end ** 0.5, 31 | num_diffusion_timesteps, 32 | dtype=np.float64, 33 | ) 34 | ** 2 35 | ) 36 | elif beta_schedule == "linear": 37 | betas = np.linspace( 38 | beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 39 | ) 40 | elif beta_schedule == "const": 41 | betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) 42 | elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 43 | betas = 1.0 / np.linspace( 44 | num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64 45 | ) 46 | elif beta_schedule == "sigmoid": 47 | betas = np.linspace(-6, 6, num_diffusion_timesteps) 48 | betas = sigmoid(betas) * (beta_end - beta_start) + beta_start 49 | else: 50 | raise NotImplementedError(beta_schedule) 51 | assert betas.shape == (num_diffusion_timesteps,) 52 | return betas 53 | 54 | 55 | class Diffusion(object): 56 | def __init__(self, args, config, device=None): 57 | self.args = args 58 | self.config = config 59 | if device is None: 60 | device = ( 61 | torch.device("cuda") 62 | if torch.cuda.is_available() 63 | else torch.device("cpu") 64 | ) 65 | self.device = device 66 | 67 | self.model_var_type = config.model.var_type 68 | betas = get_beta_schedule( 69 | beta_schedule=config.diffusion.beta_schedule, 70 | beta_start=config.diffusion.beta_start, 71 | beta_end=config.diffusion.beta_end, 72 | num_diffusion_timesteps=config.diffusion.num_diffusion_timesteps, 73 | ) 74 | betas = self.betas = torch.from_numpy(betas).float().to(self.device) 75 | self.num_timesteps = betas.shape[0] 76 | 77 | alphas = 1.0 - betas 78 | alphas_cumprod = alphas.cumprod(dim=0) 79 | alphas_cumprod_prev = torch.cat( 80 | [torch.ones(1).to(device), alphas_cumprod[:-1]], dim=0 81 | ) 82 | self.alphas_cumprod_prev = alphas_cumprod_prev 83 | posterior_variance = ( 84 | betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) 85 | ) 86 | if self.model_var_type == "fixedlarge": 87 | self.logvar = betas.log() 88 | # torch.cat( 89 | # [posterior_variance[1:2], betas[1:]], dim=0).log() 90 | elif self.model_var_type == "fixedsmall": 91 | self.logvar = posterior_variance.clamp(min=1e-20).log() 92 | 93 | def sample(self): 94 | cls_fn = None 95 | if self.config.model.type == 'simple': 96 | model = Model(self.config) 97 | # This used the pretrained DDPM model, see https://github.com/pesser/pytorch_diffusion 98 | if self.config.data.dataset == "CIFAR10": 99 | name = "cifar10" 100 | elif self.config.data.dataset == "LSUN": 101 | name = f"lsun_{self.config.data.category}" 102 | elif self.config.data.dataset == 'CelebA_HQ': 103 | name = 'celeba_hq' 104 | else: 105 | raise ValueError 106 | if name != 'celeba_hq': 107 | ckpt = get_ckpt_path(f"ema_{name}", prefix=self.args.exp) 108 | print("Loading checkpoint {}".format(ckpt)) 109 | elif name == 'celeba_hq': 110 | #ckpt = '~/.cache/diffusion_models_converted/celeba_hq.ckpt' 111 | ckpt = os.path.join(self.args.exp, "logs/celeba/celeba_hq.ckpt") 112 | if not os.path.exists(ckpt): 113 | download('https://image-editing-test-12345.s3-us-west-2.amazonaws.com/checkpoints/celeba_hq.ckpt', ckpt) 114 | else: 115 | raise ValueError 116 | model.load_state_dict(torch.load(ckpt, map_location=self.device)) 117 | model.to(self.device) 118 | model = torch.nn.DataParallel(model) 119 | 120 | elif self.config.model.type == 'openai': 121 | config_dict = vars(self.config.model) 122 | model = create_model(**config_dict) 123 | if self.config.model.use_fp16: 124 | model.convert_to_fp16() 125 | if self.config.model.class_cond: 126 | ckpt = os.path.join(self.args.exp, 'logs/imagenet/%dx%d_diffusion.pt' % (self.config.data.image_size, self.config.data.image_size)) 127 | if not os.path.exists(ckpt): 128 | download('https://openaipublic.blob.core.windows.net/diffusion/jul-2021/%dx%d_diffusion_uncond.pt' % (self.config.data.image_size, self.config.data.image_size), ckpt) 129 | else: 130 | ckpt = os.path.join(self.args.exp, "logs/imagenet/256x256_diffusion_uncond.pt") 131 | if not os.path.exists(ckpt): 132 | download('https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion_uncond.pt', ckpt) 133 | 134 | 135 | model.load_state_dict(torch.load(ckpt, map_location=self.device)) 136 | model.to(self.device) 137 | model.eval() 138 | model = torch.nn.DataParallel(model) 139 | 140 | if self.config.model.class_cond: 141 | ckpt = os.path.join(self.args.exp, 'logs/imagenet/%dx%d_classifier.pt' % (self.config.data.image_size, self.config.data.image_size)) 142 | if not os.path.exists(ckpt): 143 | image_size = self.config.data.image_size 144 | download('https://openaipublic.blob.core.windows.net/diffusion/jul-2021/%dx%d_classifier.pt' % image_size, ckpt) 145 | classifier = create_classifier(**args_to_dict(self.config.classifier, classifier_defaults().keys())) 146 | classifier.load_state_dict(torch.load(ckpt, map_location=self.device)) 147 | classifier.to(self.device) 148 | if self.config.classifier.classifier_use_fp16: 149 | classifier.convert_to_fp16() 150 | classifier.eval() 151 | classifier = torch.nn.DataParallel(classifier) 152 | 153 | import torch.nn.functional as F 154 | def cond_fn(x, t, y): 155 | with torch.enable_grad(): 156 | x_in = x.detach().requires_grad_(True) 157 | logits = classifier(x_in, t) 158 | log_probs = F.log_softmax(logits, dim=-1) 159 | selected = log_probs[range(len(logits)), y.view(-1)] 160 | return torch.autograd.grad(selected.sum(), x_in)[0] * self.config.classifier.classifier_scale 161 | cls_fn = cond_fn 162 | 163 | self.sample_sequence(model, cls_fn) 164 | 165 | def sample_sequence(self, model, cls_fn=None): 166 | args, config = self.args, self.config 167 | 168 | #get original images and corrupted y_0 169 | dataset, test_dataset = get_dataset(args, config) 170 | 171 | device_count = torch.cuda.device_count() 172 | 173 | if args.subset_start >= 0 and args.subset_end > 0: 174 | assert args.subset_end > args.subset_start 175 | test_dataset = torch.utils.data.Subset(test_dataset, range(args.subset_start, args.subset_end)) 176 | else: 177 | args.subset_start = 0 178 | args.subset_end = len(test_dataset) 179 | 180 | print(f'Dataset has size {len(test_dataset)}') 181 | 182 | def seed_worker(worker_id): 183 | worker_seed = args.seed % 2**32 184 | np.random.seed(worker_seed) 185 | random.seed(worker_seed) 186 | 187 | g = torch.Generator() 188 | g.manual_seed(args.seed) 189 | val_loader = data.DataLoader( 190 | test_dataset, 191 | batch_size=config.sampling.batch_size, 192 | shuffle=True, 193 | num_workers=config.data.num_workers, 194 | worker_init_fn=seed_worker, 195 | generator=g, 196 | ) 197 | 198 | 199 | ## get degradation matrix ## 200 | deg = args.deg 201 | H_funcs = None 202 | if deg[:2] == 'cs': 203 | compress_by = int(deg[2:]) 204 | from functions.svd_replacement import WalshHadamardCS 205 | H_funcs = WalshHadamardCS(config.data.channels, self.config.data.image_size, compress_by, torch.randperm(self.config.data.image_size**2, device=self.device), self.device) 206 | elif deg[:3] == 'inp': 207 | from functions.svd_replacement import Inpainting 208 | if deg == 'inp_lolcat': 209 | loaded = np.load("inp_masks/lolcat_extra.npy") 210 | mask = torch.from_numpy(loaded).to(self.device).reshape(-1) 211 | missing_r = torch.nonzero(mask == 0).long().reshape(-1) * 3 212 | elif deg == 'inp_lorem': 213 | loaded = np.load("inp_masks/lorem3.npy") 214 | mask = torch.from_numpy(loaded).to(self.device).reshape(-1) 215 | missing_r = torch.nonzero(mask == 0).long().reshape(-1) * 3 216 | else: 217 | missing_r = torch.randperm(config.data.image_size**2)[:config.data.image_size**2 // 2].to(self.device).long() * 3 218 | missing_g = missing_r + 1 219 | missing_b = missing_g + 1 220 | missing = torch.cat([missing_r, missing_g, missing_b], dim=0) 221 | H_funcs = Inpainting(config.data.channels, config.data.image_size, missing, self.device) 222 | elif deg == 'deno': 223 | from functions.svd_replacement import Denoising 224 | H_funcs = Denoising(config.data.channels, self.config.data.image_size, self.device) 225 | elif deg[:10] == 'sr_bicubic': 226 | factor = int(deg[10:]) 227 | from functions.svd_replacement import SRConv 228 | def bicubic_kernel(x, a=-0.5): 229 | if abs(x) <= 1: 230 | return (a + 2)*abs(x)**3 - (a + 3)*abs(x)**2 + 1 231 | elif 1 < abs(x) and abs(x) < 2: 232 | return a*abs(x)**3 - 5*a*abs(x)**2 + 8*a*abs(x) - 4*a 233 | else: 234 | return 0 235 | k = np.zeros((factor * 4)) 236 | for i in range(factor * 4): 237 | x = (1/factor)*(i- np.floor(factor*4/2) +0.5) 238 | k[i] = bicubic_kernel(x) 239 | k = k / np.sum(k) 240 | kernel = torch.from_numpy(k).float().to(self.device) 241 | H_funcs = SRConv(kernel / kernel.sum(), \ 242 | config.data.channels, self.config.data.image_size, self.device, stride = factor) 243 | elif deg == 'deblur_uni': 244 | from functions.svd_replacement import Deblurring 245 | H_funcs = Deblurring(torch.Tensor([1/9] * 9).to(self.device), config.data.channels, self.config.data.image_size, self.device) 246 | elif deg == 'deblur_gauss': 247 | from functions.svd_replacement import Deblurring 248 | sigma = 10 249 | pdf = lambda x: torch.exp(torch.Tensor([-0.5 * (x/sigma)**2])) 250 | kernel = torch.Tensor([pdf(-2), pdf(-1), pdf(0), pdf(1), pdf(2)]).to(self.device) 251 | H_funcs = Deblurring(kernel / kernel.sum(), config.data.channels, self.config.data.image_size, self.device) 252 | elif deg == 'deblur_aniso': 253 | from functions.svd_replacement import Deblurring2D 254 | sigma = 20 255 | pdf = lambda x: torch.exp(torch.Tensor([-0.5 * (x/sigma)**2])) 256 | kernel2 = torch.Tensor([pdf(-4), pdf(-3), pdf(-2), pdf(-1), pdf(0), pdf(1), pdf(2), pdf(3), pdf(4)]).to(self.device) 257 | sigma = 1 258 | pdf = lambda x: torch.exp(torch.Tensor([-0.5 * (x/sigma)**2])) 259 | kernel1 = torch.Tensor([pdf(-4), pdf(-3), pdf(-2), pdf(-1), pdf(0), pdf(1), pdf(2), pdf(3), pdf(4)]).to(self.device) 260 | H_funcs = Deblurring2D(kernel1 / kernel1.sum(), kernel2 / kernel2.sum(), config.data.channels, self.config.data.image_size, self.device) 261 | elif deg[:2] == 'sr': 262 | blur_by = int(deg[2:]) 263 | from functions.svd_replacement import SuperResolution 264 | H_funcs = SuperResolution(config.data.channels, config.data.image_size, blur_by, self.device) 265 | elif deg == 'color': 266 | from functions.svd_replacement import Colorization 267 | H_funcs = Colorization(config.data.image_size, self.device) 268 | else: 269 | print("ERROR: degradation type not supported") 270 | quit() 271 | args.sigma_0 = 2 * args.sigma_0 #to account for scaling to [-1,1] 272 | sigma_0 = args.sigma_0 273 | 274 | print(f'Start from {args.subset_start}') 275 | idx_init = args.subset_start 276 | idx_so_far = args.subset_start 277 | avg_psnr = 0.0 278 | pbar = tqdm.tqdm(val_loader) 279 | for x_orig, classes in pbar: 280 | x_orig = x_orig.to(self.device) 281 | x_orig = data_transform(self.config, x_orig) 282 | 283 | y_0 = H_funcs.H(x_orig) 284 | y_0 = y_0 + sigma_0 * torch.randn_like(y_0) 285 | 286 | pinv_y_0 = H_funcs.H_pinv(y_0).view(y_0.shape[0], config.data.channels, self.config.data.image_size, self.config.data.image_size) 287 | if deg[:6] == 'deblur': pinv_y_0 = y_0.view(y_0.shape[0], config.data.channels, self.config.data.image_size, self.config.data.image_size) 288 | elif deg == 'color': pinv_y_0 = y_0.view(y_0.shape[0], 1, self.config.data.image_size, self.config.data.image_size).repeat(1, 3, 1, 1) 289 | elif deg[:3] == 'inp': pinv_y_0 += H_funcs.H_pinv(H_funcs.H(torch.ones_like(pinv_y_0))).reshape(*pinv_y_0.shape) - 1 290 | 291 | for i in range(len(pinv_y_0)): 292 | tvu.save_image( 293 | inverse_data_transform(config, pinv_y_0[i]), os.path.join(self.args.image_folder, f"y0_{idx_so_far + i}.png") 294 | ) 295 | tvu.save_image( 296 | inverse_data_transform(config, x_orig[i]), os.path.join(self.args.image_folder, f"orig_{idx_so_far + i}.png") 297 | ) 298 | 299 | ##Begin DDIM 300 | x = torch.randn( 301 | y_0.shape[0], 302 | config.data.channels, 303 | config.data.image_size, 304 | config.data.image_size, 305 | device=self.device, 306 | ) 307 | 308 | # NOTE: This means that we are producing each predicted x0, not x_{t-1} at timestep t. 309 | with torch.no_grad(): 310 | x, _ = self.sample_image(x, model, H_funcs, y_0, sigma_0, last=False, cls_fn=cls_fn, classes=classes) 311 | 312 | x = [inverse_data_transform(config, y) for y in x] 313 | 314 | for i in [-1]: #range(len(x)): 315 | for j in range(x[i].size(0)): 316 | tvu.save_image( 317 | x[i][j], os.path.join(self.args.image_folder, f"{idx_so_far + j}_{i}.png") 318 | ) 319 | if i == len(x)-1 or i == -1: 320 | orig = inverse_data_transform(config, x_orig[j]) 321 | mse = torch.mean((x[i][j].to(self.device) - orig) ** 2) 322 | psnr = 10 * torch.log10(1 / mse) 323 | avg_psnr += psnr 324 | 325 | idx_so_far += y_0.shape[0] 326 | 327 | pbar.set_description("PSNR: %.2f" % (avg_psnr / (idx_so_far - idx_init))) 328 | 329 | avg_psnr = avg_psnr / (idx_so_far - idx_init) 330 | print("Total Average PSNR: %.2f" % avg_psnr) 331 | print("Number of samples: %d" % (idx_so_far - idx_init)) 332 | 333 | def sample_image(self, x, model, H_funcs, y_0, sigma_0, last=True, cls_fn=None, classes=None): 334 | skip = self.num_timesteps // self.args.timesteps 335 | seq = range(0, self.num_timesteps, skip) 336 | 337 | x = efficient_generalized_steps(x, seq, model, self.betas, H_funcs, y_0, sigma_0, \ 338 | etaB=self.args.etaB, etaA=self.args.eta, etaC=self.args.eta, cls_fn=cls_fn, classes=classes) 339 | if last: 340 | x = x[0][-1] 341 | return x -------------------------------------------------------------------------------- /functions/svd_replacement.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class H_functions: 4 | """ 5 | A class replacing the SVD of a matrix H, perhaps efficiently. 6 | All input vectors are of shape (Batch, ...). 7 | All output vectors are of shape (Batch, DataDimension). 8 | """ 9 | 10 | def V(self, vec): 11 | """ 12 | Multiplies the input vector by V 13 | """ 14 | raise NotImplementedError() 15 | 16 | def Vt(self, vec): 17 | """ 18 | Multiplies the input vector by V transposed 19 | """ 20 | raise NotImplementedError() 21 | 22 | def U(self, vec): 23 | """ 24 | Multiplies the input vector by U 25 | """ 26 | raise NotImplementedError() 27 | 28 | def Ut(self, vec): 29 | """ 30 | Multiplies the input vector by U transposed 31 | """ 32 | raise NotImplementedError() 33 | 34 | def singulars(self): 35 | """ 36 | Returns a vector containing the singular values. The shape of the vector should be the same as the smaller dimension (like U) 37 | """ 38 | raise NotImplementedError() 39 | 40 | def add_zeros(self, vec): 41 | """ 42 | Adds trailing zeros to turn a vector from the small dimension (U) to the big dimension (V) 43 | """ 44 | raise NotImplementedError() 45 | 46 | def H(self, vec): 47 | """ 48 | Multiplies the input vector by H 49 | """ 50 | temp = self.Vt(vec) 51 | singulars = self.singulars() 52 | return self.U(singulars * temp[:, :singulars.shape[0]]) 53 | 54 | def Ht(self, vec): 55 | """ 56 | Multiplies the input vector by H transposed 57 | """ 58 | temp = self.Ut(vec) 59 | singulars = self.singulars() 60 | return self.V(self.add_zeros(singulars * temp[:, :singulars.shape[0]])) 61 | 62 | def H_pinv(self, vec): 63 | """ 64 | Multiplies the input vector by the pseudo inverse of H 65 | """ 66 | temp = self.Ut(vec) 67 | singulars = self.singulars() 68 | temp[:, :singulars.shape[0]] = temp[:, :singulars.shape[0]] / singulars 69 | return self.V(self.add_zeros(temp)) 70 | 71 | #a memory inefficient implementation for any general degradation H 72 | class GeneralH(H_functions): 73 | def mat_by_vec(self, M, v): 74 | vshape = v.shape[1] 75 | if len(v.shape) > 2: vshape = vshape * v.shape[2] 76 | if len(v.shape) > 3: vshape = vshape * v.shape[3] 77 | return torch.matmul(M, v.view(v.shape[0], vshape, 78 | 1)).view(v.shape[0], M.shape[0]) 79 | 80 | def __init__(self, H): 81 | self._U, self._singulars, self._V = torch.svd(H, some=False) 82 | self._Vt = self._V.transpose(0, 1) 83 | self._Ut = self._U.transpose(0, 1) 84 | 85 | ZERO = 1e-3 86 | self._singulars[self._singulars < ZERO] = 0 87 | print(len([x.item() for x in self._singulars if x == 0])) 88 | 89 | def V(self, vec): 90 | return self.mat_by_vec(self._V, vec.clone()) 91 | 92 | def Vt(self, vec): 93 | return self.mat_by_vec(self._Vt, vec.clone()) 94 | 95 | def U(self, vec): 96 | return self.mat_by_vec(self._U, vec.clone()) 97 | 98 | def Ut(self, vec): 99 | return self.mat_by_vec(self._Ut, vec.clone()) 100 | 101 | def singulars(self): 102 | return self._singulars 103 | 104 | def add_zeros(self, vec): 105 | out = torch.zeros(vec.shape[0], self._V.shape[0], device=vec.device) 106 | out[:, :self._U.shape[0]] = vec.clone().reshape(vec.shape[0], -1) 107 | return out 108 | 109 | #Inpainting 110 | class Inpainting(H_functions): 111 | def __init__(self, channels, img_dim, missing_indices, device): 112 | self.channels = channels 113 | self.img_dim = img_dim 114 | self._singulars = torch.ones(channels * img_dim**2 - missing_indices.shape[0]).to(device) 115 | self.missing_indices = missing_indices 116 | self.kept_indices = torch.Tensor([i for i in range(channels * img_dim**2) if i not in missing_indices]).to(device).long() 117 | 118 | def V(self, vec): 119 | temp = vec.clone().reshape(vec.shape[0], -1) 120 | out = torch.zeros_like(temp) 121 | out[:, self.kept_indices] = temp[:, :self.kept_indices.shape[0]] 122 | out[:, self.missing_indices] = temp[:, self.kept_indices.shape[0]:] 123 | return out.reshape(vec.shape[0], -1, self.channels).permute(0, 2, 1).reshape(vec.shape[0], -1) 124 | 125 | def Vt(self, vec): 126 | temp = vec.clone().reshape(vec.shape[0], self.channels, -1).permute(0, 2, 1).reshape(vec.shape[0], -1) 127 | out = torch.zeros_like(temp) 128 | out[:, :self.kept_indices.shape[0]] = temp[:, self.kept_indices] 129 | out[:, self.kept_indices.shape[0]:] = temp[:, self.missing_indices] 130 | return out 131 | 132 | def U(self, vec): 133 | return vec.clone().reshape(vec.shape[0], -1) 134 | 135 | def Ut(self, vec): 136 | return vec.clone().reshape(vec.shape[0], -1) 137 | 138 | def singulars(self): 139 | return self._singulars 140 | 141 | def add_zeros(self, vec): 142 | temp = torch.zeros((vec.shape[0], self.channels * self.img_dim**2), device=vec.device) 143 | reshaped = vec.clone().reshape(vec.shape[0], -1) 144 | temp[:, :reshaped.shape[1]] = reshaped 145 | return temp 146 | 147 | #Denoising 148 | class Denoising(H_functions): 149 | def __init__(self, channels, img_dim, device): 150 | self._singulars = torch.ones(channels * img_dim**2, device=device) 151 | 152 | def V(self, vec): 153 | return vec.clone().reshape(vec.shape[0], -1) 154 | 155 | def Vt(self, vec): 156 | return vec.clone().reshape(vec.shape[0], -1) 157 | 158 | def U(self, vec): 159 | return vec.clone().reshape(vec.shape[0], -1) 160 | 161 | def Ut(self, vec): 162 | return vec.clone().reshape(vec.shape[0], -1) 163 | 164 | def singulars(self): 165 | return self._singulars 166 | 167 | def add_zeros(self, vec): 168 | return vec.clone().reshape(vec.shape[0], -1) 169 | 170 | #Super Resolution 171 | class SuperResolution(H_functions): 172 | def __init__(self, channels, img_dim, ratio, device): #ratio = 2 or 4 173 | assert img_dim % ratio == 0 174 | self.img_dim = img_dim 175 | self.channels = channels 176 | self.y_dim = img_dim // ratio 177 | self.ratio = ratio 178 | H = torch.Tensor([[1 / ratio**2] * ratio**2]).to(device) 179 | self.U_small, self.singulars_small, self.V_small = torch.svd(H, some=False) 180 | self.Vt_small = self.V_small.transpose(0, 1) 181 | 182 | def V(self, vec): 183 | #reorder the vector back into patches (because singulars are ordered descendingly) 184 | temp = vec.clone().reshape(vec.shape[0], -1) 185 | patches = torch.zeros(vec.shape[0], self.channels, self.y_dim**2, self.ratio**2, device=vec.device) 186 | patches[:, :, :, 0] = temp[:, :self.channels * self.y_dim**2].view(vec.shape[0], self.channels, -1) 187 | for idx in range(self.ratio**2-1): 188 | patches[:, :, :, idx+1] = temp[:, (self.channels*self.y_dim**2+idx)::self.ratio**2-1].view(vec.shape[0], self.channels, -1) 189 | #multiply each patch by the small V 190 | patches = torch.matmul(self.V_small, patches.reshape(-1, self.ratio**2, 1)).reshape(vec.shape[0], self.channels, -1, self.ratio**2) 191 | #repatch the patches into an image 192 | patches_orig = patches.reshape(vec.shape[0], self.channels, self.y_dim, self.y_dim, self.ratio, self.ratio) 193 | recon = patches_orig.permute(0, 1, 2, 4, 3, 5).contiguous() 194 | recon = recon.reshape(vec.shape[0], self.channels * self.img_dim ** 2) 195 | return recon 196 | 197 | def Vt(self, vec): 198 | #extract flattened patches 199 | patches = vec.clone().reshape(vec.shape[0], self.channels, self.img_dim, self.img_dim) 200 | patches = patches.unfold(2, self.ratio, self.ratio).unfold(3, self.ratio, self.ratio) 201 | unfold_shape = patches.shape 202 | patches = patches.contiguous().reshape(vec.shape[0], self.channels, -1, self.ratio**2) 203 | #multiply each by the small V transposed 204 | patches = torch.matmul(self.Vt_small, patches.reshape(-1, self.ratio**2, 1)).reshape(vec.shape[0], self.channels, -1, self.ratio**2) 205 | #reorder the vector to have the first entry first (because singulars are ordered descendingly) 206 | recon = torch.zeros(vec.shape[0], self.channels * self.img_dim**2, device=vec.device) 207 | recon[:, :self.channels * self.y_dim**2] = patches[:, :, :, 0].view(vec.shape[0], self.channels * self.y_dim**2) 208 | for idx in range(self.ratio**2-1): 209 | recon[:, (self.channels*self.y_dim**2+idx)::self.ratio**2-1] = patches[:, :, :, idx+1].view(vec.shape[0], self.channels * self.y_dim**2) 210 | return recon 211 | 212 | def U(self, vec): 213 | return self.U_small[0, 0] * vec.clone().reshape(vec.shape[0], -1) 214 | 215 | def Ut(self, vec): #U is 1x1, so U^T = U 216 | return self.U_small[0, 0] * vec.clone().reshape(vec.shape[0], -1) 217 | 218 | def singulars(self): 219 | return self.singulars_small.repeat(self.channels * self.y_dim**2) 220 | 221 | def add_zeros(self, vec): 222 | reshaped = vec.clone().reshape(vec.shape[0], -1) 223 | temp = torch.zeros((vec.shape[0], reshaped.shape[1] * self.ratio**2), device=vec.device) 224 | temp[:, :reshaped.shape[1]] = reshaped 225 | return temp 226 | 227 | #Colorization 228 | class Colorization(H_functions): 229 | def __init__(self, img_dim, device): 230 | self.channels = 3 231 | self.img_dim = img_dim 232 | #Do the SVD for the per-pixel matrix 233 | H = torch.Tensor([[0.3333, 0.3334, 0.3333]]).to(device) 234 | self.U_small, self.singulars_small, self.V_small = torch.svd(H, some=False) 235 | self.Vt_small = self.V_small.transpose(0, 1) 236 | 237 | def V(self, vec): 238 | #get the needles 239 | needles = vec.clone().reshape(vec.shape[0], self.channels, -1).permute(0, 2, 1) #shape: B, WH, C' 240 | #multiply each needle by the small V 241 | needles = torch.matmul(self.V_small, needles.reshape(-1, self.channels, 1)).reshape(vec.shape[0], -1, self.channels) #shape: B, WH, C 242 | #permute back to vector representation 243 | recon = needles.permute(0, 2, 1) #shape: B, C, WH 244 | return recon.reshape(vec.shape[0], -1) 245 | 246 | def Vt(self, vec): 247 | #get the needles 248 | needles = vec.clone().reshape(vec.shape[0], self.channels, -1).permute(0, 2, 1) #shape: B, WH, C 249 | #multiply each needle by the small V transposed 250 | needles = torch.matmul(self.Vt_small, needles.reshape(-1, self.channels, 1)).reshape(vec.shape[0], -1, self.channels) #shape: B, WH, C' 251 | #reorder the vector so that the first entry of each needle is at the top 252 | recon = needles.permute(0, 2, 1).reshape(vec.shape[0], -1) 253 | return recon 254 | 255 | def U(self, vec): 256 | return self.U_small[0, 0] * vec.clone().reshape(vec.shape[0], -1) 257 | 258 | def Ut(self, vec): #U is 1x1, so U^T = U 259 | return self.U_small[0, 0] * vec.clone().reshape(vec.shape[0], -1) 260 | 261 | def singulars(self): 262 | return self.singulars_small.repeat(self.img_dim**2) 263 | 264 | def add_zeros(self, vec): 265 | reshaped = vec.clone().reshape(vec.shape[0], -1) 266 | temp = torch.zeros((vec.shape[0], self.channels * self.img_dim**2), device=vec.device) 267 | temp[:, :self.img_dim**2] = reshaped 268 | return temp 269 | 270 | #Walsh-Hadamard Compressive Sensing 271 | class WalshHadamardCS(H_functions): 272 | def fwht(self, vec): #the Fast Walsh Hadamard Transform is the same as its inverse 273 | a = vec.reshape(vec.shape[0], self.channels, self.img_dim**2) 274 | h = 1 275 | while h < self.img_dim**2: 276 | a = a.reshape(vec.shape[0], self.channels, -1, h * 2) 277 | b = a.clone() 278 | a[:, :, :, :h] = b[:, :, :, :h] + b[:, :, :, h:2*h] 279 | a[:, :, :, h:2*h] = b[:, :, :, :h] - b[:, :, :, h:2*h] 280 | h *= 2 281 | a = a.reshape(vec.shape[0], self.channels, self.img_dim**2) / self.img_dim 282 | return a 283 | 284 | def __init__(self, channels, img_dim, ratio, perm, device): 285 | self.channels = channels 286 | self.img_dim = img_dim 287 | self.ratio = ratio 288 | self.perm = perm 289 | self._singulars = torch.ones(channels * img_dim**2 // ratio, device=device) 290 | 291 | def V(self, vec): 292 | temp = torch.zeros(vec.shape[0], self.channels, self.img_dim**2, device=vec.device) 293 | temp[:, :, self.perm] = vec.clone().reshape(vec.shape[0], -1, self.channels).permute(0, 2, 1) 294 | return self.fwht(temp).reshape(vec.shape[0], -1) 295 | 296 | def Vt(self, vec): 297 | return self.fwht(vec.clone())[:, :, self.perm].permute(0, 2, 1).reshape(vec.shape[0], -1) 298 | 299 | def U(self, vec): 300 | return vec.clone().reshape(vec.shape[0], -1) 301 | 302 | def Ut(self, vec): 303 | return vec.clone().reshape(vec.shape[0], -1) 304 | 305 | def singulars(self): 306 | return self._singulars 307 | 308 | def add_zeros(self, vec): 309 | out = torch.zeros(vec.shape[0], self.channels * self.img_dim**2, device=vec.device) 310 | out[:, :self.channels * self.img_dim**2 // self.ratio] = vec.clone().reshape(vec.shape[0], -1) 311 | return out 312 | 313 | #Convolution-based super-resolution 314 | class SRConv(H_functions): 315 | def mat_by_img(self, M, v, dim): 316 | return torch.matmul(M, v.reshape(v.shape[0] * self.channels, dim, 317 | dim)).reshape(v.shape[0], self.channels, M.shape[0], dim) 318 | 319 | def img_by_mat(self, v, M, dim): 320 | return torch.matmul(v.reshape(v.shape[0] * self.channels, dim, 321 | dim), M).reshape(v.shape[0], self.channels, dim, M.shape[1]) 322 | 323 | def __init__(self, kernel, channels, img_dim, device, stride = 1): 324 | self.img_dim = img_dim 325 | self.channels = channels 326 | self.ratio = stride 327 | small_dim = img_dim // stride 328 | self.small_dim = small_dim 329 | #build 1D conv matrix 330 | H_small = torch.zeros(small_dim, img_dim, device=device) 331 | for i in range(stride//2, img_dim + stride//2, stride): 332 | for j in range(i - kernel.shape[0]//2, i + kernel.shape[0]//2): 333 | j_effective = j 334 | #reflective padding 335 | if j_effective < 0: j_effective = -j_effective-1 336 | if j_effective >= img_dim: j_effective = (img_dim - 1) - (j_effective - img_dim) 337 | #matrix building 338 | H_small[i // stride, j_effective] += kernel[j - i + kernel.shape[0]//2] 339 | #get the svd of the 1D conv 340 | self.U_small, self.singulars_small, self.V_small = torch.svd(H_small, some=False) 341 | ZERO = 3e-2 342 | self.singulars_small[self.singulars_small < ZERO] = 0 343 | #calculate the singular values of the big matrix 344 | self._singulars = torch.matmul(self.singulars_small.reshape(small_dim, 1), self.singulars_small.reshape(1, small_dim)).reshape(small_dim**2) 345 | #permutation for matching the singular values. See P_1 in Appendix D.5. 346 | self._perm = torch.Tensor([self.img_dim * i + j for i in range(self.small_dim) for j in range(self.small_dim)] + \ 347 | [self.img_dim * i + j for i in range(self.small_dim) for j in range(self.small_dim, self.img_dim)]).to(device).long() 348 | 349 | def V(self, vec): 350 | #invert the permutation 351 | temp = torch.zeros(vec.shape[0], self.img_dim**2, self.channels, device=vec.device) 352 | temp[:, self._perm, :] = vec.clone().reshape(vec.shape[0], self.img_dim**2, self.channels)[:, :self._perm.shape[0], :] 353 | temp[:, self._perm.shape[0]:, :] = vec.clone().reshape(vec.shape[0], self.img_dim**2, self.channels)[:, self._perm.shape[0]:, :] 354 | temp = temp.permute(0, 2, 1) 355 | #multiply the image by V from the left and by V^T from the right 356 | out = self.mat_by_img(self.V_small, temp, self.img_dim) 357 | out = self.img_by_mat(out, self.V_small.transpose(0, 1), self.img_dim).reshape(vec.shape[0], -1) 358 | return out 359 | 360 | def Vt(self, vec): 361 | #multiply the image by V^T from the left and by V from the right 362 | temp = self.mat_by_img(self.V_small.transpose(0, 1), vec.clone(), self.img_dim) 363 | temp = self.img_by_mat(temp, self.V_small, self.img_dim).reshape(vec.shape[0], self.channels, -1) 364 | #permute the entries 365 | temp[:, :, :self._perm.shape[0]] = temp[:, :, self._perm] 366 | temp = temp.permute(0, 2, 1) 367 | return temp.reshape(vec.shape[0], -1) 368 | 369 | def U(self, vec): 370 | #invert the permutation 371 | temp = torch.zeros(vec.shape[0], self.small_dim**2, self.channels, device=vec.device) 372 | temp[:, :self.small_dim**2, :] = vec.clone().reshape(vec.shape[0], self.small_dim**2, self.channels) 373 | temp = temp.permute(0, 2, 1) 374 | #multiply the image by U from the left and by U^T from the right 375 | out = self.mat_by_img(self.U_small, temp, self.small_dim) 376 | out = self.img_by_mat(out, self.U_small.transpose(0, 1), self.small_dim).reshape(vec.shape[0], -1) 377 | return out 378 | 379 | def Ut(self, vec): 380 | #multiply the image by U^T from the left and by U from the right 381 | temp = self.mat_by_img(self.U_small.transpose(0, 1), vec.clone(), self.small_dim) 382 | temp = self.img_by_mat(temp, self.U_small, self.small_dim).reshape(vec.shape[0], self.channels, -1) 383 | #permute the entries 384 | temp = temp.permute(0, 2, 1) 385 | return temp.reshape(vec.shape[0], -1) 386 | 387 | def singulars(self): 388 | return self._singulars.repeat_interleave(3).reshape(-1) 389 | 390 | def add_zeros(self, vec): 391 | reshaped = vec.clone().reshape(vec.shape[0], -1) 392 | temp = torch.zeros((vec.shape[0], reshaped.shape[1] * self.ratio**2), device=vec.device) 393 | temp[:, :reshaped.shape[1]] = reshaped 394 | return temp 395 | 396 | #Deblurring 397 | class Deblurring(H_functions): 398 | def mat_by_img(self, M, v): 399 | return torch.matmul(M, v.reshape(v.shape[0] * self.channels, self.img_dim, 400 | self.img_dim)).reshape(v.shape[0], self.channels, M.shape[0], self.img_dim) 401 | 402 | def img_by_mat(self, v, M): 403 | return torch.matmul(v.reshape(v.shape[0] * self.channels, self.img_dim, 404 | self.img_dim), M).reshape(v.shape[0], self.channels, self.img_dim, M.shape[1]) 405 | 406 | def __init__(self, kernel, channels, img_dim, device, ZERO = 3e-2): 407 | self.img_dim = img_dim 408 | self.channels = channels 409 | #build 1D conv matrix 410 | H_small = torch.zeros(img_dim, img_dim, device=device) 411 | for i in range(img_dim): 412 | for j in range(i - kernel.shape[0]//2, i + kernel.shape[0]//2): 413 | if j < 0 or j >= img_dim: continue 414 | H_small[i, j] = kernel[j - i + kernel.shape[0]//2] 415 | #get the svd of the 1D conv 416 | self.U_small, self.singulars_small, self.V_small = torch.svd(H_small, some=False) 417 | #ZERO = 3e-2 418 | self.singulars_small[self.singulars_small < ZERO] = 0 419 | #calculate the singular values of the big matrix 420 | self._singulars = torch.matmul(self.singulars_small.reshape(img_dim, 1), self.singulars_small.reshape(1, img_dim)).reshape(img_dim**2) 421 | #sort the big matrix singulars and save the permutation 422 | self._singulars, self._perm = self._singulars.sort(descending=True) #, stable=True) 423 | 424 | def V(self, vec): 425 | #invert the permutation 426 | temp = torch.zeros(vec.shape[0], self.img_dim**2, self.channels, device=vec.device) 427 | temp[:, self._perm, :] = vec.clone().reshape(vec.shape[0], self.img_dim**2, self.channels) 428 | temp = temp.permute(0, 2, 1) 429 | #multiply the image by V from the left and by V^T from the right 430 | out = self.mat_by_img(self.V_small, temp) 431 | out = self.img_by_mat(out, self.V_small.transpose(0, 1)).reshape(vec.shape[0], -1) 432 | return out 433 | 434 | def Vt(self, vec): 435 | #multiply the image by V^T from the left and by V from the right 436 | temp = self.mat_by_img(self.V_small.transpose(0, 1), vec.clone()) 437 | temp = self.img_by_mat(temp, self.V_small).reshape(vec.shape[0], self.channels, -1) 438 | #permute the entries according to the singular values 439 | temp = temp[:, :, self._perm].permute(0, 2, 1) 440 | return temp.reshape(vec.shape[0], -1) 441 | 442 | def U(self, vec): 443 | #invert the permutation 444 | temp = torch.zeros(vec.shape[0], self.img_dim**2, self.channels, device=vec.device) 445 | temp[:, self._perm, :] = vec.clone().reshape(vec.shape[0], self.img_dim**2, self.channels) 446 | temp = temp.permute(0, 2, 1) 447 | #multiply the image by U from the left and by U^T from the right 448 | out = self.mat_by_img(self.U_small, temp) 449 | out = self.img_by_mat(out, self.U_small.transpose(0, 1)).reshape(vec.shape[0], -1) 450 | return out 451 | 452 | def Ut(self, vec): 453 | #multiply the image by U^T from the left and by U from the right 454 | temp = self.mat_by_img(self.U_small.transpose(0, 1), vec.clone()) 455 | temp = self.img_by_mat(temp, self.U_small).reshape(vec.shape[0], self.channels, -1) 456 | #permute the entries according to the singular values 457 | temp = temp[:, :, self._perm].permute(0, 2, 1) 458 | return temp.reshape(vec.shape[0], -1) 459 | 460 | def singulars(self): 461 | return self._singulars.repeat(1, 3).reshape(-1) 462 | 463 | def add_zeros(self, vec): 464 | return vec.clone().reshape(vec.shape[0], -1) 465 | 466 | #Anisotropic Deblurring 467 | class Deblurring2D(H_functions): 468 | def mat_by_img(self, M, v): 469 | return torch.matmul(M, v.reshape(v.shape[0] * self.channels, self.img_dim, 470 | self.img_dim)).reshape(v.shape[0], self.channels, M.shape[0], self.img_dim) 471 | 472 | def img_by_mat(self, v, M): 473 | return torch.matmul(v.reshape(v.shape[0] * self.channels, self.img_dim, 474 | self.img_dim), M).reshape(v.shape[0], self.channels, self.img_dim, M.shape[1]) 475 | 476 | def __init__(self, kernel1, kernel2, channels, img_dim, device): 477 | self.img_dim = img_dim 478 | self.channels = channels 479 | #build 1D conv matrix - kernel1 480 | H_small1 = torch.zeros(img_dim, img_dim, device=device) 481 | for i in range(img_dim): 482 | for j in range(i - kernel1.shape[0]//2, i + kernel1.shape[0]//2): 483 | if j < 0 or j >= img_dim: continue 484 | H_small1[i, j] = kernel1[j - i + kernel1.shape[0]//2] 485 | #build 1D conv matrix - kernel2 486 | H_small2 = torch.zeros(img_dim, img_dim, device=device) 487 | for i in range(img_dim): 488 | for j in range(i - kernel2.shape[0]//2, i + kernel2.shape[0]//2): 489 | if j < 0 or j >= img_dim: continue 490 | H_small2[i, j] = kernel2[j - i + kernel2.shape[0]//2] 491 | #get the svd of the 1D conv 492 | self.U_small1, self.singulars_small1, self.V_small1 = torch.svd(H_small1, some=False) 493 | self.U_small2, self.singulars_small2, self.V_small2 = torch.svd(H_small2, some=False) 494 | ZERO = 3e-2 495 | self.singulars_small1[self.singulars_small1 < ZERO] = 0 496 | self.singulars_small2[self.singulars_small2 < ZERO] = 0 497 | #calculate the singular values of the big matrix 498 | self._singulars = torch.matmul(self.singulars_small1.reshape(img_dim, 1), self.singulars_small2.reshape(1, img_dim)).reshape(img_dim**2) 499 | #sort the big matrix singulars and save the permutation 500 | self._singulars, self._perm = self._singulars.sort(descending=True) #, stable=True) 501 | 502 | def V(self, vec): 503 | #invert the permutation 504 | temp = torch.zeros(vec.shape[0], self.img_dim**2, self.channels, device=vec.device) 505 | temp[:, self._perm, :] = vec.clone().reshape(vec.shape[0], self.img_dim**2, self.channels) 506 | temp = temp.permute(0, 2, 1) 507 | #multiply the image by V from the left and by V^T from the right 508 | out = self.mat_by_img(self.V_small1, temp) 509 | out = self.img_by_mat(out, self.V_small2.transpose(0, 1)).reshape(vec.shape[0], -1) 510 | return out 511 | 512 | def Vt(self, vec): 513 | #multiply the image by V^T from the left and by V from the right 514 | temp = self.mat_by_img(self.V_small1.transpose(0, 1), vec.clone()) 515 | temp = self.img_by_mat(temp, self.V_small2).reshape(vec.shape[0], self.channels, -1) 516 | #permute the entries according to the singular values 517 | temp = temp[:, :, self._perm].permute(0, 2, 1) 518 | return temp.reshape(vec.shape[0], -1) 519 | 520 | def U(self, vec): 521 | #invert the permutation 522 | temp = torch.zeros(vec.shape[0], self.img_dim**2, self.channels, device=vec.device) 523 | temp[:, self._perm, :] = vec.clone().reshape(vec.shape[0], self.img_dim**2, self.channels) 524 | temp = temp.permute(0, 2, 1) 525 | #multiply the image by U from the left and by U^T from the right 526 | out = self.mat_by_img(self.U_small1, temp) 527 | out = self.img_by_mat(out, self.U_small2.transpose(0, 1)).reshape(vec.shape[0], -1) 528 | return out 529 | 530 | def Ut(self, vec): 531 | #multiply the image by U^T from the left and by U from the right 532 | temp = self.mat_by_img(self.U_small1.transpose(0, 1), vec.clone()) 533 | temp = self.img_by_mat(temp, self.U_small2).reshape(vec.shape[0], self.channels, -1) 534 | #permute the entries according to the singular values 535 | temp = temp[:, :, self._perm].permute(0, 2, 1) 536 | return temp.reshape(vec.shape[0], -1) 537 | 538 | def singulars(self): 539 | return self._singulars.repeat(1, 3).reshape(-1) 540 | 541 | def add_zeros(self, vec): 542 | return vec.clone().reshape(vec.shape[0], -1) -------------------------------------------------------------------------------- /guided_diffusion/unet.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | import math 4 | 5 | import numpy as np 6 | import torch as th 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from .fp16_util import convert_module_to_f16, convert_module_to_f32 11 | from .nn import ( 12 | checkpoint, 13 | conv_nd, 14 | linear, 15 | avg_pool_nd, 16 | zero_module, 17 | normalization, 18 | timestep_embedding, 19 | ) 20 | 21 | 22 | class AttentionPool2d(nn.Module): 23 | """ 24 | Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py 25 | """ 26 | 27 | def __init__( 28 | self, 29 | spacial_dim: int, 30 | embed_dim: int, 31 | num_heads_channels: int, 32 | output_dim: int = None, 33 | ): 34 | super().__init__() 35 | self.positional_embedding = nn.Parameter( 36 | th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5 37 | ) 38 | self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) 39 | self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) 40 | self.num_heads = embed_dim // num_heads_channels 41 | self.attention = QKVAttention(self.num_heads) 42 | 43 | def forward(self, x): 44 | b, c, *_spatial = x.shape 45 | x = x.reshape(b, c, -1) # NC(HW) 46 | x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) 47 | x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) 48 | x = self.qkv_proj(x) 49 | x = self.attention(x) 50 | x = self.c_proj(x) 51 | return x[:, :, 0] 52 | 53 | 54 | class TimestepBlock(nn.Module): 55 | """ 56 | Any module where forward() takes timestep embeddings as a second argument. 57 | """ 58 | 59 | @abstractmethod 60 | def forward(self, x, emb): 61 | """ 62 | Apply the module to `x` given `emb` timestep embeddings. 63 | """ 64 | 65 | 66 | class TimestepEmbedSequential(nn.Sequential, TimestepBlock): 67 | """ 68 | A sequential module that passes timestep embeddings to the children that 69 | support it as an extra input. 70 | """ 71 | 72 | def forward(self, x, emb): 73 | for layer in self: 74 | if isinstance(layer, TimestepBlock): 75 | x = layer(x, emb) 76 | else: 77 | x = layer(x) 78 | return x 79 | 80 | 81 | class Upsample(nn.Module): 82 | """ 83 | An upsampling layer with an optional convolution. 84 | 85 | :param channels: channels in the inputs and outputs. 86 | :param use_conv: a bool determining if a convolution is applied. 87 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 88 | upsampling occurs in the inner-two dimensions. 89 | """ 90 | 91 | def __init__(self, channels, use_conv, dims=2, out_channels=None): 92 | super().__init__() 93 | self.channels = channels 94 | self.out_channels = out_channels or channels 95 | self.use_conv = use_conv 96 | self.dims = dims 97 | if use_conv: 98 | self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1) 99 | 100 | def forward(self, x): 101 | assert x.shape[1] == self.channels 102 | if self.dims == 3: 103 | x = F.interpolate( 104 | x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" 105 | ) 106 | else: 107 | x = F.interpolate(x, scale_factor=2, mode="nearest") 108 | if self.use_conv: 109 | x = self.conv(x) 110 | return x 111 | 112 | 113 | class Downsample(nn.Module): 114 | """ 115 | A downsampling layer with an optional convolution. 116 | 117 | :param channels: channels in the inputs and outputs. 118 | :param use_conv: a bool determining if a convolution is applied. 119 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 120 | downsampling occurs in the inner-two dimensions. 121 | """ 122 | 123 | def __init__(self, channels, use_conv, dims=2, out_channels=None): 124 | super().__init__() 125 | self.channels = channels 126 | self.out_channels = out_channels or channels 127 | self.use_conv = use_conv 128 | self.dims = dims 129 | stride = 2 if dims != 3 else (1, 2, 2) 130 | if use_conv: 131 | self.op = conv_nd( 132 | dims, self.channels, self.out_channels, 3, stride=stride, padding=1 133 | ) 134 | else: 135 | assert self.channels == self.out_channels 136 | self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) 137 | 138 | def forward(self, x): 139 | assert x.shape[1] == self.channels 140 | return self.op(x) 141 | 142 | 143 | class ResBlock(TimestepBlock): 144 | """ 145 | A residual block that can optionally change the number of channels. 146 | 147 | :param channels: the number of input channels. 148 | :param emb_channels: the number of timestep embedding channels. 149 | :param dropout: the rate of dropout. 150 | :param out_channels: if specified, the number of out channels. 151 | :param use_conv: if True and out_channels is specified, use a spatial 152 | convolution instead of a smaller 1x1 convolution to change the 153 | channels in the skip connection. 154 | :param dims: determines if the signal is 1D, 2D, or 3D. 155 | :param use_checkpoint: if True, use gradient checkpointing on this module. 156 | :param up: if True, use this block for upsampling. 157 | :param down: if True, use this block for downsampling. 158 | """ 159 | 160 | def __init__( 161 | self, 162 | channels, 163 | emb_channels, 164 | dropout, 165 | out_channels=None, 166 | use_conv=False, 167 | use_scale_shift_norm=False, 168 | dims=2, 169 | use_checkpoint=False, 170 | up=False, 171 | down=False, 172 | ): 173 | super().__init__() 174 | self.channels = channels 175 | self.emb_channels = emb_channels 176 | self.dropout = dropout 177 | self.out_channels = out_channels or channels 178 | self.use_conv = use_conv 179 | self.use_checkpoint = use_checkpoint 180 | self.use_scale_shift_norm = use_scale_shift_norm 181 | 182 | self.in_layers = nn.Sequential( 183 | normalization(channels), 184 | nn.SiLU(), 185 | conv_nd(dims, channels, self.out_channels, 3, padding=1), 186 | ) 187 | 188 | self.updown = up or down 189 | 190 | if up: 191 | self.h_upd = Upsample(channels, False, dims) 192 | self.x_upd = Upsample(channels, False, dims) 193 | elif down: 194 | self.h_upd = Downsample(channels, False, dims) 195 | self.x_upd = Downsample(channels, False, dims) 196 | else: 197 | self.h_upd = self.x_upd = nn.Identity() 198 | 199 | self.emb_layers = nn.Sequential( 200 | nn.SiLU(), 201 | linear( 202 | emb_channels, 203 | 2 * self.out_channels if use_scale_shift_norm else self.out_channels, 204 | ), 205 | ) 206 | self.out_layers = nn.Sequential( 207 | normalization(self.out_channels), 208 | nn.SiLU(), 209 | nn.Dropout(p=dropout), 210 | zero_module( 211 | conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) 212 | ), 213 | ) 214 | 215 | if self.out_channels == channels: 216 | self.skip_connection = nn.Identity() 217 | elif use_conv: 218 | self.skip_connection = conv_nd( 219 | dims, channels, self.out_channels, 3, padding=1 220 | ) 221 | else: 222 | self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) 223 | 224 | def forward(self, x, emb): 225 | """ 226 | Apply the block to a Tensor, conditioned on a timestep embedding. 227 | 228 | :param x: an [N x C x ...] Tensor of features. 229 | :param emb: an [N x emb_channels] Tensor of timestep embeddings. 230 | :return: an [N x C x ...] Tensor of outputs. 231 | """ 232 | return checkpoint( 233 | self._forward, (x, emb), self.parameters(), self.use_checkpoint 234 | ) 235 | 236 | def _forward(self, x, emb): 237 | if self.updown: 238 | in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] 239 | h = in_rest(x) 240 | h = self.h_upd(h) 241 | x = self.x_upd(x) 242 | h = in_conv(h) 243 | else: 244 | h = self.in_layers(x) 245 | emb_out = self.emb_layers(emb).type(h.dtype) 246 | while len(emb_out.shape) < len(h.shape): 247 | emb_out = emb_out[..., None] 248 | if self.use_scale_shift_norm: 249 | out_norm, out_rest = self.out_layers[0], self.out_layers[1:] 250 | scale, shift = th.chunk(emb_out, 2, dim=1) 251 | h = out_norm(h) * (1 + scale) + shift 252 | h = out_rest(h) 253 | else: 254 | h = h + emb_out 255 | h = self.out_layers(h) 256 | return self.skip_connection(x) + h 257 | 258 | 259 | class AttentionBlock(nn.Module): 260 | """ 261 | An attention block that allows spatial positions to attend to each other. 262 | 263 | Originally ported from here, but adapted to the N-d case. 264 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. 265 | """ 266 | 267 | def __init__( 268 | self, 269 | channels, 270 | num_heads=1, 271 | num_head_channels=-1, 272 | use_checkpoint=False, 273 | use_new_attention_order=False, 274 | ): 275 | super().__init__() 276 | self.channels = channels 277 | if num_head_channels == -1: 278 | self.num_heads = num_heads 279 | else: 280 | assert ( 281 | channels % num_head_channels == 0 282 | ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" 283 | self.num_heads = channels // num_head_channels 284 | self.use_checkpoint = use_checkpoint 285 | self.norm = normalization(channels) 286 | self.qkv = conv_nd(1, channels, channels * 3, 1) 287 | if use_new_attention_order: 288 | # split qkv before split heads 289 | self.attention = QKVAttention(self.num_heads) 290 | else: 291 | # split heads before split qkv 292 | self.attention = QKVAttentionLegacy(self.num_heads) 293 | 294 | self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) 295 | 296 | def forward(self, x): 297 | return checkpoint(self._forward, (x,), self.parameters(), True) 298 | 299 | def _forward(self, x): 300 | b, c, *spatial = x.shape 301 | x = x.reshape(b, c, -1) 302 | qkv = self.qkv(self.norm(x)) 303 | h = self.attention(qkv) 304 | h = self.proj_out(h) 305 | return (x + h).reshape(b, c, *spatial) 306 | 307 | 308 | def count_flops_attn(model, _x, y): 309 | """ 310 | A counter for the `thop` package to count the operations in an 311 | attention operation. 312 | Meant to be used like: 313 | macs, params = thop.profile( 314 | model, 315 | inputs=(inputs, timestamps), 316 | custom_ops={QKVAttention: QKVAttention.count_flops}, 317 | ) 318 | """ 319 | b, c, *spatial = y[0].shape 320 | num_spatial = int(np.prod(spatial)) 321 | # We perform two matmuls with the same number of ops. 322 | # The first computes the weight matrix, the second computes 323 | # the combination of the value vectors. 324 | matmul_ops = 2 * b * (num_spatial ** 2) * c 325 | model.total_ops += th.DoubleTensor([matmul_ops]) 326 | 327 | 328 | class QKVAttentionLegacy(nn.Module): 329 | """ 330 | A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping 331 | """ 332 | 333 | def __init__(self, n_heads): 334 | super().__init__() 335 | self.n_heads = n_heads 336 | 337 | def forward(self, qkv): 338 | """ 339 | Apply QKV attention. 340 | 341 | :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. 342 | :return: an [N x (H * C) x T] tensor after attention. 343 | """ 344 | bs, width, length = qkv.shape 345 | assert width % (3 * self.n_heads) == 0 346 | ch = width // (3 * self.n_heads) 347 | q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) 348 | scale = 1 / math.sqrt(math.sqrt(ch)) 349 | weight = th.einsum( 350 | "bct,bcs->bts", q * scale, k * scale 351 | ) # More stable with f16 than dividing afterwards 352 | weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) 353 | a = th.einsum("bts,bcs->bct", weight, v) 354 | return a.reshape(bs, -1, length) 355 | 356 | @staticmethod 357 | def count_flops(model, _x, y): 358 | return count_flops_attn(model, _x, y) 359 | 360 | 361 | class QKVAttention(nn.Module): 362 | """ 363 | A module which performs QKV attention and splits in a different order. 364 | """ 365 | 366 | def __init__(self, n_heads): 367 | super().__init__() 368 | self.n_heads = n_heads 369 | 370 | def forward(self, qkv): 371 | """ 372 | Apply QKV attention. 373 | 374 | :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. 375 | :return: an [N x (H * C) x T] tensor after attention. 376 | """ 377 | bs, width, length = qkv.shape 378 | assert width % (3 * self.n_heads) == 0 379 | ch = width // (3 * self.n_heads) 380 | q, k, v = qkv.chunk(3, dim=1) 381 | scale = 1 / math.sqrt(math.sqrt(ch)) 382 | weight = th.einsum( 383 | "bct,bcs->bts", 384 | (q * scale).view(bs * self.n_heads, ch, length), 385 | (k * scale).view(bs * self.n_heads, ch, length), 386 | ) # More stable with f16 than dividing afterwards 387 | weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) 388 | a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) 389 | return a.reshape(bs, -1, length) 390 | 391 | @staticmethod 392 | def count_flops(model, _x, y): 393 | return count_flops_attn(model, _x, y) 394 | 395 | 396 | class UNetModel(nn.Module): 397 | """ 398 | The full UNet model with attention and timestep embedding. 399 | 400 | :param in_channels: channels in the input Tensor. 401 | :param model_channels: base channel count for the model. 402 | :param out_channels: channels in the output Tensor. 403 | :param num_res_blocks: number of residual blocks per downsample. 404 | :param attention_resolutions: a collection of downsample rates at which 405 | attention will take place. May be a set, list, or tuple. 406 | For example, if this contains 4, then at 4x downsampling, attention 407 | will be used. 408 | :param dropout: the dropout probability. 409 | :param channel_mult: channel multiplier for each level of the UNet. 410 | :param conv_resample: if True, use learned convolutions for upsampling and 411 | downsampling. 412 | :param dims: determines if the signal is 1D, 2D, or 3D. 413 | :param num_classes: if specified (as an int), then this model will be 414 | class-conditional with `num_classes` classes. 415 | :param use_checkpoint: use gradient checkpointing to reduce memory usage. 416 | :param num_heads: the number of attention heads in each attention layer. 417 | :param num_heads_channels: if specified, ignore num_heads and instead use 418 | a fixed channel width per attention head. 419 | :param num_heads_upsample: works with num_heads to set a different number 420 | of heads for upsampling. Deprecated. 421 | :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. 422 | :param resblock_updown: use residual blocks for up/downsampling. 423 | :param use_new_attention_order: use a different attention pattern for potentially 424 | increased efficiency. 425 | """ 426 | 427 | def __init__( 428 | self, 429 | image_size, 430 | in_channels, 431 | model_channels, 432 | out_channels, 433 | num_res_blocks, 434 | attention_resolutions, 435 | dropout=0, 436 | channel_mult=(1, 2, 4, 8), 437 | conv_resample=True, 438 | dims=2, 439 | num_classes=None, 440 | use_checkpoint=False, 441 | use_fp16=False, 442 | num_heads=1, 443 | num_head_channels=-1, 444 | num_heads_upsample=-1, 445 | use_scale_shift_norm=False, 446 | resblock_updown=False, 447 | use_new_attention_order=False, 448 | **kwargs 449 | ): 450 | super().__init__() 451 | 452 | if num_heads_upsample == -1: 453 | num_heads_upsample = num_heads 454 | 455 | self.image_size = image_size 456 | self.in_channels = in_channels 457 | self.model_channels = model_channels 458 | self.out_channels = out_channels 459 | self.num_res_blocks = num_res_blocks 460 | self.attention_resolutions = attention_resolutions 461 | self.dropout = dropout 462 | self.channel_mult = channel_mult 463 | self.conv_resample = conv_resample 464 | self.num_classes = num_classes 465 | self.use_checkpoint = use_checkpoint 466 | self.dtype = th.float16 if use_fp16 else th.float32 467 | self.num_heads = num_heads 468 | self.num_head_channels = num_head_channels 469 | self.num_heads_upsample = num_heads_upsample 470 | 471 | time_embed_dim = model_channels * 4 472 | self.time_embed = nn.Sequential( 473 | linear(model_channels, time_embed_dim), 474 | nn.SiLU(), 475 | linear(time_embed_dim, time_embed_dim), 476 | ) 477 | 478 | if self.num_classes is not None: 479 | self.label_emb = nn.Embedding(num_classes, time_embed_dim) 480 | 481 | ch = input_ch = int(channel_mult[0] * model_channels) 482 | self.input_blocks = nn.ModuleList( 483 | [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))] 484 | ) 485 | self._feature_size = ch 486 | input_block_chans = [ch] 487 | ds = 1 488 | for level, mult in enumerate(channel_mult): 489 | for _ in range(num_res_blocks): 490 | layers = [ 491 | ResBlock( 492 | ch, 493 | time_embed_dim, 494 | dropout, 495 | out_channels=int(mult * model_channels), 496 | dims=dims, 497 | use_checkpoint=use_checkpoint, 498 | use_scale_shift_norm=use_scale_shift_norm, 499 | ) 500 | ] 501 | ch = int(mult * model_channels) 502 | if ds in attention_resolutions: 503 | layers.append( 504 | AttentionBlock( 505 | ch, 506 | use_checkpoint=use_checkpoint, 507 | num_heads=num_heads, 508 | num_head_channels=num_head_channels, 509 | use_new_attention_order=use_new_attention_order, 510 | ) 511 | ) 512 | self.input_blocks.append(TimestepEmbedSequential(*layers)) 513 | self._feature_size += ch 514 | input_block_chans.append(ch) 515 | if level != len(channel_mult) - 1: 516 | out_ch = ch 517 | self.input_blocks.append( 518 | TimestepEmbedSequential( 519 | ResBlock( 520 | ch, 521 | time_embed_dim, 522 | dropout, 523 | out_channels=out_ch, 524 | dims=dims, 525 | use_checkpoint=use_checkpoint, 526 | use_scale_shift_norm=use_scale_shift_norm, 527 | down=True, 528 | ) 529 | if resblock_updown 530 | else Downsample( 531 | ch, conv_resample, dims=dims, out_channels=out_ch 532 | ) 533 | ) 534 | ) 535 | ch = out_ch 536 | input_block_chans.append(ch) 537 | ds *= 2 538 | self._feature_size += ch 539 | 540 | self.middle_block = TimestepEmbedSequential( 541 | ResBlock( 542 | ch, 543 | time_embed_dim, 544 | dropout, 545 | dims=dims, 546 | use_checkpoint=use_checkpoint, 547 | use_scale_shift_norm=use_scale_shift_norm, 548 | ), 549 | AttentionBlock( 550 | ch, 551 | use_checkpoint=use_checkpoint, 552 | num_heads=num_heads, 553 | num_head_channels=num_head_channels, 554 | use_new_attention_order=use_new_attention_order, 555 | ), 556 | ResBlock( 557 | ch, 558 | time_embed_dim, 559 | dropout, 560 | dims=dims, 561 | use_checkpoint=use_checkpoint, 562 | use_scale_shift_norm=use_scale_shift_norm, 563 | ), 564 | ) 565 | self._feature_size += ch 566 | 567 | self.output_blocks = nn.ModuleList([]) 568 | for level, mult in list(enumerate(channel_mult))[::-1]: 569 | for i in range(num_res_blocks + 1): 570 | ich = input_block_chans.pop() 571 | layers = [ 572 | ResBlock( 573 | ch + ich, 574 | time_embed_dim, 575 | dropout, 576 | out_channels=int(model_channels * mult), 577 | dims=dims, 578 | use_checkpoint=use_checkpoint, 579 | use_scale_shift_norm=use_scale_shift_norm, 580 | ) 581 | ] 582 | ch = int(model_channels * mult) 583 | if ds in attention_resolutions: 584 | layers.append( 585 | AttentionBlock( 586 | ch, 587 | use_checkpoint=use_checkpoint, 588 | num_heads=num_heads_upsample, 589 | num_head_channels=num_head_channels, 590 | use_new_attention_order=use_new_attention_order, 591 | ) 592 | ) 593 | if level and i == num_res_blocks: 594 | out_ch = ch 595 | layers.append( 596 | ResBlock( 597 | ch, 598 | time_embed_dim, 599 | dropout, 600 | out_channels=out_ch, 601 | dims=dims, 602 | use_checkpoint=use_checkpoint, 603 | use_scale_shift_norm=use_scale_shift_norm, 604 | up=True, 605 | ) 606 | if resblock_updown 607 | else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) 608 | ) 609 | ds //= 2 610 | self.output_blocks.append(TimestepEmbedSequential(*layers)) 611 | self._feature_size += ch 612 | 613 | self.out = nn.Sequential( 614 | normalization(ch), 615 | nn.SiLU(), 616 | zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)), 617 | ) 618 | 619 | def convert_to_fp16(self): 620 | """ 621 | Convert the torso of the model to float16. 622 | """ 623 | self.input_blocks.apply(convert_module_to_f16) 624 | self.middle_block.apply(convert_module_to_f16) 625 | self.output_blocks.apply(convert_module_to_f16) 626 | 627 | def convert_to_fp32(self): 628 | """ 629 | Convert the torso of the model to float32. 630 | """ 631 | self.input_blocks.apply(convert_module_to_f32) 632 | self.middle_block.apply(convert_module_to_f32) 633 | self.output_blocks.apply(convert_module_to_f32) 634 | 635 | def forward(self, x, timesteps, y=None): 636 | """ 637 | Apply the model to an input batch. 638 | 639 | :param x: an [N x C x ...] Tensor of inputs. 640 | :param timesteps: a 1-D batch of timesteps. 641 | :param y: an [N] Tensor of labels, if class-conditional. 642 | :return: an [N x C x ...] Tensor of outputs. 643 | """ 644 | assert (y is not None) == ( 645 | self.num_classes is not None 646 | ), "must specify y if and only if the model is class-conditional" 647 | 648 | hs = [] 649 | emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) 650 | 651 | if self.num_classes is not None: 652 | assert y.shape == (x.shape[0],) 653 | emb = emb + self.label_emb(y) 654 | 655 | h = x.type(self.dtype) 656 | for module in self.input_blocks: 657 | h = module(h, emb) 658 | hs.append(h) 659 | h = self.middle_block(h, emb) 660 | for module in self.output_blocks: 661 | h = th.cat([h, hs.pop()], dim=1) 662 | h = module(h, emb) 663 | h = h.type(x.dtype) 664 | return self.out(h) 665 | 666 | 667 | class SuperResModel(UNetModel): 668 | """ 669 | A UNetModel that performs super-resolution. 670 | 671 | Expects an extra kwarg `low_res` to condition on a low-resolution image. 672 | """ 673 | 674 | def __init__(self, image_size, in_channels, *args, **kwargs): 675 | super().__init__(image_size, in_channels * 2, *args, **kwargs) 676 | 677 | def forward(self, x, timesteps, low_res=None, **kwargs): 678 | _, _, new_height, new_width = x.shape 679 | upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear") 680 | x = th.cat([x, upsampled], dim=1) 681 | return super().forward(x, timesteps, **kwargs) 682 | 683 | 684 | class EncoderUNetModel(nn.Module): 685 | """ 686 | The half UNet model with attention and timestep embedding. 687 | 688 | For usage, see UNet. 689 | """ 690 | 691 | def __init__( 692 | self, 693 | image_size, 694 | in_channels, 695 | model_channels, 696 | out_channels, 697 | num_res_blocks, 698 | attention_resolutions, 699 | dropout=0, 700 | channel_mult=(1, 2, 4, 8), 701 | conv_resample=True, 702 | dims=2, 703 | use_checkpoint=False, 704 | use_fp16=False, 705 | num_heads=1, 706 | num_head_channels=-1, 707 | num_heads_upsample=-1, 708 | use_scale_shift_norm=False, 709 | resblock_updown=False, 710 | use_new_attention_order=False, 711 | pool="adaptive", 712 | ): 713 | super().__init__() 714 | 715 | if num_heads_upsample == -1: 716 | num_heads_upsample = num_heads 717 | 718 | self.in_channels = in_channels 719 | self.model_channels = model_channels 720 | self.out_channels = out_channels 721 | self.num_res_blocks = num_res_blocks 722 | self.attention_resolutions = attention_resolutions 723 | self.dropout = dropout 724 | self.channel_mult = channel_mult 725 | self.conv_resample = conv_resample 726 | self.use_checkpoint = use_checkpoint 727 | self.dtype = th.float16 if use_fp16 else th.float32 728 | self.num_heads = num_heads 729 | self.num_head_channels = num_head_channels 730 | self.num_heads_upsample = num_heads_upsample 731 | 732 | time_embed_dim = model_channels * 4 733 | self.time_embed = nn.Sequential( 734 | linear(model_channels, time_embed_dim), 735 | nn.SiLU(), 736 | linear(time_embed_dim, time_embed_dim), 737 | ) 738 | 739 | ch = int(channel_mult[0] * model_channels) 740 | self.input_blocks = nn.ModuleList( 741 | [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))] 742 | ) 743 | self._feature_size = ch 744 | input_block_chans = [ch] 745 | ds = 1 746 | for level, mult in enumerate(channel_mult): 747 | for _ in range(num_res_blocks): 748 | layers = [ 749 | ResBlock( 750 | ch, 751 | time_embed_dim, 752 | dropout, 753 | out_channels=int(mult * model_channels), 754 | dims=dims, 755 | use_checkpoint=use_checkpoint, 756 | use_scale_shift_norm=use_scale_shift_norm, 757 | ) 758 | ] 759 | ch = int(mult * model_channels) 760 | if ds in attention_resolutions: 761 | layers.append( 762 | AttentionBlock( 763 | ch, 764 | use_checkpoint=use_checkpoint, 765 | num_heads=num_heads, 766 | num_head_channels=num_head_channels, 767 | use_new_attention_order=use_new_attention_order, 768 | ) 769 | ) 770 | self.input_blocks.append(TimestepEmbedSequential(*layers)) 771 | self._feature_size += ch 772 | input_block_chans.append(ch) 773 | if level != len(channel_mult) - 1: 774 | out_ch = ch 775 | self.input_blocks.append( 776 | TimestepEmbedSequential( 777 | ResBlock( 778 | ch, 779 | time_embed_dim, 780 | dropout, 781 | out_channels=out_ch, 782 | dims=dims, 783 | use_checkpoint=use_checkpoint, 784 | use_scale_shift_norm=use_scale_shift_norm, 785 | down=True, 786 | ) 787 | if resblock_updown 788 | else Downsample( 789 | ch, conv_resample, dims=dims, out_channels=out_ch 790 | ) 791 | ) 792 | ) 793 | ch = out_ch 794 | input_block_chans.append(ch) 795 | ds *= 2 796 | self._feature_size += ch 797 | 798 | self.middle_block = TimestepEmbedSequential( 799 | ResBlock( 800 | ch, 801 | time_embed_dim, 802 | dropout, 803 | dims=dims, 804 | use_checkpoint=use_checkpoint, 805 | use_scale_shift_norm=use_scale_shift_norm, 806 | ), 807 | AttentionBlock( 808 | ch, 809 | use_checkpoint=use_checkpoint, 810 | num_heads=num_heads, 811 | num_head_channels=num_head_channels, 812 | use_new_attention_order=use_new_attention_order, 813 | ), 814 | ResBlock( 815 | ch, 816 | time_embed_dim, 817 | dropout, 818 | dims=dims, 819 | use_checkpoint=use_checkpoint, 820 | use_scale_shift_norm=use_scale_shift_norm, 821 | ), 822 | ) 823 | self._feature_size += ch 824 | self.pool = pool 825 | if pool == "adaptive": 826 | self.out = nn.Sequential( 827 | normalization(ch), 828 | nn.SiLU(), 829 | nn.AdaptiveAvgPool2d((1, 1)), 830 | zero_module(conv_nd(dims, ch, out_channels, 1)), 831 | nn.Flatten(), 832 | ) 833 | elif pool == "attention": 834 | assert num_head_channels != -1 835 | self.out = nn.Sequential( 836 | normalization(ch), 837 | nn.SiLU(), 838 | AttentionPool2d( 839 | (image_size // ds), ch, num_head_channels, out_channels 840 | ), 841 | ) 842 | elif pool == "spatial": 843 | self.out = nn.Sequential( 844 | nn.Linear(self._feature_size, 2048), 845 | nn.ReLU(), 846 | nn.Linear(2048, self.out_channels), 847 | ) 848 | elif pool == "spatial_v2": 849 | self.out = nn.Sequential( 850 | nn.Linear(self._feature_size, 2048), 851 | normalization(2048), 852 | nn.SiLU(), 853 | nn.Linear(2048, self.out_channels), 854 | ) 855 | else: 856 | raise NotImplementedError(f"Unexpected {pool} pooling") 857 | 858 | def convert_to_fp16(self): 859 | """ 860 | Convert the torso of the model to float16. 861 | """ 862 | self.input_blocks.apply(convert_module_to_f16) 863 | self.middle_block.apply(convert_module_to_f16) 864 | 865 | def convert_to_fp32(self): 866 | """ 867 | Convert the torso of the model to float32. 868 | """ 869 | self.input_blocks.apply(convert_module_to_f32) 870 | self.middle_block.apply(convert_module_to_f32) 871 | 872 | def forward(self, x, timesteps): 873 | """ 874 | Apply the model to an input batch. 875 | 876 | :param x: an [N x C x ...] Tensor of inputs. 877 | :param timesteps: a 1-D batch of timesteps. 878 | :return: an [N x K] Tensor of outputs. 879 | """ 880 | emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) 881 | 882 | results = [] 883 | h = x.type(self.dtype) 884 | for module in self.input_blocks: 885 | h = module(h, emb) 886 | if self.pool.startswith("spatial"): 887 | results.append(h.type(x.dtype).mean(dim=(2, 3))) 888 | h = self.middle_block(h, emb) 889 | if self.pool.startswith("spatial"): 890 | results.append(h.type(x.dtype).mean(dim=(2, 3))) 891 | h = th.cat(results, axis=-1) 892 | return self.out(h) 893 | else: 894 | h = h.type(x.dtype) 895 | return self.out(h) 896 | --------------------------------------------------------------------------------