├── .gitattributes ├── requirements.txt ├── models └── srgd │ └── conditional_continuous_linear_df8kost_dim128_epoch300.pth ├── inference_sample.sh ├── conf └── conditional_continuous_linear_df8kost_dim128.yaml ├── LICENSE ├── README.md ├── config.py ├── inference.py └── model.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.pth filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | denoising-diffusion-pytorch==1.8.15 2 | logzero 3 | timm 4 | torch 5 | torchvision 6 | -------------------------------------------------------------------------------- /models/srgd/conditional_continuous_linear_df8kost_dim128_epoch300.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:52ca34fbdba3059f5d0aa677dee06b8707d2552c06fdfaea9aa516b6ee03a3d7 3 | size 550400570 4 | -------------------------------------------------------------------------------- /inference_sample.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | input_dir=path/to/input_images 4 | output_dir=path/to/output_images 5 | 6 | conf="conf/conditional_continuous_linear_df8kost_dim128.yaml" 7 | model="models/srgd/conditional_continuous_linear_df8kost_dim128_epoch300.pth" 8 | test_label=0 9 | class_cond_scale=1.0 10 | seed=71 11 | 12 | python inference.py -c ${conf} -m ${model} \ 13 | --class_cond_scale ${class_cond_scale} --test_label ${test_label} --seed ${seed} \ 14 | --input_dir ${input_dir} --output_dir ${output_dir} 15 | -------------------------------------------------------------------------------- /conf/conditional_continuous_linear_df8kost_dim128.yaml: -------------------------------------------------------------------------------- 1 | save_dir: srgd 2 | prefix: conditional_continuous_linear_df8kost_dim128 3 | 4 | lr: 1e-4 5 | min_lr: 1e-7 6 | warmup_epochs: 30 7 | epochs: 300 8 | 9 | ema_decay: 0.9999 10 | 11 | class_cond_drop_prob: 0.1 12 | 13 | conditional_task_type: realsr_denoise_sr 14 | 15 | model: conditional_continuous 16 | noise_schedule: linear 17 | num_sample_steps: 250 18 | val_num_sample_steps: 250 19 | skip_val: true 20 | 21 | dataset_name: cropped_df8kost_400x400_overlap200 22 | 23 | crop_size_limit: false 24 | use_dpmpp_solver: true 25 | validation_ratio: 0.5 26 | 27 | batch_size: 16 28 | 29 | scale_size: 400 30 | crop_size: 256 31 | image_size: 256 32 | hr_image_size: 256 33 | lr_image_size: 64 34 | 35 | crop_rate: 1 36 | 37 | num_workers: 16 38 | 39 | hflip: true 40 | rotate: true 41 | 42 | sample_size: 16 43 | 44 | unet_dim: 128 45 | ddpm_unet_dim_mults: '1,2,4,8' 46 | 47 | learned_variance: false 48 | learned_sinusoidal_cond: true 49 | learned_sinusoidal_dim: 32 50 | 51 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2024 LY Corporation 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Real-SRGD: Enhancing Real-World Image Super-Resolution with Classifier-Free Guided Diffusion [ACCV2024] 2 | 3 | [[Paper]](https://openaccess.thecvf.com/content/ACCV2024/html/Doi_Real-SRGD_Enhancing_Real-World_Image_Super-Resolution_with_Classifier-Free_Guided_Diffusion_ACCV_2024_paper.html) 4 | 5 | This is the official PyTorch implementation of "Real-SRGD: Enhancing Real-World Image Super-Resolution with Classifier-Free Guided Diffusion (ACCV2024)". 6 | 7 | ## Installation 8 | 9 | This repository uses Git LFS (Large File Storage) to manage large files. Please ensure you have Git LFS installed before cloning the repository. Follow the steps below to install Git LFS and set up the project: 10 | 11 | 12 | 1. **Install Git LFS** 13 | If you don't have Git LFS installed, you can install it by following the instructions on the [Git LFS website](https://git-lfs.github.com/). 14 | Alternatively, you can install it using a package manager: 15 | 16 | - **For macOS**: 17 | ```bash 18 | brew install git-lfs 19 | ``` 20 | 21 | - **For Windows**: 22 | Download and run the [Git LFS installer](https://git-lfs.github.com/). 23 | 24 | - **For Linux**: 25 | Use your distribution's package manager. For example, on Ubuntu: 26 | ```bash 27 | sudo apt-get install git-lfs 28 | ``` 29 | 30 | 2. **Initialize Git LFS** 31 | After installing, initialize Git LFS in your repository: 32 | 33 | ```bash 34 | git lfs install 35 | ``` 36 | 37 | 3. **Clone the repository** 38 | Use Git to clone this repository. Please note that the download may take some time due to large files managed by Git LFS: 39 | 40 | ```bash 41 | git clone https://github.com/yahoojapan/srgd 42 | cd srgd 43 | ``` 44 | 45 | 4. **Install packages** 46 | 47 | ``` 48 | pip install -r requirements.txt 49 | ``` 50 | 51 | ## Inference 52 | 53 | ### Step 1: Prepare testing data 54 | 55 | Create a directory with an appropriate name and place all the images you want to super-resolve into this directory. 56 | This will be your `input_dir`. 57 | 58 | ### Step 2: Running testing command 59 | 60 | Execute the following command to run the inference script. Make sure to specify your `input_dir` and `output_dir` paths accordingly: 61 | 62 | ```bash 63 | input_dir=path/to/input_images 64 | output_dir=path/to/output_images 65 | 66 | conf="conf/conditional_continuous_linear_df8kost_dim128.yaml" 67 | model="models/srgd/conditional_continuous_linear_df8kost_dim128_epoch300.pth" 68 | test_label=0 69 | class_cond_scale=1.0 70 | seed=71 71 | 72 | python inference.py -c ${conf} -m ${model} \ 73 | --class_cond_scale ${class_cond_scale} --test_label ${test_label} --seed ${seed} \ 74 | --input_dir ${input_dir} --output_dir ${output_dir} 75 | ``` 76 | 77 | Replace `path/to/input_images` with the path to your input directory and `path/to/output_images` with the path where you want the super-resolved images to be saved. This script will process the images in the input_dir and save the results to the output_dir. 78 | 79 | A sample script `inference_sample.sh` is provided in the repository to help you get started with the inference process. You can modify this script to fit your specific needs. 80 | 81 | ## Citation 82 | 83 | ``` 84 | @inproceedings{doi2024, 85 | title={Real-SRGD: Enhancing Real-World Image Super-Resolution with Classifier-Free Guided Diffusion}, 86 | author={Kenji Doi and Shuntaro Okada and Ryota Yoshihashi and Hirokatsu Kataoka}, 87 | booktitle={Proceedings of the Asian Conference on Computer Vision (ACCV)}, 88 | year={2024}, 89 | } 90 | ``` 91 | 92 | ## License 93 | 94 | [MIT](./LICENSE) 95 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from dataclasses import dataclass 3 | 4 | 5 | @dataclass 6 | class Config: 7 | save_dir: str = 'srgd' 8 | prefix: str = 'conditional_continuous_linear' 9 | 10 | base_dir: str = './input/' 11 | dataset_name: str = 'cropped_df2kost_400x400_overlap200' 12 | 13 | model: str = 'continuous' # gaussian / elucidated/ continuous 14 | # conditional_gaussian / conditional_elucidated / conditional_continuous 15 | # conditional_selfcond_gaussian / conditional_selfcond_continuous 16 | 17 | cond_drop_prob: float = 0.1 18 | cond_scale: float = 1. # Classifier-free guidance scale for LR condition 19 | 20 | num_classes: int = 3 21 | conditional_task_type: str = 'realsr_denoise_sr' 22 | class_cond_drop_prob: float = 0.1 23 | class_cond_scale: float = 1. # Classifier-free guidance scale for class condition 24 | test_label: int = 0 25 | 26 | guidance_start_steps: int = 0 27 | class_guidance_start_steps: int = 0 28 | generation_start_steps: int = 0 29 | 30 | # for GaussianDifussion 31 | objective: str = 'pred_noise' # pred_noise / pred_x0 / pred_v 32 | beta_schedule: str = 'linear' # linear / cosine / sigmoid 33 | timesteps: int = 1000 34 | sampling_timesteps: int = 250 # for DDIM sampling, less than 1000 means DDIM sampling 35 | offset_noise_strength: float = 0. 36 | 37 | loss_type: str = 'l2' # l1 / l2 / smooth_l1 38 | 39 | # for ElucidatedDiffusion 40 | num_sample_steps: int = 32 41 | sigma_min: float = 0.002 42 | sigma_max: float = 80 43 | sigma_data: float = 0.5 44 | rho: float = 7 45 | P_mean: float = -1.2 46 | P_std: float = 1.2 47 | S_churn: float = 80 48 | S_tmin: float = 0.05 49 | S_tmax: float = 50 50 | S_noise: float = 1.003 51 | use_dpmpp_solver: bool = True 52 | 53 | # for ContinuousGaussianDiffusion 54 | noise_schedule: str = 'linear' # linear / cosine / leanred 55 | clip_sample_denoised: bool = True 56 | learned_schedule_net_hidden_dim: int = 1024 57 | learned_noise_schedule_frac_gradient: float = 1. 58 | 59 | # for GaussianDiffusion and ContinuousGaussianDiffusion 60 | min_snr_loss_weight: bool = False 61 | min_snr_gamma: float = 5 62 | 63 | val_num_sample_steps: int = 32 64 | 65 | n_fold: int = 10 # Currently only 10 is supported 66 | train_fold: str = '0' # Test with fold0 and train with the remaining folds 67 | 68 | skip_sample: bool = False # Skip sampleing 69 | skip_val: bool = False # Skip validation 70 | 71 | validation_ratio: float = 0.5 # Ratio of validation data used 72 | 73 | val_realsrv3: bool = False # Validate with RealSRv3 dataset 74 | val_drealsr: bool = False # Validate with DRealSR dataset 75 | val_realsrv3_scale: int = 4 # 2 / 4 76 | val_drealsr_scale: int = 4 # 2 / 4 77 | 78 | image_size: int = 128 # Image size for model input 79 | crop_size: int = 256 # Crop size from original image 80 | hr_image_size: int = 256 # High-resolution size 81 | lr_image_size: int = 128 # Low-resolution size 82 | crop_rate: int = 2 # Value of hr_image_size / image_size 83 | 84 | scale_size: int = 256 # Initial resize size during resize_randomcrop 85 | 86 | crop_size_limit: bool = False # Filter images with a short side less than crop_size 87 | 88 | pixel_shuffle_upsample: bool = True 89 | 90 | batch_size: int = 32 91 | 92 | sample_size: int = 16 93 | 94 | hflip: bool = False # 50% probability of horizontal flip 95 | rotate: bool = False # Random rotation in 90-degree increments 96 | interpolation: str = 'BICUBIC' 97 | shuffle: bool = True 98 | 99 | torch_compile: bool = False # Use torch.compile 100 | 101 | seed: int = 71 102 | 103 | amp: bool = False 104 | amp_dtype: str = 'float16' # float16 / bfloat16(A100) / float32 # Currently unused 105 | 106 | # for U-Net 107 | unet_dim: int = 64 108 | ddpm_unet_dim_mults: str = '1,2,4,8' 109 | full_attn: str = 'False,False,False,True' 110 | learned_variance: bool = False # Currently only False is supported 111 | learned_sinusoidal_cond: bool = True 112 | learned_sinusoidal_dim: int = 32 113 | 114 | ema_decay: float = 0.995 115 | ema_device: str = 'cuda' 116 | 117 | flash_attn: bool = False 118 | 119 | # load pretraind model manually 120 | ckpt_path: str = '' 121 | load_strict: bool = True 122 | 123 | # optimizer settings 124 | optimizer: str = 'adamw' 125 | lr: float = 1e-4 126 | min_lr: float = 1e-4 127 | weight_decay: float = 0. 128 | momentum: float = 0.9 129 | nesterov: bool = False 130 | amsgrad: bool = False 131 | madgrad_decoupled_decay: bool = True 132 | 133 | # scheduler settings 134 | epochs: int = 300 135 | warmup_epochs: int = 0 136 | warmup_lr_init: float = 1e-6 137 | plateau_mode: str = 'min' 138 | factor: float = 0.1 139 | patience: int = 4 140 | plateau_eps: float = 1e-8 141 | scheduler: str = 'cosine' # ReduceLROnPlateau / CosineAnnealingLR / 142 | # WarmupLinear / cosine 143 | cosine_interval_type: str = 'step' # step or epoch Frequency of CosineLRScheduler 144 | 145 | # Crop method from original image: centercrop / randomcrop / justresize / resize_randomcrop 146 | train_preprocess: str = 'randomcrop' 147 | valid_preprocess: str = 'centercrop' 148 | 149 | train_trans_mode: str = 'realesrgan' # simple / aug_v1 / aug_v2 / realesrgan 150 | # When realesrgan is specified, train_preprocess is ignored 151 | valid_trans_mode: str = 'simple' # simple 152 | 153 | usm_sharpener: bool = False # Whether to apply unsharpmask to HR images when realesrgan is specified 154 | 155 | interpolation: str = 'BICUBIC' # BILINEAR / BICUBIC / LANCZOS 156 | 157 | # for aug_v1 / aug_v2 158 | blur_prob: float = 0.5 # prob of OneOf 159 | advance_blur_prob: float = 0.5 160 | gaussian_blur_prob: float = 0.5 161 | sinc_blur_prob: float = 0.5 162 | sinc_blur_factor_min: float = 0.9 163 | sinc_blur_factor_max: float = 1.1 164 | image_compression_prob: float = 0.5 # prob of image compression 165 | quality_lower: int = 50 166 | quality_upper: int = 100 167 | noise_prob: float = 0.5 # prob of noise 168 | gauss_noise_prob: float = 0.5 169 | iso_noise_prob: float = 0.5 170 | multiplicative_noise_prob: float = 0.5 171 | 172 | train: bool = True 173 | test: bool = False 174 | debug: bool = False 175 | 176 | save_validation_sample: bool = False # Save sample images during validation 177 | save_validation_hr_sample: bool = False # Save HR sample images during validation 178 | 179 | save_every_epoch: bool = False # Save model at every epoch 180 | 181 | test_target: str = 'best_loss' # best_loss / best_psnr / best_ssim / best_lpips 182 | 183 | num_workers: int = 4 184 | device: str = 'cuda' 185 | pin_memory: bool = True 186 | model_dir: str = 'models' 187 | log_dir: str = 'logs' 188 | print_freq: int = 0 189 | 190 | 191 | def load_config(config_file): 192 | with open(config_file, 'r') as fp: 193 | opts = yaml.safe_load(fp) 194 | return Config(**opts) 195 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import random 4 | import glob 5 | import numpy as np 6 | import torchvision 7 | import torchvision.transforms as T 8 | 9 | from config import load_config 10 | from model import get_model 11 | from logzero import logger 12 | 13 | from packaging import version 14 | from argparse import ArgumentParser 15 | from PIL import Image 16 | from tqdm import tqdm 17 | 18 | toPILImage = torchvision.transforms.ToPILImage() 19 | toTensor = torchvision.transforms.ToTensor() 20 | 21 | def parse_args(): 22 | parser = ArgumentParser() 23 | parser.add_argument('-c', '--conf', required=True, 24 | help='Path to config file') 25 | parser.add_argument('-m', '--ckpt_path', type=str, required=True) 26 | parser.add_argument('--input_dir', type=str, required=True) 27 | parser.add_argument('--output_dir', type=str, required=True) 28 | parser.add_argument('--batch_size', type=int, default=8) 29 | parser.add_argument('--num_sample_steps', type=int, default=250) 30 | parser.add_argument('--interpolation', type=str, default='bicubic') 31 | parser.add_argument('--cond_scale', type=float, default=1.0) 32 | parser.add_argument('--class_cond_scale', type=float, default=1.0) 33 | parser.add_argument('--guidance_start_steps', type=int, default=0) 34 | parser.add_argument('--class_guidance_start_steps', type=int, default=0) 35 | parser.add_argument('--generation_start_steps', type=int, default=0) 36 | parser.add_argument('--start_index', type=int, default=0) 37 | parser.add_argument('--end_index', type=int, default=None) 38 | parser.add_argument('--test_label', type=int, default=None) 39 | parser.add_argument('--no_amp', dest='amp', action="store_false") 40 | parser.add_argument('--no_dpmpp_solver', dest='use_dpmpp_solver', action="store_false") 41 | parser.add_argument('--seed', type=int, default=71) 42 | parser.add_argument('--backend', type=str, default='ddp') 43 | 44 | return parser.parse_args() 45 | 46 | 47 | def seed_everything(seed): 48 | random.seed(seed) 49 | os.environ['PYTHONHASHSEED'] = str(seed) 50 | np.random.seed(seed) 51 | torch.manual_seed(seed) 52 | if torch.cuda.is_available(): 53 | torch.cuda.manual_seed(seed) 54 | torch.backends.cuda.matmul.allow_tf32 = True 55 | torch.backends.cudnn.deterministic = False 56 | torch.backends.cudnn.benchmark = True 57 | 58 | 59 | def sr_target_image(image, sr_model, scale=4, batch_size=8, 60 | test_label=2, cond_scale=1.0, guidance_start_steps=0, 61 | class_cond_scale=1.0, class_guidance_start_steps=0, 62 | generation_start_steps=0, num_sample_steps=250, 63 | enable_amp=False, interpolation='bicubic', seed=71): 64 | width, height = image.size 65 | 66 | if interpolation == 'bicubic': 67 | interpolation_mode = T.InterpolationMode.BICUBIC 68 | elif interpolation == 'lanczos': 69 | interpolation_mode = T.InterpolationMode.BICUBIC 70 | 71 | resize_hr_size = T.Resize((height*scale, width*scale), interpolation=interpolation_mode) 72 | 73 | resized_tensor = toTensor(resize_hr_size(image)).unsqueeze(0) 74 | condition_x = resized_tensor.to(sr_model.device) 75 | 76 | if test_label is not None: 77 | test_label = torch.LongTensor([test_label]).to(sr_model.device) 78 | else: 79 | test_label = None 80 | 81 | seed_everything(seed) 82 | 83 | # with torch.inference_mode(), autocast(enabled=enable_amp): 84 | with torch.inference_mode(): 85 | output = sr_model.tiled_sample(batch_size=batch_size, 86 | condition_x=condition_x, class_label=test_label, 87 | cond_scale=cond_scale, guidance_start_steps=guidance_start_steps, 88 | class_cond_scale=class_cond_scale, 89 | class_guidance_start_steps=class_guidance_start_steps, 90 | generation_start_steps=generation_start_steps, 91 | num_sample_steps=num_sample_steps, 92 | amp=enable_amp) 93 | 94 | sr_img = toPILImage(output[0]) 95 | new_width, new_height = sr_img.size 96 | assert width*4 == new_width 97 | assert height*4 == new_height 98 | return sr_img 99 | 100 | 101 | def try_open_image(image_path): 102 | try: 103 | img = Image.open(image_path).convert('RGB') 104 | return img 105 | except (IOError, SyntaxError) as e: 106 | return None 107 | 108 | def batch_sr_target_images(input_dir, output_dir, sr_model, scale=4, 109 | batch_size=8, test_label=2, 110 | cond_scale=1.0, guidance_start_steps=0, 111 | class_cond_scale=1.0, class_guidance_start_steps=0, 112 | generation_start_steps=0, num_sample_steps=250, 113 | start_index=0, end_index=None, 114 | enable_amp=False, interpolation='bicubic', seed=71): 115 | 116 | print(f"save images at: {output_dir}") 117 | 118 | os.makedirs(output_dir, exist_ok=True) 119 | 120 | image_list = sorted(glob.glob(f"{input_dir}/*"))[start_index:end_index] 121 | 122 | for filename in tqdm(image_list, disable=False): 123 | save_filename = os.path.basename(filename).replace('.png', '_out.png') 124 | save_path = os.path.join(output_dir, save_filename) 125 | 126 | if os.path.exists(save_path): 127 | print('skip') 128 | else: 129 | image = try_open_image(filename) 130 | if image is not None: 131 | cur_sr_img = sr_target_image(image, sr_model, scale=scale, 132 | batch_size=batch_size, test_label=test_label, 133 | cond_scale=cond_scale, guidance_start_steps=guidance_start_steps, 134 | class_cond_scale=class_cond_scale, 135 | class_guidance_start_steps=class_guidance_start_steps, 136 | generation_start_steps=generation_start_steps, 137 | num_sample_steps=num_sample_steps, 138 | enable_amp=enable_amp, 139 | interpolation=interpolation, seed=seed) 140 | cur_sr_img.save(save_path) 141 | else: 142 | print('Invalid image or unable to open image:', filename) 143 | 144 | 145 | if __name__ == '__main__': 146 | args = parse_args() 147 | conf = load_config(args.conf) 148 | conf.num_sample_steps = args.num_sample_steps 149 | conf.ckpt_path = args.ckpt_path 150 | 151 | if version.parse(torch.__version__) < version.parse("2.0.0"): 152 | conf.flash_attn = False 153 | 154 | ema_model = get_model(conf, logger) 155 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 156 | sr_model = ema_model.module.eval().to(device) 157 | 158 | print(args) 159 | 160 | batch_sr_target_images(args.input_dir, args.output_dir, sr_model, 161 | scale=4, batch_size=args.batch_size, test_label=args.test_label, 162 | cond_scale=args.cond_scale, guidance_start_steps=args.guidance_start_steps, 163 | class_cond_scale=args.class_cond_scale, 164 | class_guidance_start_steps=args.class_guidance_start_steps, 165 | generation_start_steps=args.generation_start_steps, 166 | num_sample_steps=args.num_sample_steps, 167 | start_index=args.start_index, end_index=args.end_index, 168 | enable_amp=args.amp, interpolation=args.interpolation, seed=args.seed) 169 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from timm.utils import ModelEmaV2 3 | 4 | from functools import partial 5 | from collections import namedtuple 6 | 7 | import math 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from denoising_diffusion_pytorch import ( 12 | Unet, 13 | GaussianDiffusion, 14 | ElucidatedDiffusion, 15 | ) 16 | from denoising_diffusion_pytorch.attend import Attend 17 | 18 | from einops import rearrange, repeat, reduce 19 | from einops.layers.torch import Rearrange 20 | from torch.special import expm1 21 | from torch import sqrt 22 | from tqdm import tqdm 23 | from torch.cuda.amp import autocast 24 | 25 | 26 | # This code is significantly inspired by or directly copied from the 27 | # "denoising-diffusion-pytorch" implementation found at the following GitHub repository: 28 | # https://github.com/lucidrains/denoising-diffusion-pytorch 29 | # 30 | # All credit for the original implementation and concept goes to the authors 31 | # of the "denoising-diffusion-pytorch" project. Any errors or shortcomings in 32 | # this adaptation are my own. 33 | 34 | # constants 35 | 36 | ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start']) 37 | 38 | # normalization functions 39 | 40 | def normalize_to_neg_one_to_one(img): 41 | return img * 2 - 1 42 | 43 | def unnormalize_to_zero_to_one(t): 44 | return (t + 1) * 0.5 45 | 46 | 47 | # helpers functions 48 | 49 | def exists(x): 50 | return x is not None 51 | 52 | def default(val, d): 53 | if exists(val): 54 | return val 55 | return d() if callable(d) else d 56 | 57 | def cast_tuple(t, length = 1): 58 | if isinstance(t, tuple): 59 | return t 60 | return ((t,) * length) 61 | 62 | def divisible_by(numer, denom): 63 | return (numer % denom) == 0 64 | 65 | def identity(t, *args, **kwargs): 66 | return t 67 | 68 | ## small helper modules 69 | 70 | class PixelShuffleUpsample(nn.Module): 71 | """ 72 | code shared by @MalumaDev at DALLE2-pytorch for addressing checkboard artifacts 73 | https://arxiv.org/ftp/arxiv/papers/1707/1707.02937.pdf 74 | """ 75 | def __init__(self, dim, dim_out = None): 76 | super().__init__() 77 | dim_out = default(dim_out, dim) 78 | conv = nn.Conv2d(dim, dim_out * 4, 1) 79 | 80 | self.net = nn.Sequential( 81 | conv, 82 | nn.SiLU(), 83 | nn.PixelShuffle(2) 84 | ) 85 | 86 | self.init_conv_(conv) 87 | 88 | def init_conv_(self, conv): 89 | o, i, h, w = conv.weight.shape 90 | conv_weight = torch.empty(o // 4, i, h, w) 91 | nn.init.kaiming_uniform_(conv_weight) 92 | conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...') 93 | 94 | conv.weight.data.copy_(conv_weight) 95 | nn.init.zeros_(conv.bias.data) 96 | 97 | def forward(self, x): 98 | return self.net(x) 99 | 100 | def Upsample(dim, dim_out = None): 101 | return nn.Sequential( 102 | nn.Upsample(scale_factor = 2, mode = 'nearest'), 103 | nn.Conv2d(dim, default(dim_out, dim), 3, padding = 1) 104 | ) 105 | 106 | def Downsample(dim, dim_out = None): 107 | return nn.Sequential( 108 | Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = 2, p2 = 2), 109 | nn.Conv2d(dim * 4, default(dim_out, dim), 1) 110 | ) 111 | 112 | def kldiv_loss(pred, target, reduction = 'none'): 113 | loss = torch.nn.KLDivLoss(reduction=reduction)(F.log_softmax(pred, dim=1), F.softmax(target, dim=1)) 114 | return loss 115 | 116 | def get_coord_and_pad(height, width, tile_size=256): 117 | if height <= tile_size and width <= tile_size: 118 | new_height, new_width = tile_size, tile_size 119 | else: 120 | new_height = ((height-1)//tile_size + 1) * tile_size + tile_size 121 | new_width = ((width-1)//tile_size + 1) * tile_size + tile_size 122 | 123 | left = (new_width - width) // 2 124 | top = (new_height - height) // 2 125 | right = left + width 126 | bottom = top + height 127 | coord = (left, top, right, bottom) 128 | 129 | pad_left = left 130 | pad_right = new_width - pad_left - width 131 | pad_top = top 132 | pad_bottom = new_height - pad_top - height 133 | pad = (pad_left, pad_right, pad_top, pad_bottom) 134 | 135 | return coord, pad 136 | 137 | def get_coords(h, w, tile_size, tile_stride, diff=0): 138 | hi_list = list(range(0, h - tile_size + 1, tile_stride)) 139 | if (h - tile_size) % tile_stride != 0: 140 | hi_list.append(h - tile_size) 141 | 142 | wi_list = list(range(0, w - tile_size + 1, tile_stride)) 143 | if (w - tile_size) % tile_stride != 0: 144 | wi_list.append(w - tile_size) 145 | 146 | coords = [] 147 | for hi in hi_list: 148 | for wi in wi_list: 149 | coords.append((hi + diff, hi + tile_size + diff, wi + diff, wi + tile_size + diff)) 150 | return coords 151 | 152 | def get_area(coords, height, width): 153 | top = height 154 | bottom = 0 155 | left = width 156 | right = 0 157 | 158 | for coord in coords: 159 | hs, he, ws, we = coord 160 | if hs < top: 161 | top = hs 162 | if he > bottom: 163 | bottom = he 164 | if ws < left: 165 | left = ws 166 | if we > right: 167 | right = we 168 | coord = (left, top, right, bottom) 169 | 170 | area_height = bottom - top 171 | area_width = right - left 172 | 173 | pad_left = left 174 | pad_right = width - pad_left - area_width 175 | pad_top = top 176 | pad_bottom = height - pad_top - area_height 177 | pad = (pad_left, pad_right, pad_top, pad_bottom) 178 | 179 | return coord, pad 180 | 181 | import torch.fft as fft 182 | def Fourier_filter(x, threshold, scale): 183 | # FFT 184 | x_freq = fft.fftn(x, dim=(-2, -1)) 185 | x_freq = fft.fftshift(x_freq, dim=(-2, -1)) 186 | 187 | B, C, H, W = x_freq.shape 188 | mask = torch.ones((B, C, H, W)).to(x.device) 189 | 190 | crow, ccol = H // 2, W //2 191 | mask[..., crow - threshold:crow + threshold, ccol - threshold:ccol + threshold] = scale 192 | x_freq = x_freq * mask 193 | 194 | # IFFT 195 | x_freq = fft.ifftshift(x_freq, dim=(-2, -1)) 196 | x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real 197 | 198 | return x_filtered 199 | 200 | 201 | class RMSNorm(nn.Module): 202 | def __init__(self, dim): 203 | super().__init__() 204 | self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) 205 | 206 | def forward(self, x): 207 | return F.normalize(x, dim = 1) * self.g * (x.shape[1] ** 0.5) 208 | 209 | class SinusoidalPosEmb(nn.Module): 210 | def __init__(self, dim): 211 | super().__init__() 212 | self.dim = dim 213 | 214 | def forward(self, x): 215 | device = x.device 216 | half_dim = self.dim // 2 217 | emb = math.log(10000) / (half_dim - 1) 218 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb) 219 | emb = x[:, None] * emb[None, :] 220 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 221 | return emb 222 | 223 | class RandomOrLearnedSinusoidalPosEmb(nn.Module): 224 | """ following @crowsonkb 's lead with random (learned optional) sinusoidal pos emb """ 225 | """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """ 226 | 227 | def __init__(self, dim, is_random = False): 228 | super().__init__() 229 | assert divisible_by(dim, 2) 230 | half_dim = dim // 2 231 | self.weights = nn.Parameter(torch.randn(half_dim), requires_grad = not is_random) 232 | 233 | def forward(self, x): 234 | x = rearrange(x, 'b -> b 1') 235 | freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi 236 | fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1) 237 | fouriered = torch.cat((x, fouriered), dim = -1) 238 | return fouriered 239 | 240 | 241 | # building block modules 242 | 243 | class Block(nn.Module): 244 | def __init__(self, dim, dim_out, groups = 8): 245 | super().__init__() 246 | self.proj = nn.Conv2d(dim, dim_out, 3, padding = 1) 247 | self.norm = nn.GroupNorm(groups, dim_out) 248 | self.act = nn.SiLU() 249 | 250 | def forward(self, x, scale_shift = None): 251 | x = self.proj(x) 252 | x = self.norm(x) 253 | 254 | if exists(scale_shift): 255 | scale, shift = scale_shift 256 | x = x * (scale + 1) + shift 257 | 258 | x = self.act(x) 259 | return x 260 | 261 | class ResnetBlock(nn.Module): 262 | def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8): 263 | super().__init__() 264 | self.mlp = nn.Sequential( 265 | nn.SiLU(), 266 | nn.Linear(time_emb_dim, dim_out * 2) 267 | ) if exists(time_emb_dim) else None 268 | 269 | self.block1 = Block(dim, dim_out, groups = groups) 270 | self.block2 = Block(dim_out, dim_out, groups = groups) 271 | self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() 272 | 273 | def forward(self, x, time_emb = None): 274 | 275 | scale_shift = None 276 | if exists(self.mlp) and exists(time_emb): 277 | time_emb = self.mlp(time_emb) 278 | time_emb = rearrange(time_emb, 'b c -> b c 1 1') 279 | scale_shift = time_emb.chunk(2, dim = 1) 280 | 281 | h = self.block1(x, scale_shift = scale_shift) 282 | 283 | h = self.block2(h) 284 | 285 | return h + self.res_conv(x) 286 | 287 | class LinearAttention(nn.Module): 288 | def __init__( 289 | self, 290 | dim, 291 | heads = 4, 292 | dim_head = 32 293 | ): 294 | super().__init__() 295 | self.scale = dim_head ** -0.5 296 | self.heads = heads 297 | hidden_dim = dim_head * heads 298 | 299 | self.norm = RMSNorm(dim) 300 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) 301 | 302 | self.to_out = nn.Sequential( 303 | nn.Conv2d(hidden_dim, dim, 1), 304 | RMSNorm(dim) 305 | ) 306 | 307 | def forward(self, x): 308 | b, c, h, w = x.shape 309 | 310 | x = self.norm(x) 311 | 312 | qkv = self.to_qkv(x).chunk(3, dim = 1) 313 | q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv) 314 | 315 | q = q.softmax(dim = -2) 316 | k = k.softmax(dim = -1) 317 | 318 | q = q * self.scale 319 | 320 | context = torch.einsum('b h d n, b h e n -> b h d e', k, v) 321 | 322 | out = torch.einsum('b h d e, b h d n -> b h e n', context, q) 323 | out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w) 324 | return self.to_out(out) 325 | 326 | class Attention(nn.Module): 327 | def __init__( 328 | self, 329 | dim, 330 | heads = 4, 331 | dim_head = 32, 332 | flash = False 333 | ): 334 | super().__init__() 335 | self.heads = heads 336 | hidden_dim = dim_head * heads 337 | 338 | self.norm = RMSNorm(dim) 339 | self.attend = Attend(flash = flash) 340 | 341 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) 342 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 343 | 344 | def forward(self, x): 345 | b, c, h, w = x.shape 346 | 347 | x = self.norm(x) 348 | 349 | qkv = self.to_qkv(x).chunk(3, dim = 1) 350 | q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h (x y) c', h = self.heads), qkv) 351 | 352 | out = self.attend(q, k, v) 353 | 354 | out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w) 355 | return self.to_out(out) 356 | 357 | 358 | 359 | class SRUnet(Unet): 360 | def __init__( 361 | self, 362 | dim, 363 | init_dim = None, 364 | out_dim = None, 365 | dim_mults = (1, 2, 4, 8), 366 | channels = 3, 367 | self_condition = True, # Set self_condition=True to allow input of LR images 368 | resnet_block_groups = 8, 369 | learned_variance = False, 370 | learned_sinusoidal_cond = False, 371 | random_fourier_features = False, 372 | learned_sinusoidal_dim = 16, 373 | attn_dim_head = 32, 374 | attn_heads = 4, 375 | full_attn = (False, False, False, True), 376 | flash_attn = False, 377 | pixel_shuffle_upsample = True, 378 | ): 379 | super().__init__( 380 | dim, 381 | init_dim, 382 | out_dim, 383 | dim_mults, 384 | channels, 385 | self_condition, 386 | resnet_block_groups, 387 | learned_variance, 388 | learned_sinusoidal_cond, 389 | random_fourier_features, 390 | learned_sinusoidal_dim, 391 | attn_dim_head, 392 | attn_heads, 393 | full_attn, 394 | flash_attn 395 | ) 396 | # determine dimensions 397 | 398 | self.channels = channels 399 | self.self_condition = self_condition 400 | input_channels = channels * (2 if self_condition else 1) 401 | 402 | init_dim = default(init_dim, dim) 403 | self.init_conv = nn.Conv2d(input_channels, init_dim, 7, padding = 3) 404 | 405 | dims = [init_dim, *map(lambda m: dim * m, dim_mults)] 406 | self.dims = dims 407 | in_out = list(zip(dims[:-1], dims[1:])) 408 | 409 | block_klass = partial(ResnetBlock, groups = resnet_block_groups) 410 | 411 | # time embeddings 412 | 413 | time_dim = dim * 4 414 | 415 | self.random_or_learned_sinusoidal_cond = learned_sinusoidal_cond or random_fourier_features 416 | 417 | if self.random_or_learned_sinusoidal_cond: 418 | sinu_pos_emb = RandomOrLearnedSinusoidalPosEmb(learned_sinusoidal_dim, random_fourier_features) 419 | fourier_dim = learned_sinusoidal_dim + 1 420 | else: 421 | sinu_pos_emb = SinusoidalPosEmb(dim) 422 | fourier_dim = dim 423 | 424 | self.time_mlp = nn.Sequential( 425 | sinu_pos_emb, 426 | nn.Linear(fourier_dim, time_dim), 427 | nn.GELU(), 428 | nn.Linear(time_dim, time_dim) 429 | ) 430 | 431 | # attention 432 | 433 | num_stages = len(dim_mults) 434 | full_attn = cast_tuple(full_attn, num_stages) 435 | attn_heads = cast_tuple(attn_heads, num_stages) 436 | attn_dim_head = cast_tuple(attn_dim_head, num_stages) 437 | 438 | assert len(full_attn) == len(dim_mults) 439 | 440 | FullAttention = partial(Attention, flash = flash_attn) 441 | 442 | # layers 443 | 444 | self.downs = nn.ModuleList([]) 445 | self.ups = nn.ModuleList([]) 446 | num_resolutions = len(in_out) 447 | 448 | for ind, ((dim_in, dim_out), layer_full_attn, layer_attn_heads, layer_attn_dim_head) in enumerate(zip(in_out, full_attn, attn_heads, attn_dim_head)): 449 | is_last = ind >= (num_resolutions - 1) 450 | 451 | attn_klass = FullAttention if layer_full_attn else LinearAttention 452 | 453 | self.downs.append(nn.ModuleList([ 454 | block_klass(dim_in, dim_in, time_emb_dim = time_dim), 455 | block_klass(dim_in, dim_in, time_emb_dim = time_dim), 456 | attn_klass(dim_in, dim_head = layer_attn_dim_head, heads = layer_attn_heads), 457 | Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1) 458 | ])) 459 | 460 | mid_dim = dims[-1] 461 | self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim) 462 | self.mid_attn = FullAttention(mid_dim, heads = attn_heads[-1], dim_head = attn_dim_head[-1]) 463 | self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim) 464 | 465 | # upsample klass 466 | # Modify to enable the use of PixelshuffleUpsample 467 | upsample_klass = Upsample if not pixel_shuffle_upsample else PixelShuffleUpsample 468 | 469 | for ind, ((dim_in, dim_out), layer_full_attn, layer_attn_heads, layer_attn_dim_head) in enumerate(zip(*map(reversed, (in_out, full_attn, attn_heads, attn_dim_head)))): 470 | is_last = ind == (len(in_out) - 1) 471 | 472 | attn_klass = FullAttention if layer_full_attn else LinearAttention 473 | 474 | self.ups.append(nn.ModuleList([ 475 | block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim), 476 | block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim), 477 | attn_klass(dim_out, dim_head = layer_attn_dim_head, heads = layer_attn_heads), 478 | upsample_klass(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding = 1) 479 | ])) 480 | 481 | default_out_dim = channels * (1 if not learned_variance else 2) 482 | self.out_dim = default(out_dim, default_out_dim) 483 | 484 | self.final_res_block = block_klass(dim * 2, dim, time_emb_dim = time_dim) 485 | self.final_conv = nn.Conv2d(dim, self.out_dim, 1) 486 | 487 | 488 | def forward(self, x, time, x_self_cond = None): 489 | assert all([divisible_by(d, self.downsample_factor) for d in x.shape[-2:]]), f'your input dimensions {x.shape[-2:]} need to be divisible by {self.downsample_factor}, given the unet' 490 | 491 | if self.self_condition: 492 | x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x)) 493 | # x = torch.cat((x_self_cond, x), dim = 1) 494 | x = torch.cat((x, x_self_cond), dim = 1) 495 | 496 | x = self.init_conv(x) 497 | r = x.clone() 498 | 499 | t = self.time_mlp(time) 500 | 501 | h = [] 502 | 503 | for block1, block2, attn, downsample in self.downs: 504 | x = block1(x, t) 505 | h.append(x) 506 | 507 | x = block2(x, t) 508 | x = attn(x) + x 509 | h.append(x) 510 | 511 | x = downsample(x) 512 | 513 | x = self.mid_block1(x, t) 514 | x = self.mid_attn(x) + x 515 | x = self.mid_block2(x, t) 516 | 517 | for block1, block2, attn, upsample in self.ups: 518 | h_ = h.pop() 519 | x = torch.cat((x, h_), dim = 1) 520 | x = block1(x, t) 521 | 522 | h_ = h.pop() 523 | x = torch.cat((x, h_), dim = 1) 524 | x = block2(x, t) 525 | x = attn(x) + x 526 | 527 | x = upsample(x) 528 | 529 | x = torch.cat((x, r), dim = 1) 530 | 531 | x = self.final_res_block(x, t) 532 | return self.final_conv(x) 533 | 534 | 535 | # SRUnet that can add class conditions 536 | class ConditionalSRUnet(Unet): 537 | def __init__( 538 | self, 539 | dim, 540 | init_dim = None, 541 | out_dim = None, 542 | dim_mults = (1, 2, 4, 8), 543 | channels = 3, 544 | self_condition = True, # Set self_condition=True to allow input of LR images 545 | resnet_block_groups = 8, 546 | learned_variance = False, 547 | learned_sinusoidal_cond = False, 548 | random_fourier_features = False, 549 | learned_sinusoidal_dim = 16, 550 | attn_dim_head = 32, 551 | attn_heads = 4, 552 | full_attn = (False, False, False, True), 553 | flash_attn = False, 554 | pixel_shuffle_upsample = True, 555 | num_classes = None 556 | ): 557 | super().__init__( 558 | dim, 559 | init_dim, 560 | out_dim, 561 | dim_mults, 562 | channels, 563 | self_condition, 564 | resnet_block_groups, 565 | learned_variance, 566 | learned_sinusoidal_cond, 567 | random_fourier_features, 568 | learned_sinusoidal_dim, 569 | attn_dim_head, 570 | attn_heads, 571 | full_attn, 572 | flash_attn 573 | ) 574 | # determine dimensions 575 | 576 | self.num_classes = num_classes 577 | 578 | self.channels = channels 579 | self.self_condition = self_condition 580 | input_channels = channels * (2 if self_condition else 1) 581 | 582 | init_dim = default(init_dim, dim) 583 | self.init_conv = nn.Conv2d(input_channels, init_dim, 7, padding = 3) 584 | 585 | dims = [init_dim, *map(lambda m: dim * m, dim_mults)] 586 | in_out = list(zip(dims[:-1], dims[1:])) 587 | 588 | block_klass = partial(ResnetBlock, groups = resnet_block_groups) 589 | 590 | # time embeddings 591 | 592 | time_dim = dim * 4 593 | 594 | self.random_or_learned_sinusoidal_cond = learned_sinusoidal_cond or random_fourier_features 595 | 596 | if self.random_or_learned_sinusoidal_cond: 597 | sinu_pos_emb = RandomOrLearnedSinusoidalPosEmb(learned_sinusoidal_dim, random_fourier_features) 598 | fourier_dim = learned_sinusoidal_dim + 1 599 | else: 600 | sinu_pos_emb = SinusoidalPosEmb(dim) 601 | fourier_dim = dim 602 | 603 | self.time_mlp = nn.Sequential( 604 | sinu_pos_emb, 605 | nn.Linear(fourier_dim, time_dim), 606 | nn.GELU(), 607 | nn.Linear(time_dim, time_dim) 608 | ) 609 | 610 | # class conditional embeddings 611 | # Align with time_dim 612 | if self.num_classes is not None: 613 | class_emb = nn.Embedding(self.num_classes, dim) 614 | self.class_mlp = nn.Sequential( 615 | class_emb, 616 | nn.Linear(dim, time_dim), 617 | nn.GELU(), 618 | nn.Linear(time_dim, time_dim) 619 | ) 620 | 621 | # attention 622 | 623 | num_stages = len(dim_mults) 624 | full_attn = cast_tuple(full_attn, num_stages) 625 | attn_heads = cast_tuple(attn_heads, num_stages) 626 | attn_dim_head = cast_tuple(attn_dim_head, num_stages) 627 | 628 | assert len(full_attn) == len(dim_mults) 629 | 630 | FullAttention = partial(Attention, flash = flash_attn) 631 | 632 | # layers 633 | 634 | self.downs = nn.ModuleList([]) 635 | self.ups = nn.ModuleList([]) 636 | num_resolutions = len(in_out) 637 | 638 | for ind, ((dim_in, dim_out), layer_full_attn, layer_attn_heads, layer_attn_dim_head) in enumerate(zip(in_out, full_attn, attn_heads, attn_dim_head)): 639 | is_last = ind >= (num_resolutions - 1) 640 | 641 | attn_klass = FullAttention if layer_full_attn else LinearAttention 642 | 643 | self.downs.append(nn.ModuleList([ 644 | block_klass(dim_in, dim_in, time_emb_dim = time_dim), 645 | block_klass(dim_in, dim_in, time_emb_dim = time_dim), 646 | attn_klass(dim_in, dim_head = layer_attn_dim_head, heads = layer_attn_heads), 647 | Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1) 648 | ])) 649 | 650 | mid_dim = dims[-1] 651 | self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim) 652 | self.mid_attn = FullAttention(mid_dim, heads = attn_heads[-1], dim_head = attn_dim_head[-1]) 653 | self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim) 654 | 655 | # upsample klass 656 | # Modify to enable the use of PixelshuffleUpsample 657 | upsample_klass = Upsample if not pixel_shuffle_upsample else PixelShuffleUpsample 658 | 659 | for ind, ((dim_in, dim_out), layer_full_attn, layer_attn_heads, layer_attn_dim_head) in enumerate(zip(*map(reversed, (in_out, full_attn, attn_heads, attn_dim_head)))): 660 | is_last = ind == (len(in_out) - 1) 661 | 662 | attn_klass = FullAttention if layer_full_attn else LinearAttention 663 | 664 | self.ups.append(nn.ModuleList([ 665 | block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim), 666 | block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim), 667 | attn_klass(dim_out, dim_head = layer_attn_dim_head, heads = layer_attn_heads), 668 | upsample_klass(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding = 1) 669 | ])) 670 | 671 | default_out_dim = channels * (1 if not learned_variance else 2) 672 | self.out_dim = default(out_dim, default_out_dim) 673 | 674 | self.final_res_block = block_klass(dim * 2, dim, time_emb_dim = time_dim) 675 | self.final_conv = nn.Conv2d(dim, self.out_dim, 1) 676 | 677 | 678 | def forward(self, x, time, class_label = None, x_self_cond = None): 679 | assert all([divisible_by(d, self.downsample_factor) for d in x.shape[-2:]]), f'your input dimensions {x.shape[-2:]} need to be divisible by {self.downsample_factor}, given the unet' 680 | 681 | if self.self_condition: 682 | x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x)) 683 | # x = torch.cat((x_self_cond, x), dim = 1) 684 | x = torch.cat((x, x_self_cond), dim = 1) 685 | 686 | x = self.init_conv(x) 687 | r = x.clone() 688 | 689 | t = self.time_mlp(time) 690 | 691 | # class conditional 692 | if class_label is not None: 693 | c = self.class_mlp(class_label) 694 | t = t + c 695 | 696 | h = [] 697 | 698 | for block1, block2, attn, downsample in self.downs: 699 | x = block1(x, t) 700 | h.append(x) 701 | 702 | x = block2(x, t) 703 | x = attn(x) + x 704 | h.append(x) 705 | 706 | x = downsample(x) 707 | 708 | x = self.mid_block1(x, t) 709 | x = self.mid_attn(x) + x 710 | x = self.mid_block2(x, t) 711 | 712 | for block1, block2, attn, upsample in self.ups: 713 | x = torch.cat((x, h.pop()), dim = 1) 714 | x = block1(x, t) 715 | 716 | x = torch.cat((x, h.pop()), dim = 1) 717 | x = block2(x, t) 718 | x = attn(x) + x 719 | 720 | x = upsample(x) 721 | 722 | x = torch.cat((x, r), dim = 1) 723 | 724 | x = self.final_res_block(x, t) 725 | return self.final_conv(x) 726 | 727 | 728 | # gaussian diffusion trainer class 729 | 730 | def extract(a, t, x_shape): 731 | b, *_ = t.shape 732 | out = a.gather(-1, t) 733 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 734 | 735 | 736 | # diffusion helpers 737 | 738 | def right_pad_dims_to(x, t): 739 | padding_dims = x.ndim - t.ndim 740 | if padding_dims <= 0: 741 | return t 742 | return t.view(*t.shape, *((1,) * padding_dims)) 743 | 744 | def linear_beta_schedule(timesteps): 745 | """ 746 | linear schedule, proposed in original ddpm paper 747 | """ 748 | scale = 1000 / timesteps 749 | beta_start = scale * 0.0001 750 | beta_end = scale * 0.02 751 | return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64) 752 | 753 | def cosine_beta_schedule(timesteps, s = 0.008): 754 | """ 755 | cosine schedule 756 | as proposed in https://openreview.net/forum?id=-NEXDKk8gZ 757 | """ 758 | steps = timesteps + 1 759 | t = torch.linspace(0, timesteps, steps, dtype = torch.float64) / timesteps 760 | alphas_cumprod = torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** 2 761 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0] 762 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) 763 | return torch.clip(betas, 0, 0.999) 764 | 765 | def sigmoid_beta_schedule(timesteps, start = -3, end = 3, tau = 1, clamp_min = 1e-5): 766 | """ 767 | sigmoid schedule 768 | proposed in https://arxiv.org/abs/2212.11972 - Figure 8 769 | better for images > 64x64, when used during training 770 | """ 771 | steps = timesteps + 1 772 | t = torch.linspace(0, timesteps, steps, dtype = torch.float64) / timesteps 773 | v_start = torch.tensor(start / tau).sigmoid() 774 | v_end = torch.tensor(end / tau).sigmoid() 775 | alphas_cumprod = (-((t * (end - start) + start) / tau).sigmoid() + v_end) / (v_end - v_start) 776 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0] 777 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) 778 | return torch.clip(betas, 0, 0.999) 779 | 780 | 781 | class GaussianDiffusionSR(GaussianDiffusion): 782 | def set_seed(self, seed): 783 | torch.cuda.manual_seed(seed) 784 | 785 | def __init__( 786 | self, 787 | model, 788 | *, 789 | image_size, 790 | timesteps = 1000, 791 | sampling_timesteps = None, 792 | objective = 'pred_v', 793 | beta_schedule = 'sigmoid', 794 | schedule_fn_kwargs = dict(), 795 | ddim_sampling_eta = 0., 796 | auto_normalize = True, 797 | offset_noise_strength = 0., # https://www.crosslabs.org/blog/diffusion-with-offset-noise 798 | min_snr_loss_weight = False, # https://arxiv.org/abs/2303.09556 799 | min_snr_gamma = 5, 800 | cond_drop_prob = 0., 801 | loss_type = 'l2', 802 | ): 803 | super().__init__( 804 | model=model, 805 | image_size=image_size, 806 | timesteps=timesteps, 807 | sampling_timesteps=sampling_timesteps, 808 | objective=objective, 809 | beta_schedule=beta_schedule, 810 | schedule_fn_kwargs=schedule_fn_kwargs, 811 | ddim_sampling_eta=ddim_sampling_eta, 812 | auto_normalize=auto_normalize, 813 | offset_noise_strength=offset_noise_strength, 814 | min_snr_loss_weight=min_snr_loss_weight, 815 | min_snr_gamma=min_snr_gamma 816 | ) 817 | assert not (type(self) == GaussianDiffusion and model.channels != model.out_dim) 818 | assert not model.random_or_learned_sinusoidal_cond 819 | 820 | self.model = model 821 | 822 | self.channels = self.model.channels 823 | self.self_condition = self.model.self_condition 824 | 825 | self.image_size = image_size 826 | 827 | self.objective = objective 828 | 829 | assert objective in {'pred_noise', 'pred_x0', 'pred_v'}, 'objective must be either pred_noise (predict noise) or pred_x0 (predict image start) or pred_v (predict v [v-parameterization as defined in appendix D of progressive distillation paper, used in imagen-video successfully])' 830 | 831 | if beta_schedule == 'linear': 832 | beta_schedule_fn = linear_beta_schedule 833 | elif beta_schedule == 'cosine': 834 | beta_schedule_fn = cosine_beta_schedule 835 | elif beta_schedule == 'sigmoid': 836 | beta_schedule_fn = sigmoid_beta_schedule 837 | else: 838 | raise ValueError(f'unknown beta schedule {beta_schedule}') 839 | 840 | betas = beta_schedule_fn(timesteps, **schedule_fn_kwargs) 841 | 842 | alphas = 1. - betas 843 | alphas_cumprod = torch.cumprod(alphas, dim=0) 844 | alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.) 845 | 846 | timesteps, = betas.shape 847 | self.num_timesteps = int(timesteps) 848 | 849 | # sampling related parameters 850 | 851 | self.sampling_timesteps = default(sampling_timesteps, timesteps) # default num sampling timesteps to number of timesteps at training 852 | 853 | assert self.sampling_timesteps <= timesteps 854 | self.is_ddim_sampling = self.sampling_timesteps < timesteps 855 | self.ddim_sampling_eta = ddim_sampling_eta 856 | 857 | # helper function to register buffer from float64 to float32 858 | 859 | register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32)) 860 | 861 | register_buffer('betas', betas) 862 | register_buffer('alphas_cumprod', alphas_cumprod) 863 | register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) 864 | 865 | # calculations for diffusion q(x_t | x_{t-1}) and others 866 | 867 | register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod)) 868 | register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod)) 869 | register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod)) 870 | register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod)) 871 | register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1)) 872 | 873 | # calculations for posterior q(x_{t-1} | x_t, x_0) 874 | 875 | posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) 876 | 877 | # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) 878 | 879 | register_buffer('posterior_variance', posterior_variance) 880 | 881 | # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain 882 | 883 | register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20))) 884 | register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) 885 | register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)) 886 | 887 | # offset noise strength - in blogpost, they claimed 0.1 was ideal 888 | 889 | self.offset_noise_strength = offset_noise_strength 890 | 891 | # derive loss weight 892 | # snr - signal noise ratio 893 | 894 | snr = alphas_cumprod / (1 - alphas_cumprod) 895 | 896 | # https://arxiv.org/abs/2303.09556 897 | 898 | maybe_clipped_snr = snr.clone() 899 | if min_snr_loss_weight: 900 | maybe_clipped_snr.clamp_(max = min_snr_gamma) 901 | 902 | if objective == 'pred_noise': 903 | register_buffer('loss_weight', maybe_clipped_snr / snr) 904 | elif objective == 'pred_x0': 905 | register_buffer('loss_weight', maybe_clipped_snr) 906 | elif objective == 'pred_v': 907 | register_buffer('loss_weight', maybe_clipped_snr / (snr + 1)) 908 | 909 | # auto-normalization of data [0, 1] -> [-1, 1] - can turn off by setting it to be False 910 | 911 | self.normalize = normalize_to_neg_one_to_one if auto_normalize else identity 912 | self.unnormalize = unnormalize_to_zero_to_one if auto_normalize else identity 913 | 914 | self.cond_drop_prob = cond_drop_prob 915 | self.loss_type = loss_type 916 | 917 | def model_predictions(self, x, t, condition_x = None, cond_scale = 1.0, clip_x_start = False, rederive_pred_noise = False): 918 | if cond_scale == 1.0: 919 | model_output = self.model(x, t, condition_x) 920 | else: 921 | cond_out = self.model(x, t, condition_x) 922 | null_out = self.model(x, t, None) 923 | model_output = null_out + (cond_out - null_out) * cond_scale 924 | 925 | maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity 926 | 927 | if self.objective == 'pred_noise': 928 | pred_noise = model_output 929 | x_start = self.predict_start_from_noise(x, t, pred_noise) 930 | x_start = maybe_clip(x_start) 931 | 932 | if clip_x_start and rederive_pred_noise: 933 | pred_noise = self.predict_noise_from_start(x, t, x_start) 934 | 935 | elif self.objective == 'pred_x0': 936 | x_start = model_output 937 | x_start = maybe_clip(x_start) 938 | pred_noise = self.predict_noise_from_start(x, t, x_start) 939 | 940 | elif self.objective == 'pred_v': 941 | v = model_output 942 | x_start = self.predict_start_from_v(x, t, v) 943 | x_start = maybe_clip(x_start) 944 | pred_noise = self.predict_noise_from_start(x, t, x_start) 945 | 946 | return ModelPrediction(pred_noise, x_start) 947 | 948 | def p_mean_variance(self, x, t, condition_x = None, cond_scale = 1.0, clip_denoised = True): 949 | preds = self.model_predictions(x, t, condition_x, cond_scale) 950 | x_start = preds.pred_x_start 951 | 952 | if clip_denoised: 953 | x_start.clamp_(-1., 1.) 954 | 955 | model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x, t = t) 956 | return model_mean, posterior_variance, posterior_log_variance, x_start 957 | 958 | @torch.inference_mode() 959 | def p_sample(self, x, t: int, condition_x = None, cond_scale = 1.0): 960 | b, *_, device = *x.shape, self.device 961 | batched_times = torch.full((b,), t, device = device, dtype = torch.long) 962 | model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, condition_x = condition_x, cond_scale = cond_scale, clip_denoised = True) 963 | noise = torch.randn_like(x) if t > 0 else 0. # no noise if t == 0 964 | pred_img = model_mean + (0.5 * model_log_variance).exp() * noise 965 | return pred_img, x_start 966 | 967 | @torch.inference_mode() 968 | def p_sample_loop(self, shape, condition_x, cond_scale, guidance_start_steps, generation_start_steps, sampling_timesteps, 969 | with_images, with_x0_images): 970 | batch, device = shape[0], self.device 971 | 972 | if generation_start_steps > 0: 973 | target_time = self.num_timesteps - generation_start_steps 974 | t = torch.tensor([target_time]*batch, device=device).long() 975 | img = self.q_sample(x_start=condition_x, t=t) 976 | else: 977 | img = torch.randn(shape, device = device) 978 | 979 | if with_images: 980 | image_list = [] 981 | image_list.append(img.clone().detach().cpu()) 982 | 983 | x_start = None 984 | 985 | if with_x0_images: 986 | x0_image_list = [] 987 | x0_image_list.append(img.clone().detach().cpu()) 988 | 989 | for i, t in enumerate(tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps)): 990 | if i < generation_start_steps: 991 | continue 992 | if i < guidance_start_steps: 993 | cur_cond_scale = 1.0 994 | else: 995 | cur_cond_scale = cond_scale 996 | img, x_start = self.p_sample(img, t, condition_x, cur_cond_scale) 997 | if with_images: 998 | image_list.append(img.clone().detach().cpu()) 999 | if with_x0_images: 1000 | x0_image_list.append(x_start.clone().detach().cpu()) 1001 | 1002 | if with_images: 1003 | if with_x0_images: 1004 | return self.unnormalize(img), image_list, x0_image_list 1005 | else: 1006 | return self.unnormalize(img), image_list 1007 | else: 1008 | return self.unnormalize(img) 1009 | 1010 | @torch.inference_mode() 1011 | def ddim_sample(self, shape, condition_x, cond_scale, guidance_start_steps, generation_start_steps, sampling_timesteps, 1012 | with_images, with_x0_images): 1013 | batch, device, total_timesteps, eta = shape[0], self.device, self.num_timesteps, self.ddim_sampling_eta 1014 | 1015 | times = torch.linspace(-1, total_timesteps - 1, steps = sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps 1016 | times = list(reversed(times.int().tolist())) 1017 | time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)] 1018 | 1019 | if generation_start_steps > 0: 1020 | target_time = time_pairs[generation_start_steps][0] 1021 | t = torch.tensor([target_time]*batch, device=device).long() 1022 | img = self.q_sample(x_start=condition_x, t=t) 1023 | else: 1024 | img = torch.randn(shape, device = device) 1025 | 1026 | if with_images: 1027 | image_list = [] 1028 | image_list.append(img.clone().detach().cpu()) 1029 | 1030 | x_start = None 1031 | 1032 | if with_x0_images: 1033 | x0_image_list = [] 1034 | x0_image_list.append(img.clone().detach().cpu()) 1035 | 1036 | for i, (time, time_next) in enumerate(tqdm(time_pairs, desc = 'sampling loop time step')): 1037 | if i < generation_start_steps: 1038 | continue 1039 | if i < guidance_start_steps: 1040 | cur_cond_scale = 1.0 1041 | else: 1042 | cur_cond_scale = cond_scale 1043 | time_cond = torch.full((batch,), time, device = device, dtype = torch.long) 1044 | pred_noise, x_start, *_ = self.model_predictions(img, time_cond, condition_x, cur_cond_scale, clip_x_start = True, rederive_pred_noise = True) 1045 | 1046 | if time_next < 0: 1047 | img = x_start 1048 | if with_images: 1049 | image_list.append(img.clone().detach().cpu()) 1050 | if with_x0_images: 1051 | x0_image_list.append(img.clone().detach().cpu()) 1052 | continue 1053 | 1054 | alpha = self.alphas_cumprod[time] 1055 | alpha_next = self.alphas_cumprod[time_next] 1056 | 1057 | sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt() 1058 | c = (1 - alpha_next - sigma ** 2).sqrt() 1059 | 1060 | noise = torch.randn_like(img) 1061 | 1062 | img = x_start * alpha_next.sqrt() + \ 1063 | c * pred_noise + \ 1064 | sigma * noise 1065 | 1066 | if with_images: 1067 | image_list.append(img.clone().detach().cpu()) 1068 | if with_x0_images: 1069 | x0_image_list.append(x_start.clone().detach().cpu()) 1070 | 1071 | if with_images: 1072 | if with_x0_images: 1073 | return self.unnormalize(img), image_list, x0_image_list 1074 | else: 1075 | return self.unnormalize(img), image_list 1076 | else: 1077 | return self.unnormalize(img) 1078 | 1079 | @torch.inference_mode() 1080 | def tiled_sample(self, batch_size=4, tile_size=256, tile_stride=256, 1081 | condition_x=None, class_label=None, 1082 | cond_scale=1.0, guidance_start_steps=0, 1083 | class_cond_scale=1.0, class_guidance_start_steps=0, 1084 | generation_start_steps=0, num_sample_steps=None, 1085 | with_images=False, with_x0_images=False, start_white_noise=True, amp=False): 1086 | 1087 | num_sample_steps = default(num_sample_steps, self.sampling_timesteps) 1088 | condition_x = normalize_to_neg_one_to_one(condition_x) 1089 | 1090 | batch, c, h, w = condition_x.shape 1091 | 1092 | # pad condition_x 1093 | coord, pad = get_coord_and_pad(h, w) 1094 | left, top, right, bottom = coord 1095 | condition_x = F.pad(condition_x, pad, mode='reflect') 1096 | 1097 | device, total_timesteps, eta = self.device, self.num_timesteps, self.ddim_sampling_eta 1098 | 1099 | times = torch.linspace(-1, total_timesteps - 1, steps = num_sample_steps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps 1100 | times = list(reversed(times.int().tolist())) 1101 | time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)] 1102 | 1103 | if generation_start_steps > 0: 1104 | target_time = time_pairs[generation_start_steps][0] 1105 | t = torch.tensor([target_time]*batch, device=device).long() 1106 | img = self.q_sample(x_start=condition_x, t=t) 1107 | else: 1108 | img = torch.randn(condition_x.shape, device = device) 1109 | 1110 | if with_images: 1111 | image_list = [] 1112 | image_list.append(img.clone().detach().cpu()) 1113 | 1114 | 1115 | if with_x0_images: 1116 | x0_image_list = [] 1117 | x0_image_list.append(img.clone().detach().cpu()) 1118 | 1119 | # Pre-calculate tile regions 1120 | _, _, height, width = condition_x.shape 1121 | coords0 = get_coords(height, width, tile_size, tile_size, diff=0) 1122 | if height <= tile_size and width <= tile_size: 1123 | coords1 = get_coords(height, width, tile_size, tile_stride, diff=0) 1124 | else: 1125 | coords1 = get_coords(height-tile_size, width-tile_size, tile_size, tile_stride, diff=tile_size//2) 1126 | coord_list = [coords0, coords1] 1127 | 1128 | # Get the region of the smaller coords 1129 | small_coord, small_pad = get_area(coords1, height, width) 1130 | sleft, stop, sright, sbottom = small_coord 1131 | 1132 | # Pad the outside of the smaller region of condition_x with 0 1133 | cropped_condition_x = condition_x[:,:,stop:sbottom,sleft:sright] 1134 | condition_x = F.pad(cropped_condition_x, small_pad, mode='constant', value=0) 1135 | 1136 | # x_start = None 1137 | x_start = img.clone() 1138 | pred_noise = torch.zeros_like(img) 1139 | 1140 | for i, (time, time_next) in enumerate(tqdm(time_pairs, desc = 'sampling loop time step')): 1141 | if i < generation_start_steps: 1142 | continue 1143 | if i < guidance_start_steps: 1144 | cur_cond_scale = 1.0 1145 | else: 1146 | cur_cond_scale = cond_scale 1147 | 1148 | time_cond = torch.full((batch,), time, device = device, dtype = torch.long) 1149 | 1150 | cur_coords = coord_list[i%2] 1151 | 1152 | alpha = self.alphas_cumprod[time] 1153 | alpha_next = self.alphas_cumprod[time_next] 1154 | sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt() 1155 | cc = (1 - alpha_next - sigma ** 2).sqrt() 1156 | 1157 | minibatch_index = 0 1158 | minibatch = torch.zeros((batch_size, c, tile_size, tile_size), device=condition_x.device) 1159 | minibatch_condition = torch.zeros((batch_size, c, tile_size, tile_size), device=condition_x.device) 1160 | output_indexes = [None] * batch_size 1161 | for hs, he, ws, we in cur_coords: 1162 | minibatch[minibatch_index] = img[:, :, hs:he, ws:we] 1163 | minibatch_condition[minibatch_index] = condition_x[:, :, hs:he, ws:we] 1164 | output_indexes[minibatch_index] = (hs, ws) 1165 | minibatch_index += 1 1166 | 1167 | if minibatch_index == batch_size: 1168 | with autocast(enabled=amp): 1169 | tile_pred_noise, tile_x_start, *_ = self.model_predictions(minibatch, time_cond, minibatch_condition, cur_cond_scale, clip_x_start = True, rederive_pred_noise = True) 1170 | 1171 | for k in range(minibatch_index): 1172 | hs, ws = output_indexes[k] 1173 | pred_noise[:, :, hs:hs+tile_size, ws:ws+tile_size] = tile_pred_noise[k] 1174 | x_start[:, :, hs:hs+tile_size, ws:ws+tile_size] = tile_x_start[k] 1175 | # cur_img = tile_x_start[k] * alpha_next.sqrt() + cc * tile_pred_noise[k] + sigma * noise[:, :, hs:hs+tile_size, ws:ws+tile_size] 1176 | # img[:, :, hs:hs+tile_size, ws:ws+tile_size] = cur_img 1177 | 1178 | minibatch_index = 0 1179 | 1180 | if minibatch_index > 0: 1181 | with autocast(enabled=amp): 1182 | tile_pred_noise, tile_x_start, *_ = self.model_predictions(minibatch[0:minibatch_index], time_cond, minibatch_condition[0:minibatch_index], cur_cond_scale, clip_x_start = True, rederive_pred_noise = True) 1183 | 1184 | for k in range(minibatch_index): 1185 | hs, ws = output_indexes[k] 1186 | pred_noise[:, :, hs:hs+tile_size, ws:ws+tile_size] = tile_pred_noise[k] 1187 | x_start[:, :, hs:hs+tile_size, ws:ws+tile_size] = tile_x_start[k] 1188 | # cur_img = tile_x_start[k] * alpha_next.sqrt() + cc * tile_pred_noise[k] + sigma * noise[:, :, hs:hs+tile_size, ws:ws+tile_size] 1189 | # img[:, :, hs:hs+tile_size, ws:ws+tile_size] = cur_img 1190 | 1191 | noise = torch.randn_like(img) 1192 | 1193 | img = x_start * alpha_next.sqrt() + \ 1194 | cc * pred_noise + \ 1195 | sigma * noise 1196 | 1197 | if time_next < 0: 1198 | img = x_start 1199 | if with_images: 1200 | image_list.append(img.clone().detach().cpu()) 1201 | if with_x0_images: 1202 | x0_image_list.append(img.clone().detach().cpu()) 1203 | continue 1204 | 1205 | if i%2 == 1: 1206 | # Reconstruct by removing the padding part of img when odd times 1207 | cropped_img = img[:,:,stop:sbottom,sleft:sright] 1208 | t = torch.tensor([time_next]*batch, device=device).long() 1209 | img = self.q_sample(x_start=torch.zeros_like(condition_x), t=t) 1210 | img[:,:,stop:sbottom,sleft:sright] = cropped_img 1211 | 1212 | if with_images: 1213 | image_list.append(img.clone().detach().cpu()) 1214 | if with_x0_images: 1215 | x0_image_list.append(x_start.clone().detach().cpu()) 1216 | 1217 | img = img[:,:,top:bottom,left:right] 1218 | img.clamp_(-1., 1.) 1219 | img = unnormalize_to_zero_to_one(img) 1220 | 1221 | if with_images: 1222 | if with_x0_images: 1223 | return img, image_list, x0_image_list 1224 | else: 1225 | return img, image_list 1226 | else: 1227 | return img 1228 | 1229 | @torch.inference_mode() 1230 | def sample(self, batch_size = 16, condition_x = None, cond_scale = 1.0, 1231 | guidance_start_steps = 0, generation_start_steps = 0, 1232 | num_sample_steps = None, with_images = False, with_x0_images = False): 1233 | # image_size, channels = self.image_size, self.channels 1234 | sampling_timesteps = default(num_sample_steps, self.sampling_timesteps) 1235 | 1236 | _n, _c, h, w = condition_x.shape 1237 | condition_x = normalize_to_neg_one_to_one(condition_x) 1238 | sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample 1239 | return sample_fn((batch_size, self.channels, h, w), condition_x, cond_scale, 1240 | guidance_start_steps, generation_start_steps, sampling_timesteps, 1241 | with_images, with_x0_images) 1242 | 1243 | @property 1244 | def loss_fn(self): 1245 | if self.loss_type == 'l1': 1246 | return F.l1_loss 1247 | elif self.loss_type == 'l2': 1248 | return F.mse_loss 1249 | elif self.loss_type == 'smooth_l1': 1250 | return F.smooth_l1_loss 1251 | else: 1252 | raise ValueError(f'invalid loss type {self.loss_type}') 1253 | 1254 | def p_losses(self, x_start, t, condition_x, 1255 | noise = None, offset_noise_strength = None, clip_x_start = True): 1256 | b, c, h, w = x_start.shape 1257 | 1258 | noise = default(noise, lambda: torch.randn_like(x_start)) 1259 | 1260 | # offset noise - https://www.crosslabs.org/blog/diffusion-with-offset-noise 1261 | 1262 | offset_noise_strength = default(offset_noise_strength, self.offset_noise_strength) 1263 | 1264 | if offset_noise_strength > 0.: 1265 | offset_noise = torch.randn(x_start.shape[:2], device = self.device) 1266 | noise += offset_noise_strength * rearrange(offset_noise, 'b c -> b c 1 1') 1267 | 1268 | # noise sample 1269 | 1270 | x = self.q_sample(x_start = x_start, t = t, noise = noise) 1271 | 1272 | x_self_cond = condition_x 1273 | 1274 | # predict and take gradient step 1275 | 1276 | model_out = self.model(x, t, x_self_cond) 1277 | 1278 | if self.objective == 'pred_noise': 1279 | target = noise 1280 | elif self.objective == 'pred_x0': 1281 | target = x_start 1282 | elif self.objective == 'pred_v': 1283 | v = self.predict_v(x_start, t, noise) 1284 | target = v 1285 | else: 1286 | raise ValueError(f'unknown objective {self.objective}') 1287 | 1288 | loss = self.loss_fn(model_out, target, reduction = 'none') 1289 | loss = reduce(loss, 'b ... -> b (...)', 'mean') 1290 | current_loss_weight = extract(self.loss_weight, t, loss.shape) 1291 | loss = loss * current_loss_weight 1292 | loss = reduce(loss, 'b ... -> b', 'mean') 1293 | 1294 | return loss.mean() 1295 | 1296 | def forward(self, img, condition_x, *args, **kwargs): 1297 | b, c, h, w, device, img_size, = *img.shape, img.device, self.image_size 1298 | assert h == img_size and w == img_size, f'height and width of image must be {img_size}' 1299 | t = torch.randint(0, self.num_timesteps, (b,), device=device).long() 1300 | 1301 | img = self.normalize(img) 1302 | if torch.rand(1) < self.cond_drop_prob: 1303 | condition_x = None 1304 | else: 1305 | condition_x = self.normalize(condition_x) 1306 | 1307 | return self.p_losses(img, t, condition_x, *args, **kwargs) 1308 | 1309 | 1310 | 1311 | class ConditionalGaussianDiffusionSR(GaussianDiffusion): 1312 | def set_seed(self, seed): 1313 | torch.cuda.manual_seed(seed) 1314 | 1315 | def __init__( 1316 | self, 1317 | model, 1318 | *, 1319 | image_size, 1320 | timesteps = 1000, 1321 | sampling_timesteps = None, 1322 | objective = 'pred_v', 1323 | beta_schedule = 'sigmoid', 1324 | schedule_fn_kwargs = dict(), 1325 | ddim_sampling_eta = 0., 1326 | auto_normalize = True, 1327 | offset_noise_strength = 0., # https://www.crosslabs.org/blog/diffusion-with-offset-noise 1328 | min_snr_loss_weight = False, # https://arxiv.org/abs/2303.09556 1329 | min_snr_gamma = 5, 1330 | cond_drop_prob = 0., 1331 | class_cond_drop_prob = 0., 1332 | loss_type = 'l2', 1333 | ): 1334 | super().__init__( 1335 | model=model, 1336 | image_size=image_size, 1337 | timesteps=timesteps, 1338 | sampling_timesteps=sampling_timesteps, 1339 | objective=objective, 1340 | beta_schedule=beta_schedule, 1341 | schedule_fn_kwargs=schedule_fn_kwargs, 1342 | ddim_sampling_eta=ddim_sampling_eta, 1343 | auto_normalize=auto_normalize, 1344 | offset_noise_strength=offset_noise_strength, 1345 | min_snr_loss_weight=min_snr_loss_weight, 1346 | min_snr_gamma=min_snr_gamma 1347 | ) 1348 | assert not (type(self) == GaussianDiffusion and model.channels != model.out_dim) 1349 | assert not model.random_or_learned_sinusoidal_cond 1350 | 1351 | self.model = model 1352 | 1353 | self.channels = self.model.channels 1354 | self.self_condition = self.model.self_condition 1355 | 1356 | self.image_size = image_size 1357 | 1358 | self.objective = objective 1359 | 1360 | assert objective in {'pred_noise', 'pred_x0', 'pred_v'}, 'objective must be either pred_noise (predict noise) or pred_x0 (predict image start) or pred_v (predict v [v-parameterization as defined in appendix D of progressive distillation paper, used in imagen-video successfully])' 1361 | 1362 | if beta_schedule == 'linear': 1363 | beta_schedule_fn = linear_beta_schedule 1364 | elif beta_schedule == 'cosine': 1365 | beta_schedule_fn = cosine_beta_schedule 1366 | elif beta_schedule == 'sigmoid': 1367 | beta_schedule_fn = sigmoid_beta_schedule 1368 | else: 1369 | raise ValueError(f'unknown beta schedule {beta_schedule}') 1370 | 1371 | betas = beta_schedule_fn(timesteps, **schedule_fn_kwargs) 1372 | 1373 | alphas = 1. - betas 1374 | alphas_cumprod = torch.cumprod(alphas, dim=0) 1375 | alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.) 1376 | 1377 | timesteps, = betas.shape 1378 | self.num_timesteps = int(timesteps) 1379 | 1380 | # sampling related parameters 1381 | 1382 | self.sampling_timesteps = default(sampling_timesteps, timesteps) # default num sampling timesteps to number of timesteps at training 1383 | 1384 | assert self.sampling_timesteps <= timesteps 1385 | self.is_ddim_sampling = self.sampling_timesteps < timesteps 1386 | self.ddim_sampling_eta = ddim_sampling_eta 1387 | 1388 | # helper function to register buffer from float64 to float32 1389 | 1390 | register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32)) 1391 | 1392 | register_buffer('betas', betas) 1393 | register_buffer('alphas_cumprod', alphas_cumprod) 1394 | register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) 1395 | 1396 | # calculations for diffusion q(x_t | x_{t-1}) and others 1397 | 1398 | register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod)) 1399 | register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod)) 1400 | register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod)) 1401 | register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod)) 1402 | register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1)) 1403 | 1404 | # calculations for posterior q(x_{t-1} | x_t, x_0) 1405 | 1406 | posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) 1407 | 1408 | # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) 1409 | 1410 | register_buffer('posterior_variance', posterior_variance) 1411 | 1412 | # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain 1413 | 1414 | register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20))) 1415 | register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) 1416 | register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)) 1417 | 1418 | # offset noise strength - in blogpost, they claimed 0.1 was ideal 1419 | 1420 | self.offset_noise_strength = offset_noise_strength 1421 | 1422 | # derive loss weight 1423 | # snr - signal noise ratio 1424 | 1425 | snr = alphas_cumprod / (1 - alphas_cumprod) 1426 | 1427 | # https://arxiv.org/abs/2303.09556 1428 | 1429 | maybe_clipped_snr = snr.clone() 1430 | if min_snr_loss_weight: 1431 | maybe_clipped_snr.clamp_(max = min_snr_gamma) 1432 | 1433 | if objective == 'pred_noise': 1434 | register_buffer('loss_weight', maybe_clipped_snr / snr) 1435 | elif objective == 'pred_x0': 1436 | register_buffer('loss_weight', maybe_clipped_snr) 1437 | elif objective == 'pred_v': 1438 | register_buffer('loss_weight', maybe_clipped_snr / (snr + 1)) 1439 | 1440 | # auto-normalization of data [0, 1] -> [-1, 1] - can turn off by setting it to be False 1441 | 1442 | self.normalize = normalize_to_neg_one_to_one if auto_normalize else identity 1443 | self.unnormalize = unnormalize_to_zero_to_one if auto_normalize else identity 1444 | 1445 | self.cond_drop_prob = cond_drop_prob 1446 | self.class_cond_drop_prob = class_cond_drop_prob 1447 | self.loss_type = loss_type 1448 | 1449 | def model_predictions(self, x, t, condition_x = None, class_label = None, 1450 | cond_scale = 1.0, class_cond_scale = 1.0, 1451 | clip_x_start = False, rederive_pred_noise = False): 1452 | 1453 | # Currently not supported for CFG with both condition_x and class_label 1454 | if (cond_scale != 1.0) and (class_cond_scale != 1.0): 1455 | raise NotImplementedError("Currently, you cannot specify both cond_scale and class_cond_scale at the same time.") 1456 | 1457 | if cond_scale == 1.0 and class_cond_scale == 1.0: 1458 | model_output = self.model(x, t, class_label, condition_x) 1459 | elif cond_scale != 1.0: 1460 | cond_out = self.model(x, t, class_label, condition_x) 1461 | null_out = self.model(x, t, class_label, None) 1462 | model_output = null_out + (cond_out - null_out) * cond_scale 1463 | elif class_cond_scale != 1.0: 1464 | cond_out = self.model(x, t, class_label, condition_x) 1465 | null_out = self.model(x, t, None, condition_x) 1466 | model_output = null_out + (cond_out - null_out) * class_cond_scale 1467 | 1468 | maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity 1469 | 1470 | if self.objective == 'pred_noise': 1471 | pred_noise = model_output 1472 | x_start = self.predict_start_from_noise(x, t, pred_noise) 1473 | x_start = maybe_clip(x_start) 1474 | 1475 | if clip_x_start and rederive_pred_noise: 1476 | pred_noise = self.predict_noise_from_start(x, t, x_start) 1477 | 1478 | elif self.objective == 'pred_x0': 1479 | x_start = model_output 1480 | x_start = maybe_clip(x_start) 1481 | pred_noise = self.predict_noise_from_start(x, t, x_start) 1482 | 1483 | elif self.objective == 'pred_v': 1484 | v = model_output 1485 | x_start = self.predict_start_from_v(x, t, v) 1486 | x_start = maybe_clip(x_start) 1487 | pred_noise = self.predict_noise_from_start(x, t, x_start) 1488 | 1489 | return ModelPrediction(pred_noise, x_start) 1490 | 1491 | def p_mean_variance(self, x, t, condition_x = None, class_label = None, 1492 | cond_scale = 1.0, class_cond_scale = 1.0, clip_denoised = True): 1493 | preds = self.model_predictions(x, t, condition_x, class_label, cond_scale, class_cond_scale) 1494 | x_start = preds.pred_x_start 1495 | 1496 | if clip_denoised: 1497 | x_start.clamp_(-1., 1.) 1498 | 1499 | model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x, t = t) 1500 | return model_mean, posterior_variance, posterior_log_variance, x_start 1501 | 1502 | @torch.inference_mode() 1503 | def p_sample(self, x, t: int, condition_x = None, class_label = None, 1504 | cond_scale = 1.0, class_cond_scale = 1.0): 1505 | b, *_, device = *x.shape, self.device 1506 | batched_times = torch.full((b,), t, device = device, dtype = torch.long) 1507 | model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, 1508 | condition_x = condition_x, class_label = class_label, 1509 | cond_scale = cond_scale, 1510 | class_cond_scale = class_cond_scale, 1511 | clip_denoised = True) 1512 | noise = torch.randn_like(x) if t > 0 else 0. # no noise if t == 0 1513 | pred_img = model_mean + (0.5 * model_log_variance).exp() * noise 1514 | return pred_img, x_start 1515 | 1516 | @torch.inference_mode() 1517 | def p_sample_loop(self, shape, condition_x, class_label, 1518 | cond_scale, guidance_start_steps, 1519 | class_cond_scale, class_guidance_start_steps, 1520 | generation_start_steps, sampling_timesteps, with_images, with_x0_images): 1521 | batch, device = shape[0], self.device 1522 | 1523 | if generation_start_steps > 0: 1524 | target_time = self.num_timesteps - generation_start_steps 1525 | t = torch.tensor([target_time]*batch, device=device).long() 1526 | img = self.q_sample(x_start=condition_x, t=t) 1527 | else: 1528 | img = torch.randn(shape, device = device) 1529 | 1530 | if with_images: 1531 | image_list = [] 1532 | image_list.append(img.clone().detach().cpu()) 1533 | 1534 | x_start = None 1535 | 1536 | if with_x0_images: 1537 | x0_image_list = [] 1538 | x0_image_list.append(img.clne().detach().cpu()) 1539 | 1540 | for i, t in enumerate(tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps)): 1541 | if i < generation_start_steps: 1542 | continue 1543 | if i < guidance_start_steps: 1544 | cur_cond_scale = 1.0 1545 | else: 1546 | cur_cond_scale = cond_scale 1547 | if i < class_guidance_start_steps: 1548 | cur_class_cond_scale = 1.0 1549 | else: 1550 | cur_class_cond_scale = class_cond_scale 1551 | img, x_start = self.p_sample(img, t, condition_x, class_label, cur_cond_scale, cur_class_cond_scale) 1552 | if with_images: 1553 | image_list.append(img.clone().detach().cpu()) 1554 | if with_x0_images: 1555 | x0_image_list.append(x_start.clone().detach().cpu()) 1556 | 1557 | if with_images: 1558 | if with_x0_images: 1559 | return self.unnormalize(img), image_list, x0_image_list 1560 | else: 1561 | return self.unnormalize(img), image_list 1562 | else: 1563 | return self.unnormalize(img) 1564 | 1565 | @torch.inference_mode() 1566 | def ddim_sample(self, shape, condition_x, class_label, 1567 | cond_scale, guidance_start_steps, 1568 | class_cond_scale, class_guidance_start_steps, 1569 | generation_start_steps, sampling_timesteps, with_images, with_x0_images): 1570 | # batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective 1571 | batch, device, total_timesteps, eta = shape[0], self.device, self.num_timesteps, self.ddim_sampling_eta 1572 | 1573 | times = torch.linspace(-1, total_timesteps - 1, steps = sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps 1574 | times = list(reversed(times.int().tolist())) 1575 | time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)] 1576 | 1577 | if generation_start_steps > 0: 1578 | target_time = time_pairs[generation_start_steps][0] 1579 | t = torch.tensor([target_time]*batch, device=device).long() 1580 | img = self.q_sample(x_start=condition_x, t=t) 1581 | else: 1582 | img = torch.randn(shape, device = device) 1583 | 1584 | if with_images: 1585 | image_list = [] 1586 | image_list.append(img.clone().detach().cpu()) 1587 | 1588 | x_start = None 1589 | 1590 | if with_x0_images: 1591 | x0_image_list = [] 1592 | x0_image_list.append(img.clone().detach().cpu()) 1593 | 1594 | for i, (time, time_next) in enumerate(tqdm(time_pairs, desc = 'sampling loop time step')): 1595 | if i < generation_start_steps: 1596 | continue 1597 | if i < guidance_start_steps: 1598 | cur_cond_scale = 1.0 1599 | else: 1600 | cur_cond_scale = cond_scale 1601 | if i < class_guidance_start_steps: 1602 | cur_class_cond_scale = 1.0 1603 | else: 1604 | cur_class_cond_scale = class_cond_scale 1605 | time_cond = torch.full((batch,), time, device = device, dtype = torch.long) 1606 | pred_noise, x_start, *_ = self.model_predictions(img, time_cond, condition_x, class_label, 1607 | cur_cond_scale, cur_class_cond_scale, 1608 | clip_x_start = True, rederive_pred_noise = True) 1609 | 1610 | if time_next < 0: 1611 | img = x_start 1612 | if with_images: 1613 | image_list.append(img.clone().detach().cpu()) 1614 | if with_x0_images: 1615 | x0_image_list.append(img.clone().detach().cpu()) 1616 | continue 1617 | 1618 | alpha = self.alphas_cumprod[time] 1619 | alpha_next = self.alphas_cumprod[time_next] 1620 | 1621 | sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt() 1622 | c = (1 - alpha_next - sigma ** 2).sqrt() 1623 | 1624 | noise = torch.randn_like(img) 1625 | 1626 | img = x_start * alpha_next.sqrt() + \ 1627 | c * pred_noise + \ 1628 | sigma * noise 1629 | 1630 | if with_images: 1631 | image_list.append(img.clone().detach().cpu()) 1632 | if with_x0_images: 1633 | x0_image_list.append(x_start.clone().detach().cpu()) 1634 | 1635 | if with_images: 1636 | if with_x0_images: 1637 | return self.unnormalize(img), image_list, x0_image_list 1638 | else: 1639 | return self.unnormalize(img), image_list 1640 | else: 1641 | return self.unnormalize(img) 1642 | 1643 | 1644 | @torch.inference_mode() 1645 | def sample(self, batch_size = 16, condition_x = None, class_label = None, 1646 | cond_scale = 1.0, guidance_start_steps = 0, 1647 | class_cond_scale = 1.0, class_guidance_start_steps = 0, 1648 | generation_start_steps = 0, 1649 | num_sample_steps = None, with_images = False, with_x0_images = False): 1650 | # image_size, channels = self.image_size, self.channels 1651 | sampling_timesteps = default(num_sample_steps, self.sampling_timesteps) 1652 | 1653 | _n, _c, h, w = condition_x.shape 1654 | condition_x = normalize_to_neg_one_to_one(condition_x) 1655 | sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample 1656 | return sample_fn((batch_size, self.channels, h, w), condition_x, class_label, 1657 | cond_scale, guidance_start_steps, 1658 | class_cond_scale, class_guidance_start_steps, 1659 | generation_start_steps, sampling_timesteps, with_images, with_x0_images) 1660 | 1661 | @property 1662 | def loss_fn(self): 1663 | if self.loss_type == 'l1': 1664 | return F.l1_loss 1665 | elif self.loss_type == 'l2': 1666 | return F.mse_loss 1667 | elif self.loss_type == 'smooth_l1': 1668 | return F.smooth_l1_loss 1669 | else: 1670 | raise ValueError(f'invalid loss type {self.loss_type}') 1671 | 1672 | def p_losses(self, x_start, t, class_label, condition_x, 1673 | noise = None, offset_noise_strength = None, clip_x_start=True): 1674 | b, c, h, w = x_start.shape 1675 | 1676 | noise = default(noise, lambda: torch.randn_like(x_start)) 1677 | 1678 | # offset noise - https://www.crosslabs.org/blog/diffusion-with-offset-noise 1679 | 1680 | offset_noise_strength = default(offset_noise_strength, self.offset_noise_strength) 1681 | 1682 | if offset_noise_strength > 0.: 1683 | offset_noise = torch.randn(x_start.shape[:2], device = self.device) 1684 | noise += offset_noise_strength * rearrange(offset_noise, 'b c -> b c 1 1') 1685 | 1686 | # noise sample 1687 | 1688 | x = self.q_sample(x_start = x_start, t = t, noise = noise) 1689 | 1690 | x_self_cond = condition_x 1691 | 1692 | # predict and take gradient step 1693 | 1694 | model_out = self.model(x, t, class_label, x_self_cond) 1695 | 1696 | if self.objective == 'pred_noise': 1697 | target = noise 1698 | elif self.objective == 'pred_x0': 1699 | target = x_start 1700 | elif self.objective == 'pred_v': 1701 | v = self.predict_v(x_start, t, noise) 1702 | target = v 1703 | else: 1704 | raise ValueError(f'unknown objective {self.objective}') 1705 | 1706 | loss = self.loss_fn(model_out, target, reduction = 'none') 1707 | loss = reduce(loss, 'b ... -> b (...)', 'mean') 1708 | current_loss_weight = extract(self.loss_weight, t, loss.shape) 1709 | loss = loss * current_loss_weight 1710 | loss = reduce(loss, 'b ... -> b', 'mean') 1711 | 1712 | return loss.mean() 1713 | 1714 | def forward(self, img, condition_x, class_label, *args, **kwargs): 1715 | b, c, h, w, device, img_size, = *img.shape, img.device, self.image_size 1716 | assert h == img_size and w == img_size, f'height and width of image must be {img_size}' 1717 | t = torch.randint(0, self.num_timesteps, (b,), device=device).long() 1718 | 1719 | img = self.normalize(img) 1720 | if torch.rand(1) < self.cond_drop_prob: 1721 | condition_x = None 1722 | else: 1723 | condition_x = self.normalize(condition_x) 1724 | 1725 | if torch.rand(1) < self.class_cond_drop_prob: 1726 | class_label = None 1727 | 1728 | return self.p_losses(img, t, class_label, condition_x, *args, **kwargs) 1729 | 1730 | 1731 | class ElucidatedDiffusionSR(ElucidatedDiffusion): 1732 | def set_seed(self, seed): 1733 | torch.cuda.manual_seed(seed) 1734 | 1735 | def __init__( 1736 | self, 1737 | net, 1738 | *, 1739 | image_size, 1740 | channels = 3, 1741 | num_sample_steps = 32, # number of sampling steps 1742 | sigma_min = 0.002, # min noise level 1743 | sigma_max = 80, # max noise level 1744 | sigma_data = 0.5, # standard deviation of data distribution 1745 | rho = 7, # controls the sampling schedule 1746 | P_mean = -1.2, # mean of log-normal distribution from which noise is drawn for training 1747 | P_std = 1.2, # standard deviation of log-normal distribution from which noise is drawn for training 1748 | S_churn = 80, # parameters for stochastic sampling - depends on dataset, Table 5 in apper 1749 | S_tmin = 0.05, 1750 | S_tmax = 50, 1751 | S_noise = 1.003, 1752 | cond_drop_prob = 0., 1753 | use_dpmpp_solver = False, 1754 | loss_type = 'l2' 1755 | ): 1756 | super().__init__( 1757 | net=net, 1758 | image_size=image_size, 1759 | channels=channels, 1760 | num_sample_steps=num_sample_steps, 1761 | sigma_min=sigma_min, sigma_max=sigma_max, sigma_data=sigma_data, 1762 | rho=rho, 1763 | P_mean=P_mean, P_std=P_std, 1764 | S_churn=S_churn, S_tmin=S_tmin, S_tmax=S_tmax, S_noise=S_noise 1765 | ) 1766 | #assert net.learned_sinusoidal_cond 1767 | assert net.random_or_learned_sinusoidal_cond 1768 | self.self_condition = net.self_condition 1769 | 1770 | self.net = net 1771 | 1772 | # image dimensions 1773 | 1774 | self.channels = channels 1775 | self.image_size = image_size 1776 | 1777 | # parameters 1778 | 1779 | self.sigma_min = sigma_min 1780 | self.sigma_max = sigma_max 1781 | self.sigma_data = sigma_data 1782 | 1783 | self.rho = rho 1784 | 1785 | self.P_mean = P_mean 1786 | self.P_std = P_std 1787 | 1788 | self.num_sample_steps = num_sample_steps # otherwise known as N in the paper 1789 | 1790 | self.S_churn = S_churn 1791 | self.S_tmin = S_tmin 1792 | self.S_tmax = S_tmax 1793 | self.S_noise = S_noise 1794 | 1795 | self.cond_drop_prob = cond_drop_prob 1796 | self.use_dpmpp_solver = use_dpmpp_solver 1797 | self.loss_type = loss_type 1798 | 1799 | def set_seed(self, seed): 1800 | torch.cuda.manual_seed(seed) 1801 | 1802 | def preconditioned_network_forward(self, noised_images, sigma, condition_x, cond_scale = 1.0, clamp = False): 1803 | batch, device = noised_images.shape[0], noised_images.device 1804 | 1805 | if isinstance(sigma, float): 1806 | sigma = torch.full((batch,), sigma, device = device) 1807 | 1808 | padded_sigma = rearrange(sigma, 'b -> b 1 1 1') 1809 | 1810 | net_out = self.net( 1811 | self.c_in(padded_sigma) * noised_images, 1812 | self.c_noise(sigma), 1813 | condition_x 1814 | ) 1815 | 1816 | out = self.c_skip(padded_sigma) * noised_images + self.c_out(padded_sigma) * net_out 1817 | 1818 | if cond_scale != 1.0: 1819 | null_out = self.net( 1820 | self.c_in(padded_sigma) * noised_images, 1821 | self.c_noise(sigma), 1822 | None 1823 | ) 1824 | 1825 | null_out = self.c_skip(padded_sigma) * noised_images + self.c_out(padded_sigma) * null_out 1826 | 1827 | out = null_out + (out - null_out) * cond_scale 1828 | 1829 | if clamp: 1830 | out = out.clamp(-1., 1.) 1831 | 1832 | return out 1833 | 1834 | @torch.inference_mode() 1835 | def get_noised_images(self, condition_x, target_step, num_sample_steps=None): 1836 | # Input condition_x that has been normalize_to_neg_one_to_one 1837 | num_sample_steps = default(num_sample_steps, self.num_sample_steps) 1838 | sigmas = self.sample_schedule(num_sample_steps) 1839 | n, _, _, _ = condition_x.shape 1840 | padded_sigmas = repeat(sigmas[target_step], ' -> b', b = n) 1841 | noise = torch.randn_like(condition_x) 1842 | noised_images = condition_x + padded_sigmas * noise # alphas are 1. in the paper 1843 | return noised_images 1844 | 1845 | @torch.inference_mode() 1846 | def sample(self, batch_size=16, condition_x=None, cond_scale=1.0, guidance_start_steps=0, 1847 | generation_start_steps=0, 1848 | num_sample_steps=None, clamp=True, with_images=False, with_x0_images=False, zero_init=False): 1849 | if self.use_dpmpp_solver: 1850 | return self.sample_using_dpmpp(batch_size, condition_x, cond_scale, guidance_start_steps, 1851 | generation_start_steps, num_sample_steps, clamp, with_images, with_x0_images, zero_init) 1852 | else: 1853 | return self.sample_org(batch_size, condition_x, cond_scale, guidance_start_steps, 1854 | generation_start_steps, num_sample_steps, clamp, with_images, with_x0_images, zero_init) 1855 | 1856 | @torch.inference_mode() 1857 | def sample_org(self, batch_size = 16, condition_x = None, cond_scale = 1.0, guidance_start_steps = 0, 1858 | generation_start_steps = 0, num_sample_steps = None, clamp = True, 1859 | with_images = False, with_x0_images = False, zero_init = False): 1860 | num_sample_steps = default(num_sample_steps, self.num_sample_steps) 1861 | 1862 | _n, _c, h, w = condition_x.shape 1863 | # image_size = h 1864 | shape = (batch_size, self.channels, h, w) 1865 | 1866 | condition_x = normalize_to_neg_one_to_one(condition_x) 1867 | 1868 | # get the schedule, which is returned as (sigma, gamma) tuple, and pair up with the next sigma and gamma 1869 | 1870 | sigmas = self.sample_schedule(num_sample_steps) 1871 | 1872 | gammas = torch.where( 1873 | (sigmas >= self.S_tmin) & (sigmas <= self.S_tmax), 1874 | min(self.S_churn / num_sample_steps, math.sqrt(2) - 1), 1875 | 0. 1876 | ) 1877 | 1878 | sigmas_and_gammas = list(zip(sigmas[:-1], sigmas[1:], gammas[:-1])) 1879 | 1880 | if generation_start_steps > 0: 1881 | images = self.get_noised_images(condition_x, generation_start_steps) 1882 | elif zero_init: 1883 | images = torch.zeros(shape, device = self.device) 1884 | else: 1885 | # images is noise at the beginning 1886 | init_sigma = sigmas[0] 1887 | images = init_sigma * torch.randn(shape, device = self.device) 1888 | 1889 | if with_images: 1890 | image_list = [] 1891 | image_list.append(images.clone().detach().cpu()) 1892 | 1893 | if with_x0_images: 1894 | x0_image_list = [] 1895 | x0_image_list.append(images.clone().detach().cpu()) 1896 | 1897 | # gradually denoise 1898 | 1899 | for i, (sigma, sigma_next, gamma) in enumerate(tqdm(sigmas_and_gammas, desc = 'sampling time step')): 1900 | if i < generation_start_steps: 1901 | continue 1902 | if i < guidance_start_steps: 1903 | cur_cond_scale = 1.0 1904 | else: 1905 | cur_cond_scale = cond_scale 1906 | sigma, sigma_next, gamma = map(lambda t: t.item(), (sigma, sigma_next, gamma)) 1907 | 1908 | eps = self.S_noise * torch.randn(shape, device = self.device) # stochastic sampling 1909 | 1910 | sigma_hat = sigma + gamma * sigma 1911 | images_hat = images + math.sqrt(sigma_hat ** 2 - sigma ** 2) * eps 1912 | 1913 | model_output = self.preconditioned_network_forward(images_hat, sigma_hat, condition_x, cur_cond_scale, clamp = clamp) 1914 | denoised_over_sigma = (images_hat - model_output) / sigma_hat 1915 | 1916 | images_next = images_hat + (sigma_next - sigma_hat) * denoised_over_sigma 1917 | 1918 | # second order correction, if not the last timestep 1919 | 1920 | if sigma_next != 0: 1921 | model_output_next = self.preconditioned_network_forward(images_next, sigma_next, condition_x, cur_cond_scale, clamp = clamp) 1922 | denoised_prime_over_sigma = (images_next - model_output_next) / sigma_next 1923 | images_next = images_hat + 0.5 * (sigma_next - sigma_hat) * (denoised_over_sigma + denoised_prime_over_sigma) 1924 | 1925 | images = images_next 1926 | 1927 | if with_images: 1928 | image_list.append(images.clone().detach().cpu()) 1929 | if with_x0_images: 1930 | if sigma_next != 0: 1931 | x0_image_list.append(denoised_prime_over_sigma.clone().detach().cpu()) 1932 | else: 1933 | x0_image_list.append(denoised_over_sigma.clone().detach().cpu()) 1934 | 1935 | images = images.clamp(-1., 1.) 1936 | 1937 | if with_images: 1938 | if with_x0_images: 1939 | return unnormalize_to_zero_to_one(images), image_list, x0_image_list 1940 | else: 1941 | return unnormalize_to_zero_to_one(images), image_list 1942 | else: 1943 | return unnormalize_to_zero_to_one(images) 1944 | 1945 | @torch.inference_mode() 1946 | def sample_using_dpmpp(self, batch_size = 16, condition_x = None, cond_scale = 1.0, guidance_start_steps = 0, 1947 | generation_start_steps = 0, num_sample_steps = None, clamp = True, 1948 | with_images = False, with_x0_images = False, zero_init = False): 1949 | """ 1950 | thanks to Katherine Crowson (https://github.com/crowsonkb) for figuring it all out! 1951 | https://arxiv.org/abs/2211.01095 1952 | """ 1953 | num_sample_steps = default(num_sample_steps, self.num_sample_steps) 1954 | 1955 | _n, _c, h, w = condition_x.shape 1956 | # image_size = h 1957 | shape = (batch_size, self.channels, h, w) 1958 | 1959 | condition_x = normalize_to_neg_one_to_one(condition_x) 1960 | 1961 | sigmas = self.sample_schedule(num_sample_steps) 1962 | 1963 | if generation_start_steps > 0: 1964 | images = self.get_noised_images(condition_x, generation_start_steps) 1965 | elif zero_init: 1966 | images = torch.zeros(shape, device = self.device) 1967 | else: 1968 | images = sigmas[0] * torch.randn(shape, device = self.device) 1969 | 1970 | if with_images: 1971 | image_list = [] 1972 | image_list.append(images.clone().detach().cpu()) 1973 | 1974 | if with_x0_images: 1975 | x0_image_list = [] 1976 | x0_image_list.append(images.clone().detach().cpu()) 1977 | 1978 | sigma_fn = lambda t: t.neg().exp() 1979 | t_fn = lambda sigma: sigma.log().neg() 1980 | 1981 | old_denoised = None 1982 | for i in tqdm(range(len(sigmas) - 1)): 1983 | if i < generation_start_steps: 1984 | continue 1985 | if i < guidance_start_steps: 1986 | cur_cond_scale = 1.0 1987 | else: 1988 | cur_cond_scale = cond_scale 1989 | denoised = self.preconditioned_network_forward(images, sigmas[i].item(), condition_x, cur_cond_scale, clamp = clamp) 1990 | t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1]) 1991 | h = t_next - t 1992 | 1993 | if not exists(old_denoised) or sigmas[i + 1] == 0: 1994 | denoised_d = denoised 1995 | else: 1996 | h_last = t - t_fn(sigmas[i - 1]) 1997 | r = h_last / h 1998 | gamma = - 1 / (2 * r) 1999 | denoised_d = (1 - gamma) * denoised + gamma * old_denoised 2000 | 2001 | images = (sigma_fn(t_next) / sigma_fn(t)) * images - (-h).expm1() * denoised_d 2002 | old_denoised = denoised 2003 | 2004 | if with_images: 2005 | image_list.append(images.clone().detach().cpu()) 2006 | if with_x0_images: 2007 | x0_image_list.append(denoised_d.clone().detach().cpu()) 2008 | 2009 | images = images.clamp(-1., 1.) 2010 | 2011 | if with_images: 2012 | if with_x0_images: 2013 | return unnormalize_to_zero_to_one(images), image_list, x0_image_list 2014 | else: 2015 | return unnormalize_to_zero_to_one(images), image_list 2016 | else: 2017 | return unnormalize_to_zero_to_one(images) 2018 | 2019 | @property 2020 | def loss_fn(self): 2021 | if self.loss_type == 'l1': 2022 | return F.l1_loss 2023 | elif self.loss_type == 'l2': 2024 | return F.mse_loss 2025 | elif self.loss_type == 'smooth_l1': 2026 | return F.smooth_l1_loss 2027 | else: 2028 | raise ValueError(f'invalid loss type {self.loss_type}') 2029 | 2030 | def forward(self, images, condition_x): 2031 | batch_size, c, h, w, device, image_size, channels = *images.shape, images.device, self.image_size, self.channels 2032 | 2033 | assert h == image_size and w == image_size, f'height and width of image must be {image_size}' 2034 | assert c == channels, 'mismatch of image channels' 2035 | 2036 | images = normalize_to_neg_one_to_one(images) 2037 | if torch.randn(1) < self.cond_drop_prob: 2038 | condition_x = None 2039 | else: 2040 | condition_x = normalize_to_neg_one_to_one(condition_x) 2041 | 2042 | sigmas = self.noise_distribution(batch_size) 2043 | padded_sigmas = rearrange(sigmas, 'b -> b 1 1 1') 2044 | 2045 | noise = torch.randn_like(images) 2046 | 2047 | noised_images = images + padded_sigmas * noise # alphas are 1. in the paper 2048 | 2049 | denoised = self.preconditioned_network_forward(noised_images, sigmas, condition_x, cond_scale=1.0) 2050 | 2051 | losses = self.loss_fn(denoised, images, reduction = 'none') 2052 | losses = reduce(losses, 'b ... -> b', 'mean') 2053 | 2054 | losses = losses * self.loss_weight(sigmas) 2055 | 2056 | return losses.mean() 2057 | 2058 | 2059 | class ConditionalElucidatedDiffusionSR(ElucidatedDiffusion): 2060 | def set_seed(self, seed): 2061 | torch.cuda.manual_seed(seed) 2062 | 2063 | def __init__( 2064 | self, 2065 | net, 2066 | *, 2067 | image_size, 2068 | channels = 3, 2069 | num_sample_steps = 32, # number of sampling steps 2070 | sigma_min = 0.002, # min noise level 2071 | sigma_max = 80, # max noise level 2072 | sigma_data = 0.5, # standard deviation of data distribution 2073 | rho = 7, # controls the sampling schedule 2074 | P_mean = -1.2, # mean of log-normal distribution from which noise is drawn for training 2075 | P_std = 1.2, # standard deviation of log-normal distribution from which noise is drawn for training 2076 | S_churn = 80, # parameters for stochastic sampling - depends on dataset, Table 5 in apper 2077 | S_tmin = 0.05, 2078 | S_tmax = 50, 2079 | S_noise = 1.003, 2080 | cond_drop_prob = 0., 2081 | class_cond_drop_prob = 0., 2082 | use_dpmpp_solver = False, 2083 | loss_type = 'l2' 2084 | ): 2085 | super().__init__( 2086 | net=net, 2087 | image_size=image_size, 2088 | channels=channels, 2089 | num_sample_steps=num_sample_steps, 2090 | sigma_min=sigma_min, sigma_max=sigma_max, sigma_data=sigma_data, 2091 | rho=rho, 2092 | P_mean=P_mean, P_std=P_std, 2093 | S_churn=S_churn, S_tmin=S_tmin, S_tmax=S_tmax, S_noise=S_noise 2094 | ) 2095 | #assert net.learned_sinusoidal_cond 2096 | assert net.random_or_learned_sinusoidal_cond 2097 | self.self_condition = net.self_condition 2098 | 2099 | self.net = net 2100 | 2101 | # image dimensions 2102 | 2103 | self.channels = channels 2104 | self.image_size = image_size 2105 | 2106 | # parameters 2107 | 2108 | self.sigma_min = sigma_min 2109 | self.sigma_max = sigma_max 2110 | self.sigma_data = sigma_data 2111 | 2112 | self.rho = rho 2113 | 2114 | self.P_mean = P_mean 2115 | self.P_std = P_std 2116 | 2117 | self.num_sample_steps = num_sample_steps # otherwise known as N in the paper 2118 | 2119 | self.S_churn = S_churn 2120 | self.S_tmin = S_tmin 2121 | self.S_tmax = S_tmax 2122 | self.S_noise = S_noise 2123 | 2124 | self.cond_drop_prob = cond_drop_prob 2125 | self.class_cond_drop_prob = class_cond_drop_prob 2126 | self.use_dpmpp_solver = use_dpmpp_solver 2127 | self.loss_type = loss_type 2128 | 2129 | def set_seed(self, seed): 2130 | torch.cuda.manual_seed(seed) 2131 | 2132 | def preconditioned_network_forward(self, noised_images, sigma, condition_x, class_label, 2133 | cond_scale = 1.0, class_cond_scale = 1.0, clamp = False): 2134 | batch, device = noised_images.shape[0], noised_images.device 2135 | 2136 | if isinstance(sigma, float): 2137 | sigma = torch.full((batch,), sigma, device = device) 2138 | 2139 | padded_sigma = rearrange(sigma, 'b -> b 1 1 1') 2140 | 2141 | net_out = self.net( 2142 | self.c_in(padded_sigma) * noised_images, 2143 | self.c_noise(sigma), 2144 | class_label, 2145 | condition_x 2146 | ) 2147 | 2148 | out = self.c_skip(padded_sigma) * noised_images + self.c_out(padded_sigma) * net_out 2149 | 2150 | # Currently not supported for CFG with both condition_x and class_label 2151 | if (cond_scale != 1.0) and (class_cond_scale != 1.0): 2152 | raise NotImplementedError("Currently, you cannot specify both cond_scale and class_cond_scale at the same time.") 2153 | 2154 | # CFG by condition_x 2155 | if cond_scale != 1.0: 2156 | null_out = self.net( 2157 | self.c_in(padded_sigma) * noised_images, 2158 | self.c_noise(sigma), 2159 | class_label, 2160 | None 2161 | ) 2162 | 2163 | null_out = self.c_skip(padded_sigma) * noised_images + self.c_out(padded_sigma) * null_out 2164 | 2165 | out = null_out + (out - null_out) * cond_scale 2166 | 2167 | # CFG by class_label 2168 | if class_cond_scale != 1.0: 2169 | null_out = self.net( 2170 | self.c_in(padded_sigma) * noised_images, 2171 | self.c_noise(sigma), 2172 | None, 2173 | condition_x 2174 | ) 2175 | 2176 | null_out = self.c_skip(padded_sigma) * noised_images + self.c_out(padded_sigma) * null_out 2177 | 2178 | out = null_out + (out - null_out) * class_cond_scale 2179 | 2180 | if clamp: 2181 | out = out.clamp(-1., 1.) 2182 | 2183 | return out 2184 | 2185 | @torch.inference_mode() 2186 | def get_noised_images(self, condition_x, target_step, num_sample_steps=None): 2187 | # Input condition_x that has been normalize_to_neg_one_to_one 2188 | num_sample_steps = default(num_sample_steps, self.num_sample_steps) 2189 | sigmas = self.sample_schedule(num_sample_steps) 2190 | n, _, _, _ = condition_x.shape 2191 | padded_sigmas = repeat(sigmas[target_step], ' -> b', b = n) 2192 | noise = torch.randn_like(condition_x) 2193 | noised_images = condition_x + padded_sigmas * noise # alphas are 1. in the paper 2194 | return noised_images 2195 | 2196 | @torch.inference_mode() 2197 | def sample(self, batch_size=16, condition_x=None, class_label=None, cond_scale=1.0, guidance_start_steps=0, 2198 | class_cond_scale=1.0, class_guidance_start_steps=0, generation_start_steps=0, 2199 | num_sample_steps=None, clamp=True, with_images=False, with_x0_images=False, zero_init=False): 2200 | if self.use_dpmpp_solver: 2201 | return self.sample_using_dpmpp(batch_size, condition_x, class_label, 2202 | cond_scale, guidance_start_steps, 2203 | class_cond_scale, class_guidance_start_steps, 2204 | generation_start_steps, num_sample_steps, clamp, with_images, with_x0_images, zero_init) 2205 | else: 2206 | return self.sample_org(batch_size, condition_x, class_label, 2207 | cond_scale, guidance_start_steps, 2208 | class_cond_scale, class_guidance_start_steps, 2209 | generation_start_steps, num_sample_steps, clamp, with_images, with_x0_images, zero_init) 2210 | 2211 | @torch.inference_mode() 2212 | def sample_org(self, batch_size = 16, condition_x = None, class_label = None, 2213 | cond_scale = 1.0, guidance_start_steps = 0, 2214 | class_cond_scale = 1.0, class_guidance_start_steps = 0, 2215 | generation_start_steps = 0, num_sample_steps = None, clamp = True, 2216 | with_images = False, with_x0_images = False, zero_init = False): 2217 | num_sample_steps = default(num_sample_steps, self.num_sample_steps) 2218 | 2219 | _n, _c, h, w = condition_x.shape 2220 | # image_size = h 2221 | shape = (batch_size, self.channels, h, w) 2222 | 2223 | condition_x = normalize_to_neg_one_to_one(condition_x) 2224 | 2225 | # get the schedule, which is returned as (sigma, gamma) tuple, and pair up with the next sigma and gamma 2226 | 2227 | sigmas = self.sample_schedule(num_sample_steps) 2228 | 2229 | gammas = torch.where( 2230 | (sigmas >= self.S_tmin) & (sigmas <= self.S_tmax), 2231 | min(self.S_churn / num_sample_steps, math.sqrt(2) - 1), 2232 | 0. 2233 | ) 2234 | 2235 | sigmas_and_gammas = list(zip(sigmas[:-1], sigmas[1:], gammas[:-1])) 2236 | 2237 | if generation_start_steps > 0: 2238 | images = self.get_noised_images(condition_x, generation_start_steps) 2239 | elif zero_init: 2240 | images = torch.zeros(shape, device = self.device) 2241 | else: 2242 | # images is noise at the beginning 2243 | init_sigma = sigmas[0] 2244 | images = init_sigma * torch.randn(shape, device = self.device) 2245 | 2246 | if with_images: 2247 | image_list = [] 2248 | image_list.append(images.clone().detach().cpu()) 2249 | 2250 | if with_x0_images: 2251 | x0_image_list = [] 2252 | x0_image_list.append(images.clone().detach().cpu()) 2253 | 2254 | # gradually denoise 2255 | 2256 | for i, (sigma, sigma_next, gamma) in enumerate(tqdm(sigmas_and_gammas, desc = 'sampling time step')): 2257 | if i < generation_start_steps: 2258 | continue 2259 | if i < guidance_start_steps: 2260 | cur_cond_scale = 1.0 2261 | else: 2262 | cur_cond_scale = cond_scale 2263 | if i < class_guidance_start_steps: 2264 | cur_class_cond_scale = 1.0 2265 | else: 2266 | cur_class_cond_scale = class_cond_scale 2267 | sigma, sigma_next, gamma = map(lambda t: t.item(), (sigma, sigma_next, gamma)) 2268 | 2269 | eps = self.S_noise * torch.randn(shape, device = self.device) # stochastic sampling 2270 | 2271 | sigma_hat = sigma + gamma * sigma 2272 | images_hat = images + math.sqrt(sigma_hat ** 2 - sigma ** 2) * eps 2273 | 2274 | model_output = self.preconditioned_network_forward(images_hat, sigma_hat, condition_x, class_label, 2275 | cur_cond_scale, cur_class_cond_scale, clamp = clamp) 2276 | denoised_over_sigma = (images_hat - model_output) / sigma_hat 2277 | 2278 | images_next = images_hat + (sigma_next - sigma_hat) * denoised_over_sigma 2279 | 2280 | # second order correction, if not the last timestep 2281 | 2282 | if sigma_next != 0: 2283 | model_output_next = self.preconditioned_network_forward(images_next, sigma_next, condition_x, class_label, 2284 | cur_cond_scale, cur_class_cond_scale, clamp = clamp) 2285 | denoised_prime_over_sigma = (images_next - model_output_next) / sigma_next 2286 | images_next = images_hat + 0.5 * (sigma_next - sigma_hat) * (denoised_over_sigma + denoised_prime_over_sigma) 2287 | 2288 | images = images_next 2289 | 2290 | if with_images: 2291 | image_list.append(images.clone().detach().cpu()) 2292 | if with_x0_images: 2293 | if sigma_next != 0: 2294 | x0_image_list.append(denoised_prime_over_sigma.clone().detach().cpu()) 2295 | else: 2296 | x0_image_list.append(denoised_over_sigma.clone().detach().cpu()) 2297 | 2298 | images = images.clamp(-1., 1.) 2299 | 2300 | if with_images: 2301 | if with_x0_images: 2302 | return unnormalize_to_zero_to_one(images), image_list, x0_image_list 2303 | else: 2304 | return unnormalize_to_zero_to_one(images), image_list 2305 | else: 2306 | return unnormalize_to_zero_to_one(images) 2307 | 2308 | @torch.inference_mode() 2309 | def tiled_sample(self, batch_size=4, tile_size=256, tile_stride=256, 2310 | condition_x=None, class_label=None, 2311 | cond_scale=1.0, guidance_start_steps=0, 2312 | class_cond_scale=1.0, class_guidance_start_steps=0, 2313 | generation_start_steps=0, num_sample_steps=None, clamp = True, zero_init = False, 2314 | with_images=False, with_x0_images=False, start_white_noise=True, amp=False): 2315 | 2316 | num_sample_steps = default(num_sample_steps, self.num_sample_steps) 2317 | condition_x = normalize_to_neg_one_to_one(condition_x) 2318 | 2319 | batch, c, h, w = condition_x.shape 2320 | 2321 | # pad condition_x 2322 | coord, pad = get_coord_and_pad(h, w) 2323 | left, top, right, bottom = coord 2324 | condition_x = F.pad(condition_x, pad, mode='reflect') 2325 | 2326 | # shape = (batch_size, self.channels, h, w) 2327 | shape = condition_x.shape 2328 | 2329 | # get the schedule, which is returned as (sigma, gamma) tuple, and pair up with the next sigma and gamma 2330 | 2331 | sigmas = self.sample_schedule(num_sample_steps) 2332 | 2333 | gammas = torch.where( 2334 | (sigmas >= self.S_tmin) & (sigmas <= self.S_tmax), 2335 | min(self.S_churn / num_sample_steps, math.sqrt(2) - 1), 2336 | 0. 2337 | ) 2338 | 2339 | sigmas_and_gammas = list(zip(sigmas[:-1], sigmas[1:], gammas[:-1])) 2340 | 2341 | if generation_start_steps > 0: 2342 | images = self.get_noised_images(condition_x, generation_start_steps) 2343 | elif zero_init: 2344 | images = torch.zeros(shape, device = self.device) 2345 | else: 2346 | # images is noise at the beginning 2347 | init_sigma = sigmas[0] 2348 | images = init_sigma * torch.randn(shape, device = self.device) 2349 | 2350 | if with_images: 2351 | image_list = [] 2352 | image_list.append(images[:,:,top:bottom,left:right].clone().detach().cpu()) 2353 | 2354 | if with_x0_images: 2355 | x0_image_list = [] 2356 | x0_image_list.append(images[:,:,top:bottom,left:right].clone().detach().cpu()) 2357 | 2358 | # Pre-calculate tile regions 2359 | _, _, height, width = condition_x.shape 2360 | coords0 = get_coords(height, width, tile_size, tile_size, diff=0) 2361 | if height <= tile_size and width <= tile_size: 2362 | coords1 = get_coords(height, width, tile_size, tile_stride, diff=0) 2363 | else: 2364 | coords1 = get_coords(height-tile_size, width-tile_size, tile_size, tile_stride, diff=tile_size//2) 2365 | coord_list = [coords0, coords1] 2366 | 2367 | # Get the region of the smaller coords 2368 | small_coord, small_pad = get_area(coords1, height, width) 2369 | sleft, stop, sright, sbottom = small_coord 2370 | 2371 | # Pad the outside of the smaller region of condition_x with 0 2372 | cropped_condition_x = condition_x[:,:,stop:sbottom,sleft:sright] 2373 | condition_x = F.pad(cropped_condition_x, small_pad, mode='constant', value=0) 2374 | 2375 | # gradually denoise 2376 | 2377 | x_start = images.clone() 2378 | 2379 | for i, (sigma, sigma_next, gamma) in enumerate(tqdm(sigmas_and_gammas, desc = 'sampling time step')): 2380 | if i < generation_start_steps: 2381 | continue 2382 | if i < guidance_start_steps: 2383 | cur_cond_scale = 1.0 2384 | else: 2385 | cur_cond_scale = cond_scale 2386 | if i < class_guidance_start_steps: 2387 | cur_class_cond_scale = 1.0 2388 | else: 2389 | cur_class_cond_scale = class_cond_scale 2390 | sigma, sigma_next, gamma = map(lambda t: t.item(), (sigma, sigma_next, gamma)) 2391 | 2392 | eps = self.S_noise * torch.randn(shape, device = self.device) # stochastic sampling 2393 | 2394 | sigma_hat = sigma + gamma * sigma 2395 | images_hat = images + math.sqrt(sigma_hat ** 2 - sigma ** 2) * eps 2396 | 2397 | cur_coords = coord_list[i%2] 2398 | 2399 | minibatch_index = 0 2400 | minibatch = torch.zeros((batch_size, c, tile_size, tile_size), device=condition_x.device) 2401 | minibatch_condition = torch.zeros((batch_size, c, tile_size, tile_size), device=condition_x.device) 2402 | output_indexes = [None] * batch_size 2403 | for hs, he, ws, we in cur_coords: 2404 | minibatch[minibatch_index] = images_hat[:, :, hs:he, ws:we] 2405 | minibatch_condition[minibatch_index] = condition_x[:, :, hs:he, ws:we] 2406 | output_indexes[minibatch_index] = (hs, ws) 2407 | minibatch_index += 1 2408 | 2409 | if minibatch_index == batch_size: 2410 | with autocast(enabled=amp): 2411 | tile_out = self.preconditioned_network_forward(minibatch, sigma_hat, minibatch_condition, class_label, 2412 | cur_cond_scale, cur_class_cond_scale, clamp=clamp) 2413 | denoised_over_sigma = (minibatch - tile_out) / sigma_hat 2414 | images_next = minibatch + (sigma_next - sigma_hat) * denoised_over_sigma 2415 | # second order correction, if not the last timestep 2416 | if sigma_next != 0: 2417 | tile_out_next = self.preconditioned_network_forward(images_next, sigma_next, minibatch_condition, class_label, 2418 | cur_cond_scale, cur_class_cond_scale, clamp=clamp) 2419 | denoised_prime_over_sigma = (images_next - tile_out_next) / sigma_next 2420 | images_next = minibatch + 0.5 * (sigma_next - sigma_hat) * (denoised_over_sigma + denoised_prime_over_sigma) 2421 | 2422 | for k in range(minibatch_index): 2423 | hs, ws = output_indexes[k] 2424 | images[:,:, hs:hs+tile_size, ws:ws+tile_size] = images_next[k] 2425 | if sigma_next != 0: 2426 | x_start[:, :, hs:hs+tile_size, ws:ws+tile_size] = denoised_prime_over_sigma[k] 2427 | else: 2428 | x_start[:, :, hs:hs+tile_size, ws:ws+tile_size] = denoised_over_sigma[k] 2429 | minibatch_index = 0 2430 | 2431 | if minibatch_index > 0: 2432 | with autocast(enabled=amp): 2433 | tile_out = self.preconditioned_network_forward(minibatch[0:minibatch_index], sigma_hat, 2434 | minibatch_condition[0:minibatch_index], class_label, 2435 | cur_cond_scale, cur_class_cond_scale, clamp=clamp) 2436 | denoised_over_sigma = (minibatch[0:minibatch_index] - tile_out) / sigma_hat 2437 | images_next = minibatch[0:minibatch_index] + (sigma_next - sigma_hat) * denoised_over_sigma 2438 | # second order correction, if not the last timestep 2439 | if sigma_next != 0: 2440 | tile_out_next = self.preconditioned_network_forward(images_next, sigma_next, 2441 | minibatch_condition[0:minibatch_index], class_label, 2442 | cur_cond_scale, cur_class_cond_scale, clamp=clamp) 2443 | denoised_prime_over_sigma = (images_next - tile_out_next) / sigma_next 2444 | images_next = minibatch[0:minibatch_index] + 0.5 * (sigma_next - sigma_hat) * (denoised_over_sigma + denoised_prime_over_sigma) 2445 | 2446 | for k in range(minibatch_index): 2447 | hs, ws = output_indexes[k] 2448 | images[:,:, hs:hs+tile_size, ws:ws+tile_size] = images_next[k] 2449 | if sigma_next != 0: 2450 | x_start[:, :, hs:hs+tile_size, ws:ws+tile_size] = denoised_prime_over_sigma[k] 2451 | else: 2452 | x_start[:, :, hs:hs+tile_size, ws:ws+tile_size] = denoised_over_sigma[k] 2453 | 2454 | if i%2 == 1: 2455 | # Reconstruct by removing the padding part of img when odd times 2456 | cropped_img = images[:,:,stop:sbottom,sleft:sright] 2457 | images = self.get_noised_images(torch.zeros_like(condition_x), i) 2458 | images[:,:,stop:sbottom,sleft:sright] = cropped_img 2459 | 2460 | if with_images: 2461 | image_list.append(images.clone().detach().cpu()) 2462 | if with_x0_images: 2463 | x0_image_list.append(x_start.clone().detach().cpu()) 2464 | 2465 | images = images[:,:,top:bottom,left:right] 2466 | images = images.clamp(-1., 1.) 2467 | images = unnormalize_to_zero_to_one(images) 2468 | 2469 | if with_images: 2470 | if with_x0_images: 2471 | return images, image_list, x0_image_list 2472 | else: 2473 | return images, image_list 2474 | else: 2475 | return images 2476 | 2477 | 2478 | @torch.inference_mode() 2479 | def sample_using_dpmpp(self, batch_size = 16, condition_x = None, class_label = None, 2480 | cond_scale = 1.0, guidance_start_steps = 0, 2481 | class_cond_scale = 1.0, class_guidance_start_steps = 0, 2482 | generation_start_steps = 0, num_sample_steps = None, clamp = True, 2483 | with_images = False, with_x0_images = False, zero_init = False): 2484 | """ 2485 | thanks to Katherine Crowson (https://github.com/crowsonkb) for figuring it all out! 2486 | https://arxiv.org/abs/2211.01095 2487 | """ 2488 | num_sample_steps = default(num_sample_steps, self.num_sample_steps) 2489 | 2490 | _n, _c, h, w = condition_x.shape 2491 | # image_size = h 2492 | shape = (batch_size, self.channels, h, w) 2493 | 2494 | condition_x = normalize_to_neg_one_to_one(condition_x) 2495 | 2496 | sigmas = self.sample_schedule(num_sample_steps) 2497 | 2498 | if generation_start_steps > 0: 2499 | images = self.get_noised_images(condition_x, generation_start_steps) 2500 | elif zero_init: 2501 | images = torch.zeros(shape, device = self.device) 2502 | else: 2503 | images = sigmas[0] * torch.randn(shape, device = self.device) 2504 | 2505 | if with_images: 2506 | image_list = [] 2507 | image_list.append(images.clone().detach().cpu()) 2508 | 2509 | if with_x0_images: 2510 | x0_image_list = [] 2511 | x0_image_list.append(images.clone().detach().cpu()) 2512 | 2513 | sigma_fn = lambda t: t.neg().exp() 2514 | t_fn = lambda sigma: sigma.log().neg() 2515 | 2516 | old_denoised = None 2517 | for i in tqdm(range(len(sigmas) - 1)): 2518 | if i < generation_start_steps: 2519 | continue 2520 | if i < guidance_start_steps: 2521 | cur_cond_scale = 1.0 2522 | else: 2523 | cur_cond_scale = cond_scale 2524 | if i < class_guidance_start_steps: 2525 | cur_class_cond_scale = 1.0 2526 | else: 2527 | cur_class_cond_scale = class_cond_scale 2528 | denoised = self.preconditioned_network_forward(images, sigmas[i].item(), condition_x, class_label, 2529 | cur_cond_scale, cur_class_cond_scale, clamp = clamp) 2530 | t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1]) 2531 | h = t_next - t 2532 | 2533 | if not exists(old_denoised) or sigmas[i + 1] == 0: 2534 | denoised_d = denoised 2535 | else: 2536 | h_last = t - t_fn(sigmas[i - 1]) 2537 | r = h_last / h 2538 | gamma = - 1 / (2 * r) 2539 | denoised_d = (1 - gamma) * denoised + gamma * old_denoised 2540 | 2541 | images = (sigma_fn(t_next) / sigma_fn(t)) * images - (-h).expm1() * denoised_d 2542 | old_denoised = denoised 2543 | 2544 | if with_images: 2545 | image_list.append(images.clone().detach().cpu()) 2546 | if with_x0_images: 2547 | x0_image_list.append(denoised_d.clone().detach().cpu()) 2548 | 2549 | images = images.clamp(-1., 1.) 2550 | 2551 | if with_images: 2552 | if with_x0_images: 2553 | return unnormalize_to_zero_to_one(images), image_list, x0_image_list 2554 | else: 2555 | return unnormalize_to_zero_to_one(images), image_list 2556 | else: 2557 | return unnormalize_to_zero_to_one(images) 2558 | 2559 | @property 2560 | def loss_fn(self): 2561 | if self.loss_type == 'l1': 2562 | return F.l1_loss 2563 | elif self.loss_type == 'l2': 2564 | return F.mse_loss 2565 | elif self.loss_type == 'smooth_l1': 2566 | return F.smooth_l1_loss 2567 | else: 2568 | raise ValueError(f'invalid loss type {self.loss_type}') 2569 | 2570 | def forward(self, images, condition_x, class_label): 2571 | batch_size, c, h, w, device, image_size, channels = *images.shape, images.device, self.image_size, self.channels 2572 | 2573 | assert h == image_size and w == image_size, f'height and width of image must be {image_size}' 2574 | assert c == channels, 'mismatch of image channels' 2575 | 2576 | images = normalize_to_neg_one_to_one(images) 2577 | if torch.randn(1) < self.cond_drop_prob: 2578 | condition_x = None 2579 | else: 2580 | condition_x = normalize_to_neg_one_to_one(condition_x) 2581 | 2582 | if torch.randn(1) < self.class_cond_drop_prob: 2583 | class_label = None 2584 | 2585 | sigmas = self.noise_distribution(batch_size) 2586 | padded_sigmas = rearrange(sigmas, 'b -> b 1 1 1') 2587 | 2588 | noise = torch.randn_like(images) 2589 | 2590 | noised_images = images + padded_sigmas * noise # alphas are 1. in the paper 2591 | 2592 | denoised = self.preconditioned_network_forward(noised_images, sigmas, condition_x, class_label, 2593 | cond_scale=1.0, class_cond_scale=1.0) 2594 | 2595 | losses = self.loss_fn(denoised, images, reduction = 'none') 2596 | losses = reduce(losses, 'b ... -> b', 'mean') 2597 | 2598 | losses = losses * self.loss_weight(sigmas) 2599 | 2600 | return losses.mean() 2601 | 2602 | 2603 | 2604 | # neural net helpers 2605 | 2606 | class Residual(nn.Module): 2607 | def __init__(self, fn): 2608 | super().__init__() 2609 | self.fn = fn 2610 | 2611 | def forward(self, x): 2612 | return x + self.fn(x) 2613 | 2614 | class MonotonicLinear(nn.Module): 2615 | def __init__(self, *args, **kwargs): 2616 | super().__init__() 2617 | self.net = nn.Linear(*args, **kwargs) 2618 | 2619 | def forward(self, x): 2620 | return F.linear(x, self.net.weight.abs(), self.net.bias.abs()) 2621 | 2622 | # continuous schedules 2623 | 2624 | # equations are taken from https://openreview.net/attachment?id=2LdBqxc1Yv&name=supplementary_material 2625 | # @crowsonkb Katherine's repository also helped here https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/utils.py 2626 | 2627 | # log(snr) that approximates the original linear schedule 2628 | 2629 | def log(t, eps = 1e-20): 2630 | return torch.log(t.clamp(min = eps)) 2631 | 2632 | def beta_linear_log_snr(t): 2633 | return -log(expm1(1e-4 + 10 * (t ** 2))) 2634 | 2635 | def alpha_cosine_log_snr(t, s = 0.008): 2636 | return -log((torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** -2) - 1, eps = 1e-5) 2637 | 2638 | class learned_noise_schedule(nn.Module): 2639 | """ described in section H and then I.2 of the supplementary material for variational ddpm paper """ 2640 | 2641 | def __init__( 2642 | self, 2643 | *, 2644 | log_snr_max, 2645 | log_snr_min, 2646 | hidden_dim = 1024, 2647 | frac_gradient = 1. 2648 | ): 2649 | super().__init__() 2650 | self.slope = log_snr_min - log_snr_max 2651 | self.intercept = log_snr_max 2652 | 2653 | self.net = nn.Sequential( 2654 | Rearrange('... -> ... 1'), 2655 | MonotonicLinear(1, 1), 2656 | Residual(nn.Sequential( 2657 | MonotonicLinear(1, hidden_dim), 2658 | nn.Sigmoid(), 2659 | MonotonicLinear(hidden_dim, 1) 2660 | )), 2661 | Rearrange('... 1 -> ...'), 2662 | ) 2663 | 2664 | self.frac_gradient = frac_gradient 2665 | 2666 | def forward(self, x): 2667 | frac_gradient = self.frac_gradient 2668 | device = x.device 2669 | 2670 | out_zero = self.net(torch.zeros_like(x)) 2671 | out_one = self.net(torch.ones_like(x)) 2672 | 2673 | x = self.net(x) 2674 | 2675 | normed = self.slope * ((x - out_zero) / (out_one - out_zero)) + self.intercept 2676 | return normed * frac_gradient + normed.detach() * (1 - frac_gradient) 2677 | 2678 | 2679 | class ContinuousTimeGaussianDiffusionSR(nn.Module): 2680 | def set_seed(self, seed): 2681 | torch.cuda.manual_seed(seed) 2682 | 2683 | def __init__( 2684 | self, 2685 | model, 2686 | *, 2687 | image_size, 2688 | channels = 3, 2689 | noise_schedule = 'linear', 2690 | num_sample_steps = 500, 2691 | clip_sample_denoised = True, 2692 | learned_schedule_net_hidden_dim = 1024, 2693 | learned_noise_schedule_frac_gradient = 1., # between 0 and 1, determines what percentage of gradients go back, so one can update the learned noise schedule more slowly 2694 | min_snr_loss_weight = False, 2695 | min_snr_gamma = 5, 2696 | cond_drop_prob = 0., 2697 | loss_type = 'l2', 2698 | ): 2699 | super().__init__() 2700 | assert model.random_or_learned_sinusoidal_cond 2701 | #assert not model.self_condition, 'not supported yet' 2702 | 2703 | self.model = model 2704 | 2705 | # image dimensions 2706 | 2707 | self.channels = channels 2708 | self.image_size = image_size 2709 | 2710 | # continuous noise schedule related stuff 2711 | 2712 | if noise_schedule == 'linear': 2713 | self.log_snr = beta_linear_log_snr 2714 | elif noise_schedule == 'cosine': 2715 | self.log_snr = alpha_cosine_log_snr 2716 | elif noise_schedule == 'learned': 2717 | log_snr_max, log_snr_min = [beta_linear_log_snr(torch.tensor([time])).item() for time in (0., 1.)] 2718 | 2719 | self.log_snr = learned_noise_schedule( 2720 | log_snr_max = log_snr_max, 2721 | log_snr_min = log_snr_min, 2722 | hidden_dim = learned_schedule_net_hidden_dim, 2723 | frac_gradient = learned_noise_schedule_frac_gradient 2724 | ) 2725 | else: 2726 | raise ValueError(f'unknown noise schedule {noise_schedule}') 2727 | 2728 | # sampling 2729 | 2730 | self.num_sample_steps = num_sample_steps 2731 | self.clip_sample_denoised = clip_sample_denoised 2732 | 2733 | # proposed https://arxiv.org/abs/2303.09556 2734 | 2735 | self.min_snr_loss_weight = min_snr_loss_weight 2736 | self.min_snr_gamma = min_snr_gamma 2737 | 2738 | self.cond_drop_prob = cond_drop_prob 2739 | self.loss_type = loss_type 2740 | 2741 | @property 2742 | def device(self): 2743 | return next(self.model.parameters()).device 2744 | 2745 | def p_mean_variance(self, x, time, condition_x, cond_scale, time_next): 2746 | # reviewer found an error in the equation in the paper (missing sigma) 2747 | # following - https://openreview.net/forum?id=2LdBqxc1Yv¬eId=rIQgH0zKsRt 2748 | 2749 | log_snr = self.log_snr(time) 2750 | log_snr_next = self.log_snr(time_next) 2751 | c = -expm1(log_snr - log_snr_next) 2752 | 2753 | squared_alpha, squared_alpha_next = log_snr.sigmoid(), log_snr_next.sigmoid() 2754 | squared_sigma, squared_sigma_next = (-log_snr).sigmoid(), (-log_snr_next).sigmoid() 2755 | 2756 | alpha, sigma, alpha_next = map(sqrt, (squared_alpha, squared_sigma, squared_alpha_next)) 2757 | 2758 | batch_log_snr = repeat(log_snr, ' -> b', b = x.shape[0]) 2759 | if cond_scale == 1.0: 2760 | pred_noise = self.model(x, batch_log_snr, condition_x) 2761 | else: 2762 | cond_out = self.model(x, batch_log_snr, condition_x) 2763 | null_out = self.model(x, batch_log_snr, None) 2764 | pred_noise = null_out + (cond_out - null_out) * cond_scale 2765 | 2766 | x_start = (x - sigma * pred_noise) / alpha 2767 | 2768 | if self.clip_sample_denoised: 2769 | x_start.clamp_(-1., 1.) 2770 | model_mean = alpha_next * (x * (1 - c) / alpha + c * x_start) 2771 | else: 2772 | model_mean = alpha_next / alpha * (x - c * sigma * pred_noise) 2773 | 2774 | posterior_variance = squared_sigma_next * c 2775 | 2776 | return model_mean, posterior_variance, x_start 2777 | 2778 | # sampling related functions 2779 | 2780 | @torch.inference_mode() 2781 | def p_sample(self, x, time, condition_x, cond_scale, time_next): 2782 | batch, *_, device = *x.shape, x.device 2783 | 2784 | model_mean, model_variance, x_start = self.p_mean_variance(x = x, time = time, condition_x = condition_x, cond_scale = cond_scale, time_next = time_next) 2785 | 2786 | if time_next == 0: 2787 | return model_mean, x_start 2788 | 2789 | noise = torch.randn_like(x) 2790 | return model_mean + sqrt(model_variance) * noise, x_start 2791 | 2792 | def p_sample_loop(self, shape, condition_x, cond_scale, 2793 | guidance_start_steps, generation_start_steps, num_sample_steps, 2794 | with_images, with_x0_images): 2795 | batch = shape[0] 2796 | 2797 | if generation_start_steps > 0: 2798 | start_time = 1. - torch.tensor(generation_start_steps / num_sample_steps, device=condition_x.device) 2799 | start_times = repeat(start_time, ' -> b', b = batch) 2800 | img, _log_snr = self.q_sample(condition_x, start_times) 2801 | else: 2802 | img = torch.randn(shape, device = self.device) 2803 | 2804 | if with_images: 2805 | image_list = [] 2806 | image_list.append(img.clone().detach().cpu()) 2807 | 2808 | if with_x0_images: 2809 | x0_image_list = [] 2810 | x0_image_list.append(img.clone().detach().cpu()) 2811 | 2812 | steps = torch.linspace(1., 0., num_sample_steps + 1, device = self.device) 2813 | 2814 | for i in tqdm(range(num_sample_steps), desc = 'sampling loop time step', total = num_sample_steps): 2815 | if i < generation_start_steps: 2816 | continue 2817 | if i < guidance_start_steps: 2818 | cur_cond_scale = 1.0 2819 | else: 2820 | cur_cond_scale = cond_scale 2821 | times = steps[i] 2822 | times_next = steps[i + 1] 2823 | with torch.inference_mode(): 2824 | img, x_start = self.p_sample(img, times, condition_x, cur_cond_scale, times_next) 2825 | 2826 | if with_images: 2827 | image_list.append(img.clone().detach().cpu()) 2828 | if with_x0_images: 2829 | x0_image_list.append(x_start.clone().detach().cpu()) 2830 | 2831 | img = img.clamp(-1., 1.) 2832 | img = unnormalize_to_zero_to_one(img) 2833 | 2834 | if with_images: 2835 | if with_x0_images: 2836 | return img, image_list, x0_image_list 2837 | else: 2838 | return img, image_list 2839 | else: 2840 | return img 2841 | 2842 | @torch.inference_mode() 2843 | def tiled_sample(self, batch_size=4, tile_size=256, tile_stride=256, 2844 | condition_x=None, class_label=None, 2845 | cond_scale=1.0, guidance_start_steps=0, 2846 | class_cond_scale=1.0, class_guidance_start_steps=0, 2847 | generation_start_steps=0, num_sample_steps=None, 2848 | with_images=False, with_x0_images=False, start_white_noise=True, amp=False): 2849 | 2850 | num_sample_steps = default(num_sample_steps, self.num_sample_steps) 2851 | condition_x = normalize_to_neg_one_to_one(condition_x) 2852 | 2853 | batch, c, h, w = condition_x.shape 2854 | 2855 | # pad condition_x 2856 | coord, pad = get_coord_and_pad(h, w) 2857 | left, top, right, bottom = coord 2858 | condition_x = F.pad(condition_x, pad, mode='reflect') 2859 | 2860 | if generation_start_steps > 0: 2861 | start_time = 1. - torch.tensor(generation_start_steps / num_sample_steps, device=condition_x.device) 2862 | start_times = repeat(start_time, ' -> b', b = batch) 2863 | img, _log_snr = self.q_sample(condition_x, start_times) 2864 | else: 2865 | if start_white_noise: 2866 | img = torch.randn(condition_x.shape, device = self.device) 2867 | else: 2868 | start_time = torch.tensor(1., device=condition_x.device) 2869 | start_times = repeat(start_time, ' -> b', b = batch) 2870 | img, _log_snr = self.q_sample(condition_x, start_times) 2871 | 2872 | if with_images: 2873 | image_list = [] 2874 | image_list.append(img[:,:,top:bottom,left:right].clone().detach().cpu()) 2875 | 2876 | if with_x0_images: 2877 | x0_image_list = [] 2878 | x0_image_list.append(img[:,:,top:bottom,left:right].clone().detach().cpu()) 2879 | 2880 | steps = torch.linspace(1., 0., num_sample_steps + 1, device = self.device) 2881 | 2882 | # Pre-calculate tile regions 2883 | _, _, height, width = condition_x.shape 2884 | coords0 = get_coords(height, width, tile_size, tile_size, diff=0) 2885 | if height <= tile_size and width <= tile_size: 2886 | coords1 = get_coords(height, width, tile_size, tile_stride, diff=0) 2887 | else: 2888 | coords1 = get_coords(height-tile_size, width-tile_size, tile_size, tile_stride, diff=tile_size//2) 2889 | coord_list = [coords0, coords1] 2890 | 2891 | # Get the region of the smaller coords 2892 | small_coord, small_pad = get_area(coords1, height, width) 2893 | sleft, stop, sright, sbottom = small_coord 2894 | 2895 | # Pad the outside of the smaller region of condition_x with 0 2896 | cropped_condition_x = condition_x[:,:,stop:sbottom,sleft:sright] 2897 | condition_x = F.pad(cropped_condition_x, small_pad, mode='constant', value=0) 2898 | 2899 | x_start = img.clone() 2900 | 2901 | for i in tqdm(range(num_sample_steps), desc = 'sampling loop time step', total = num_sample_steps): 2902 | if i < generation_start_steps: 2903 | continue 2904 | if i < guidance_start_steps: 2905 | cur_cond_scale = 1.0 2906 | else: 2907 | cur_cond_scale = cond_scale 2908 | 2909 | times = steps[i] 2910 | times_next = steps[i + 1] 2911 | 2912 | cur_coords = coord_list[i%2] 2913 | 2914 | minibatch_index = 0 2915 | minibatch = torch.zeros((batch_size, c, tile_size, tile_size), device=condition_x.device) 2916 | minibatch_condition = torch.zeros((batch_size, c, tile_size, tile_size), device=condition_x.device) 2917 | # minibatch_mask = torch.zeros((batch_size, c, tile_size, tile_size), device=condition_x.device) 2918 | output_indexes = [None] * batch_size 2919 | for hs, he, ws, we in cur_coords: 2920 | minibatch[minibatch_index] = img[:, :, hs:he, ws:we] 2921 | minibatch_condition[minibatch_index] = condition_x[:, :, hs:he, ws:we] 2922 | output_indexes[minibatch_index] = (hs, ws) 2923 | minibatch_index += 1 2924 | 2925 | if minibatch_index == batch_size: 2926 | with autocast(enabled=amp): 2927 | tile_out, tile_x_start = self.p_sample(minibatch, times, minibatch_condition, cur_cond_scale, times_next) 2928 | for k in range(minibatch_index): 2929 | hs, ws = output_indexes[k] 2930 | img[:, :, hs:hs+tile_size, ws:ws+tile_size] = tile_out[k] 2931 | x_start[:, :, hs:hs+tile_size, ws:ws+tile_size] = tile_x_start[k] 2932 | minibatch_index = 0 2933 | 2934 | if minibatch_index > 0: 2935 | with autocast(enabled=amp): 2936 | tile_out, tile_x_start = self.p_sample(minibatch[0:minibatch_index], times, minibatch_condition[0:minibatch_index], cur_cond_scale, times_next) 2937 | for k in range(minibatch_index): 2938 | hs, ws = output_indexes[k] 2939 | img[:, :, hs:hs+tile_size, ws:ws+tile_size] = tile_out[k] 2940 | x_start[:, :, hs:hs+tile_size, ws:ws+tile_size] = tile_x_start[k] 2941 | 2942 | if i%2 == 1: 2943 | # Reconstruct by removing the padding part of img when odd times 2944 | cropped_img = img[:,:,stop:sbottom,sleft:sright] 2945 | img, _log_snr = self.q_sample(torch.zeros_like(condition_x), times_next) 2946 | img[:,:,stop:sbottom,sleft:sright] = cropped_img 2947 | 2948 | if with_images: 2949 | image_list.append(img.clone().detach().cpu()) 2950 | if with_x0_images: 2951 | x0_image_list.append(x_start.clone().detach().cpu()) 2952 | 2953 | 2954 | img = img[:,:,top:bottom,left:right] 2955 | img.clamp_(-1., 1.) 2956 | img = unnormalize_to_zero_to_one(img) 2957 | 2958 | if with_images: 2959 | if with_x0_images: 2960 | return img, image_list, x0_image_list 2961 | else: 2962 | return img, image_list 2963 | else: 2964 | return img 2965 | 2966 | def sample(self, batch_size = 16, condition_x = None, cond_scale = 1.0, 2967 | guidance_start_steps = 0, generation_start_steps = 0, num_sample_steps = None, 2968 | with_images=False, with_x0_images=False): 2969 | num_sample_steps = default(num_sample_steps, self.num_sample_steps) 2970 | condition_x = normalize_to_neg_one_to_one(condition_x) 2971 | return self.p_sample_loop((batch_size, self.channels, self.image_size, self.image_size), condition_x, 2972 | cond_scale, guidance_start_steps, generation_start_steps, num_sample_steps, 2973 | with_images, with_x0_images) 2974 | 2975 | # training related functions - noise prediction 2976 | 2977 | @autocast(enabled = False) 2978 | def q_sample(self, x_start, times, noise = None, return_alpha_sigma_sum=False): 2979 | noise = default(noise, lambda: torch.randn_like(x_start)) 2980 | 2981 | log_snr = self.log_snr(times) 2982 | 2983 | log_snr_padded = right_pad_dims_to(x_start, log_snr) 2984 | alpha, sigma = sqrt(log_snr_padded.sigmoid()), sqrt((-log_snr_padded).sigmoid()) 2985 | x_noised = x_start * alpha + noise * sigma 2986 | 2987 | if return_alpha_sigma_sum: 2988 | return x_noised, alpha+sigma 2989 | else: 2990 | return x_noised, log_snr 2991 | 2992 | def random_times(self, batch_size): 2993 | # times are now uniform from 0 to 1 2994 | return torch.zeros((batch_size,), device = self.device).float().uniform_(0, 1) 2995 | 2996 | @property 2997 | def loss_fn(self): 2998 | if self.loss_type == 'l1': 2999 | return F.l1_loss 3000 | elif self.loss_type == 'l2': 3001 | return F.mse_loss 3002 | elif self.loss_type == 'smooth_l1': 3003 | return F.smooth_l1_loss 3004 | else: 3005 | raise ValueError(f'invalid loss type {self.loss_type}') 3006 | 3007 | def p_losses(self, x_start, times, condition_x, 3008 | scaler=None, optimizer=None, 3009 | discriminator=None, disc_scaler=None, disc_optimizer=None, img_encoder=None, 3010 | hvd=None, tb_writer=None, global_step=None, 3011 | noise = None): 3012 | noise = default(noise, lambda: torch.randn_like(x_start)) 3013 | 3014 | batch_size, _, _, _ = x_start.shape 3015 | 3016 | x, log_snr = self.q_sample(x_start = x_start, times = times, noise = noise) 3017 | model_out = self.model(x, log_snr, condition_x) 3018 | 3019 | # diffusion loss 3020 | losses = self.loss_fn(model_out, noise, reduction = 'none') 3021 | losses = reduce(losses, 'b ... -> b', 'mean') 3022 | losses = losses.mean() 3023 | 3024 | # TODO: fix 3025 | if self.min_snr_loss_weight: 3026 | snr = log_snr.exp() 3027 | loss_weight = snr.clamp(min = self.min_snr_gamma) / snr 3028 | losses = losses * loss_weight 3029 | 3030 | return losses 3031 | 3032 | def forward(self, img, condition_x, 3033 | scaler=None, optimizer=None, 3034 | discriminator=None, disc_scaler=None, disc_optimizer=None, img_encoder=None, 3035 | hvd=None, tb_writer=None, global_step=None, 3036 | *args, **kwargs): 3037 | b, c, h, w, device, img_size, = *img.shape, img.device, self.image_size 3038 | assert h == img_size and w == img_size, f'height and width of image must be {img_size}' 3039 | 3040 | times = self.random_times(b) 3041 | img = normalize_to_neg_one_to_one(img) 3042 | condition_x = normalize_to_neg_one_to_one(condition_x) 3043 | 3044 | if torch.rand(1) < self.cond_drop_prob: 3045 | condition_x = None 3046 | 3047 | return self.p_losses(img, times, condition_x, 3048 | scaler, optimizer, 3049 | discriminator, disc_scaler, disc_optimizer, img_encoder, 3050 | hvd, tb_writer, global_step, 3051 | *args, **kwargs) 3052 | 3053 | 3054 | class ConditionalContinuousTimeGaussianDiffusionSR(nn.Module): 3055 | def set_seed(self, seed): 3056 | torch.cuda.manual_seed(seed) 3057 | 3058 | def __init__( 3059 | self, 3060 | model, 3061 | *, 3062 | image_size, 3063 | channels = 3, 3064 | noise_schedule = 'linear', 3065 | num_sample_steps = 500, 3066 | clip_sample_denoised = True, 3067 | learned_schedule_net_hidden_dim = 1024, 3068 | learned_noise_schedule_frac_gradient = 1., # between 0 and 1, determines what percentage of gradients go back, so one can update the learned noise schedule more slowly 3069 | min_snr_loss_weight = False, 3070 | min_snr_gamma = 5, 3071 | cond_drop_prob = 0., 3072 | class_cond_drop_prob = 0., 3073 | loss_type = 'l2', 3074 | ): 3075 | super().__init__() 3076 | assert model.random_or_learned_sinusoidal_cond 3077 | #assert not model.self_condition, 'not supported yet' 3078 | 3079 | self.model = model 3080 | 3081 | # image dimensions 3082 | 3083 | self.channels = channels 3084 | self.image_size = image_size 3085 | 3086 | # continuous noise schedule related stuff 3087 | 3088 | if noise_schedule == 'linear': 3089 | self.log_snr = beta_linear_log_snr 3090 | elif noise_schedule == 'cosine': 3091 | self.log_snr = alpha_cosine_log_snr 3092 | elif noise_schedule == 'learned': 3093 | log_snr_max, log_snr_min = [beta_linear_log_snr(torch.tensor([time])).item() for time in (0., 1.)] 3094 | 3095 | self.log_snr = learned_noise_schedule( 3096 | log_snr_max = log_snr_max, 3097 | log_snr_min = log_snr_min, 3098 | hidden_dim = learned_schedule_net_hidden_dim, 3099 | frac_gradient = learned_noise_schedule_frac_gradient 3100 | ) 3101 | else: 3102 | raise ValueError(f'unknown noise schedule {noise_schedule}') 3103 | 3104 | # sampling 3105 | 3106 | self.num_sample_steps = num_sample_steps 3107 | self.clip_sample_denoised = clip_sample_denoised 3108 | 3109 | # proposed https://arxiv.org/abs/2303.09556 3110 | 3111 | self.min_snr_loss_weight = min_snr_loss_weight 3112 | self.min_snr_gamma = min_snr_gamma 3113 | 3114 | self.cond_drop_prob = cond_drop_prob 3115 | self.class_cond_drop_prob = class_cond_drop_prob 3116 | self.loss_type = loss_type 3117 | 3118 | @property 3119 | def device(self): 3120 | return next(self.model.parameters()).device 3121 | 3122 | def p_mean_variance(self, x, time, condition_x, class_label, 3123 | cond_scale, class_cond_scale, time_next): 3124 | # reviewer found an error in the equation in the paper (missing sigma) 3125 | # following - https://openreview.net/forum?id=2LdBqxc1Yv¬eId=rIQgH0zKsRt 3126 | 3127 | log_snr = self.log_snr(time) 3128 | log_snr_next = self.log_snr(time_next) 3129 | c = -expm1(log_snr - log_snr_next) 3130 | 3131 | squared_alpha, squared_alpha_next = log_snr.sigmoid(), log_snr_next.sigmoid() 3132 | squared_sigma, squared_sigma_next = (-log_snr).sigmoid(), (-log_snr_next).sigmoid() 3133 | 3134 | alpha, sigma, alpha_next = map(sqrt, (squared_alpha, squared_sigma, squared_alpha_next)) 3135 | 3136 | batch_log_snr = repeat(log_snr, ' -> b', b = x.shape[0]) 3137 | 3138 | if (cond_scale != 1.0) and (class_cond_scale != 1.0): 3139 | raise NotImplementedError("Currently, you cannot specify both cond_scale and class_cond_scale at the same time.") 3140 | # full_null_out = self.model(x, batch_log_snr, None, None) 3141 | # class_null_out = self.model(x, batch_log_snr, None, condition_x) 3142 | # cond_null_out = self.model(x, batch_log_snr, class_label, None) 3143 | # full_cond_out = self.model(x, batch_log_snr, class_label, condition_x) 3144 | # pred_noise = full_null_out + \ 3145 | # ((full_cond_out - class_null_out) * class_cond_scale + \ 3146 | # (full_cond_out - cond_null_out) * cond_scale) / 2. 3147 | elif cond_scale != 1.0: 3148 | cond_out = self.model(x, batch_log_snr, class_label, condition_x) 3149 | null_out = self.model(x, batch_log_snr, class_label, None) 3150 | pred_noise = null_out + (cond_out - null_out) * cond_scale 3151 | elif class_cond_scale != 1.0: 3152 | cond_out = self.model(x, batch_log_snr, class_label, condition_x) 3153 | null_out = self.model(x, batch_log_snr, None, condition_x) 3154 | pred_noise = null_out + (cond_out - null_out) * class_cond_scale 3155 | elif cond_scale == 1.0 and class_cond_scale == 1.0: 3156 | pred_noise = self.model(x, batch_log_snr, class_label, condition_x) 3157 | else: 3158 | raise NotImplementedError() 3159 | 3160 | x_start = (x - sigma * pred_noise) / alpha 3161 | 3162 | if self.clip_sample_denoised: 3163 | x_start.clamp_(-1., 1.) 3164 | model_mean = alpha_next * (x * (1 - c) / alpha + c * x_start) 3165 | else: 3166 | model_mean = alpha_next / alpha * (x - c * sigma * pred_noise) 3167 | 3168 | posterior_variance = squared_sigma_next * c 3169 | 3170 | return model_mean, posterior_variance, x_start 3171 | 3172 | # sampling related functions 3173 | 3174 | @torch.inference_mode() 3175 | def p_sample(self, x, time, condition_x, class_label, 3176 | cond_scale, class_cond_scale, time_next): 3177 | batch, *_, device = *x.shape, x.device 3178 | 3179 | model_mean, model_variance, x_start = self.p_mean_variance(x = x, time = time, 3180 | condition_x = condition_x, class_label = class_label, 3181 | cond_scale = cond_scale, class_cond_scale = class_cond_scale, 3182 | time_next = time_next) 3183 | 3184 | if time_next == 0: 3185 | return model_mean, x_start 3186 | 3187 | noise = torch.randn_like(x) 3188 | return model_mean + sqrt(model_variance) * noise, x_start 3189 | 3190 | # @torch.inference_mode() 3191 | def p_sample_loop(self, shape, condition_x, class_label, 3192 | cond_scale, guidance_start_steps, 3193 | class_cond_scale, class_guidance_start_steps, 3194 | generation_start_steps, num_sample_steps, 3195 | with_images, with_x0_images): 3196 | batch = shape[0] 3197 | 3198 | if generation_start_steps > 0: 3199 | start_time = 1. - torch.tensor(generation_start_steps / num_sample_steps, device=condition_x.device) 3200 | start_times = repeat(start_time, ' -> b', b = batch) 3201 | img, _log_snr = self.q_sample(condition_x, start_times) 3202 | else: 3203 | img = torch.randn(shape, device = self.device) 3204 | 3205 | if with_images: 3206 | image_list = [] 3207 | image_list.append(img.clone().detach().cpu()) 3208 | 3209 | if with_x0_images: 3210 | x0_image_list = [] 3211 | x0_image_list.append(img.clone().detach().cpu()) 3212 | 3213 | steps = torch.linspace(1., 0., num_sample_steps + 1, device = self.device) 3214 | 3215 | for i in tqdm(range(num_sample_steps), desc = 'sampling loop time step', total = num_sample_steps): 3216 | if i < generation_start_steps: 3217 | continue 3218 | if i < guidance_start_steps: 3219 | cur_cond_scale = 1.0 3220 | else: 3221 | cur_cond_scale = cond_scale 3222 | if i < class_guidance_start_steps: 3223 | cur_class_cond_scale = 1.0 3224 | else: 3225 | cur_class_cond_scale = class_cond_scale 3226 | times = steps[i] 3227 | times_next = steps[i + 1] 3228 | with torch.inference_mode(): 3229 | img, x_start = self.p_sample(img, times, condition_x, class_label, 3230 | cur_cond_scale, cur_class_cond_scale, times_next) 3231 | 3232 | if with_images: 3233 | image_list.append(img.clone().detach().cpu()) 3234 | if with_x0_images: 3235 | x0_image_list.append(x_start.clone().detach().cpu()) 3236 | 3237 | img = img.clamp(-1., 1.) 3238 | img = unnormalize_to_zero_to_one(img) 3239 | 3240 | if with_images: 3241 | if with_x0_images: 3242 | return img, image_list, x0_image_list 3243 | else: 3244 | return img, image_list 3245 | else: 3246 | return img 3247 | 3248 | def delta_p_mean_variance(self, x_start, time, condition_x, class_label, 3249 | cond_scale, class_cond_scale, time_next, x0): 3250 | # reviewer found an error in the equation in the paper (missing sigma) 3251 | # following - https://openreview.net/forum?id=2LdBqxc1Yv¬eId=rIQgH0zKsRt 3252 | 3253 | log_snr = self.log_snr(time) 3254 | log_snr_next = self.log_snr(time_next) 3255 | c = -expm1(log_snr - log_snr_next) 3256 | 3257 | squared_alpha, squared_alpha_next = log_snr.sigmoid(), log_snr_next.sigmoid() 3258 | squared_sigma, squared_sigma_next = (-log_snr).sigmoid(), (-log_snr_next).sigmoid() 3259 | 3260 | alpha, sigma, alpha_next = map(sqrt, (squared_alpha, squared_sigma, squared_alpha_next)) 3261 | 3262 | batch_log_snr = repeat(log_snr, ' -> b', b = x_start.shape[0]) 3263 | 3264 | # Currently not supported for CFG with both condition_x and class_label 3265 | if (cond_scale != 1.0) and (class_cond_scale != 1.0): 3266 | raise NotImplementedError("Currently, you cannot specify both cond_scale and class_cond_scale at the same time.") 3267 | elif cond_scale != 1.0: 3268 | raise NotImplementedError() 3269 | elif class_cond_scale != 1.0: 3270 | raise NotImplementedError() 3271 | elif cond_scale == 1.0 and class_cond_scale == 1.0: 3272 | pred_noise = self.model(x, batch_log_snr, class_label, condition_x) 3273 | else: 3274 | raise NotImplementedError() 3275 | 3276 | x_start = (x - sigma * pred_noise) / alpha 3277 | 3278 | if self.clip_sample_denoised: 3279 | x_start.clamp_(-1., 1.) 3280 | model_mean = alpha_next * (x * (1 - c) / alpha + c * x_start) 3281 | else: 3282 | model_mean = alpha_next / alpha * (x - c * sigma * pred_noise) 3283 | 3284 | posterior_variance = squared_sigma_next * c 3285 | 3286 | return model_mean, posterior_variance, x_start 3287 | 3288 | def tiled_sample(self, batch_size=4, tile_size=256, tile_stride=256, 3289 | condition_x=None, class_label=None, 3290 | cond_scale=1.0, guidance_start_steps=0, 3291 | class_cond_scale=1.0, class_guidance_start_steps=0, 3292 | generation_start_steps=0, num_sample_steps=None, 3293 | with_images=False, with_x0_images=False, start_white_noise=True, amp=False): 3294 | 3295 | num_sample_steps = default(num_sample_steps, self.num_sample_steps) 3296 | condition_x = normalize_to_neg_one_to_one(condition_x) 3297 | 3298 | batch, c, h, w = condition_x.shape 3299 | 3300 | # pad condition_x 3301 | coord, pad = get_coord_and_pad(h, w) 3302 | left, top, right, bottom = coord 3303 | condition_x = F.pad(condition_x, pad, mode='reflect') 3304 | 3305 | if generation_start_steps > 0: 3306 | start_time = 1. - torch.tensor(generation_start_steps / num_sample_steps, device=condition_x.device) 3307 | start_times = repeat(start_time, ' -> b', b = batch) 3308 | img, _log_snr = self.q_sample(condition_x, start_times) 3309 | else: 3310 | if start_white_noise: 3311 | img = torch.randn(condition_x.shape, device = self.device) 3312 | else: 3313 | start_time = torch.tensor(1., device=condition_x.device) 3314 | start_times = repeat(start_time, ' -> b', b = batch) 3315 | img, _log_snr = self.q_sample(condition_x, start_times) 3316 | 3317 | if with_images: 3318 | image_list = [] 3319 | image_list.append(img[:,:,top:bottom,left:right].clone().detach().cpu()) 3320 | 3321 | if with_x0_images: 3322 | x0_image_list = [] 3323 | x0_image_list.append(img[:,:,top:bottom,left:right].clone().detach().cpu()) 3324 | 3325 | steps = torch.linspace(1., 0., num_sample_steps + 1, device = self.device) 3326 | 3327 | # Pre-calculate tile regions 3328 | _, _, height, width = condition_x.shape 3329 | coords0 = get_coords(height, width, tile_size, tile_size, diff=0) 3330 | if height <= tile_size and width <= tile_size: 3331 | coords1 = get_coords(height, width, tile_size, tile_stride, diff=0) 3332 | else: 3333 | coords1 = get_coords(height-tile_size, width-tile_size, tile_size, tile_stride, diff=tile_size//2) 3334 | coord_list = [coords0, coords1] 3335 | 3336 | # Get the region of the smaller coords 3337 | small_coord, small_pad = get_area(coords1, height, width) 3338 | sleft, stop, sright, sbottom = small_coord 3339 | 3340 | # Pad the outside of the smaller region of condition_x with 0 3341 | cropped_condition_x = condition_x[:,:,stop:sbottom,sleft:sright] 3342 | condition_x = F.pad(cropped_condition_x, small_pad, mode='constant', value=0) 3343 | 3344 | x_start = img.clone() 3345 | 3346 | for i in tqdm(range(num_sample_steps), desc = 'sampling loop time step', total = num_sample_steps): 3347 | if i < generation_start_steps: 3348 | continue 3349 | if i < guidance_start_steps: 3350 | cur_cond_scale = 1.0 3351 | else: 3352 | cur_cond_scale = cond_scale 3353 | if i < class_guidance_start_steps: 3354 | cur_class_cond_scale = 1.0 3355 | else: 3356 | cur_class_cond_scale = class_cond_scale 3357 | 3358 | times = steps[i] 3359 | times_next = steps[i + 1] 3360 | 3361 | cur_coords = coord_list[i%2] 3362 | 3363 | minibatch_index = 0 3364 | minibatch = torch.zeros((batch_size, c, tile_size, tile_size), device=condition_x.device) 3365 | minibatch_condition = torch.zeros((batch_size, c, tile_size, tile_size), device=condition_x.device) 3366 | output_indexes = [None] * batch_size 3367 | for hs, he, ws, we in cur_coords: 3368 | minibatch[minibatch_index] = img[:, :, hs:he, ws:we] 3369 | minibatch_condition[minibatch_index] = condition_x[:, :, hs:he, ws:we] 3370 | output_indexes[minibatch_index] = (hs, ws) 3371 | minibatch_index += 1 3372 | 3373 | if minibatch_index == batch_size: 3374 | with torch.inference_mode(): 3375 | tile_out, tile_x_start = self.p_sample(minibatch, times, minibatch_condition, class_label, 3376 | cur_cond_scale, cur_class_cond_scale, times_next) 3377 | for k in range(minibatch_index): 3378 | hs, ws = output_indexes[k] 3379 | img[:, :, hs:hs+tile_size, ws:ws+tile_size] = tile_out[k] 3380 | x_start[:, :, hs:hs+tile_size, ws:ws+tile_size] = tile_x_start[k] 3381 | minibatch_index = 0 3382 | 3383 | if minibatch_index > 0: 3384 | with torch.inference_mode(): 3385 | tile_out, tile_x_start = self.p_sample(minibatch[0:minibatch_index], times, minibatch_condition[0:minibatch_index], class_label, 3386 | cur_cond_scale, cur_class_cond_scale, times_next) 3387 | for k in range(minibatch_index): 3388 | hs, ws = output_indexes[k] 3389 | img[:, :, hs:hs+tile_size, ws:ws+tile_size] = tile_out[k] 3390 | x_start[:, :, hs:hs+tile_size, ws:ws+tile_size] = tile_x_start[k] 3391 | 3392 | if i%2 == 1: 3393 | # Reconstruct by removing the padding part of img when odd times 3394 | cropped_img = img[:,:,stop:sbottom,sleft:sright] 3395 | img, _log_snr = self.q_sample(torch.zeros_like(condition_x), times_next) 3396 | img[:,:,stop:sbottom,sleft:sright] = cropped_img 3397 | 3398 | if with_images: 3399 | image_list.append(img.clone().detach().cpu()) 3400 | if with_x0_images: 3401 | x0_image_list.append(x_start.clone().detach().cpu()) 3402 | 3403 | img = img[:,:,top:bottom,left:right] 3404 | img.clamp_(-1., 1.) 3405 | img = unnormalize_to_zero_to_one(img) 3406 | 3407 | if with_images: 3408 | if with_x0_images: 3409 | return img, image_list, x0_image_list 3410 | else: 3411 | return img, image_list 3412 | else: 3413 | return img 3414 | 3415 | 3416 | # @torch.inference_mode() 3417 | def sample(self, batch_size = 16, condition_x = None, class_label = None, 3418 | cond_scale = 1.0, guidance_start_steps = 0, 3419 | class_cond_scale = 1.0, class_guidance_start_steps = 0, 3420 | generation_start_steps = 0, num_sample_steps = None, 3421 | with_images=False, with_x0_images=False, x0=None): 3422 | num_sample_steps = default(num_sample_steps, self.num_sample_steps) 3423 | condition_x = normalize_to_neg_one_to_one(condition_x) 3424 | 3425 | return self.p_sample_loop((batch_size, self.channels, self.image_size, self.image_size), 3426 | condition_x, class_label, 3427 | cond_scale, guidance_start_steps, 3428 | class_cond_scale, class_guidance_start_steps, 3429 | generation_start_steps, num_sample_steps, 3430 | with_images, with_x0_images) 3431 | 3432 | # training related functions - noise prediction 3433 | 3434 | @autocast(enabled = False) 3435 | def q_sample(self, x_start, times, noise = None, return_alpha_sigma_sum=False): 3436 | noise = default(noise, lambda: torch.randn_like(x_start)) 3437 | 3438 | log_snr = self.log_snr(times) 3439 | 3440 | log_snr_padded = right_pad_dims_to(x_start, log_snr) 3441 | alpha, sigma = sqrt(log_snr_padded.sigmoid()), sqrt((-log_snr_padded).sigmoid()) 3442 | x_noised = x_start * alpha + noise * sigma 3443 | 3444 | if return_alpha_sigma_sum: 3445 | return x_noised, alpha+sigma 3446 | else: 3447 | return x_noised, log_snr 3448 | 3449 | def random_times(self, batch_size): 3450 | # times are now uniform from 0 to 1 3451 | return torch.zeros((batch_size,), device = self.device).float().uniform_(0, 1) 3452 | 3453 | @property 3454 | def loss_fn(self): 3455 | if self.loss_type == 'l1': 3456 | return F.l1_loss 3457 | elif self.loss_type == 'l2': 3458 | return F.mse_loss 3459 | elif self.loss_type == 'smooth_l1': 3460 | return F.smooth_l1_loss 3461 | else: 3462 | raise ValueError(f'invalid loss type {self.loss_type}') 3463 | 3464 | def p_losses(self, x_start, times, class_label, condition_x, noise = None): 3465 | noise = default(noise, lambda: torch.randn_like(x_start)) 3466 | 3467 | x, log_snr = self.q_sample(x_start = x_start, times = times, noise = noise) 3468 | model_out = self.model(x, log_snr, class_label, condition_x) 3469 | 3470 | losses = self.loss_fn(model_out, noise, reduction = 'none') 3471 | losses = reduce(losses, 'b ... -> b', 'mean') 3472 | 3473 | if self.min_snr_loss_weight: 3474 | snr = log_snr.exp() 3475 | loss_weight = snr.clamp(min = self.min_snr_gamma) / snr 3476 | losses = losses * loss_weight 3477 | 3478 | return losses.mean() 3479 | 3480 | def forward(self, img, condition_x, class_label, *args, **kwargs): 3481 | b, c, h, w, device, img_size, = *img.shape, img.device, self.image_size 3482 | assert h == img_size and w == img_size, f'height and width of image must be {img_size}' 3483 | 3484 | times = self.random_times(b) 3485 | img = normalize_to_neg_one_to_one(img) 3486 | 3487 | if torch.rand(1) < self.cond_drop_prob: 3488 | condition_x = None 3489 | else: 3490 | condition_x = normalize_to_neg_one_to_one(condition_x) 3491 | 3492 | if torch.rand(1) < self.class_cond_drop_prob: 3493 | class_label = None 3494 | 3495 | return self.p_losses(img, times, class_label, condition_x, *args, **kwargs) 3496 | 3497 | 3498 | 3499 | 3500 | def get_model(conf, logger): 3501 | dim_mults = tuple([int(elem) for elem in conf.ddpm_unet_dim_mults.split(',')]) 3502 | full_attn = tuple([eval(elem) for elem in conf.full_attn.split(',')]) 3503 | if 'conditional' in conf.model: 3504 | unet = ConditionalSRUnet( 3505 | dim = conf.unet_dim, 3506 | dim_mults = dim_mults, 3507 | full_attn = full_attn, 3508 | learned_variance = conf.learned_variance, 3509 | learned_sinusoidal_cond = conf.learned_sinusoidal_cond, 3510 | learned_sinusoidal_dim = conf.learned_sinusoidal_dim, 3511 | flash_attn = conf.flash_attn, 3512 | pixel_shuffle_upsample = conf.pixel_shuffle_upsample, 3513 | num_classes = conf.num_classes 3514 | ) 3515 | logger.info(f"ConditionalSRUnet: channels=6 dim={conf.unet_dim} dim_mults={conf.ddpm_unet_dim_mults} num_classes={conf.num_classes}") 3516 | else: 3517 | unet = SRUnet( 3518 | dim = conf.unet_dim, 3519 | dim_mults = dim_mults, 3520 | full_attn = full_attn, 3521 | learned_variance = conf.learned_variance, 3522 | learned_sinusoidal_cond = conf.learned_sinusoidal_cond, 3523 | learned_sinusoidal_dim = conf.learned_sinusoidal_dim, 3524 | flash_attn = conf.flash_attn, 3525 | pixel_shuffle_upsample = conf.pixel_shuffle_upsample, 3526 | use_free_u = conf.use_free_u, 3527 | free_u_b1 = conf.free_u_b1, 3528 | free_u_b2 = conf.free_u_b2, 3529 | free_u_s1 = conf.free_u_s1, 3530 | free_u_s2 = conf.free_u_s2 3531 | ) 3532 | logger.info(f"SRUnet: channels=6 dim={conf.unet_dim} dim_mults={conf.ddpm_unet_dim_mults}") 3533 | 3534 | if conf.model == 'gaussian': 3535 | assert not conf.learned_sinusoidal_cond 3536 | conf.use_dpmpp_solver = False 3537 | model = GaussianDiffusionSR( 3538 | model = unet, 3539 | image_size = conf.image_size, 3540 | timesteps = conf.timesteps, 3541 | sampling_timesteps = conf.sampling_timesteps, 3542 | objective = conf.objective, 3543 | beta_schedule = conf.beta_schedule, 3544 | offset_noise_strength = conf.offset_noise_strength, 3545 | min_snr_loss_weight = conf.min_snr_loss_weight, 3546 | min_snr_gamma = conf.min_snr_gamma, 3547 | cond_drop_prob = conf.cond_drop_prob, 3548 | loss_type = conf.loss_type, 3549 | ) 3550 | logger.info(f"GaussianDiffusionSR: image_size={conf.image_size} timesteps={conf.timesteps} sampling_timesteps={conf.sampling_timesteps}") 3551 | 3552 | elif conf.model == 'conditional_gaussian': 3553 | assert not conf.learned_sinusoidal_cond 3554 | conf.use_dpmpp_solver = False 3555 | model = ConditionalGaussianDiffusionSR( 3556 | model = unet, 3557 | image_size = conf.image_size, 3558 | timesteps = conf.timesteps, 3559 | sampling_timesteps = conf.sampling_timesteps, 3560 | objective = conf.objective, 3561 | beta_schedule = conf.beta_schedule, 3562 | offset_noise_strength = conf.offset_noise_strength, 3563 | min_snr_loss_weight = conf.min_snr_loss_weight, 3564 | min_snr_gamma = conf.min_snr_gamma, 3565 | cond_drop_prob = conf.cond_drop_prob, 3566 | class_cond_drop_prob = conf.class_cond_drop_prob, 3567 | loss_type = conf.loss_type, 3568 | ) 3569 | logger.info(f"ConditionalGaussianDiffusionSR: image_size={conf.image_size} timesteps={conf.timesteps} sampling_timesteps={conf.sampling_timesteps}") 3570 | 3571 | elif conf.model == 'elucidated': 3572 | assert conf.learned_sinusoidal_cond 3573 | model = ElucidatedDiffusionSR( 3574 | net = unet, 3575 | image_size = conf.image_size, 3576 | num_sample_steps = conf.num_sample_steps, 3577 | sigma_min = conf.sigma_min, 3578 | sigma_max = conf.sigma_max, 3579 | sigma_data = conf.sigma_data, 3580 | rho = conf.rho, 3581 | P_mean = conf.P_mean, 3582 | P_std = conf.P_std, 3583 | S_churn = conf.S_churn, 3584 | S_tmin = conf.S_tmin, 3585 | S_tmax = conf.S_tmax, 3586 | S_noise = conf.S_noise, 3587 | cond_drop_prob = conf.cond_drop_prob, 3588 | use_dpmpp_solver = conf.use_dpmpp_solver, 3589 | loss_type = conf.loss_type 3590 | ) 3591 | logger.info(f"ElucidatedDiffusionSR: image_size={conf.image_size} num_sample_steps={conf.num_sample_steps}") 3592 | 3593 | elif conf.model == 'conditional_elucidated': 3594 | assert conf.learned_sinusoidal_cond 3595 | model = ConditionalElucidatedDiffusionSR( 3596 | net = unet, 3597 | image_size = conf.image_size, 3598 | num_sample_steps = conf.num_sample_steps, 3599 | sigma_min = conf.sigma_min, 3600 | sigma_max = conf.sigma_max, 3601 | sigma_data = conf.sigma_data, 3602 | rho = conf.rho, 3603 | P_mean = conf.P_mean, 3604 | P_std = conf.P_std, 3605 | S_churn = conf.S_churn, 3606 | S_tmin = conf.S_tmin, 3607 | S_tmax = conf.S_tmax, 3608 | S_noise = conf.S_noise, 3609 | cond_drop_prob = conf.cond_drop_prob, 3610 | class_cond_drop_prob = conf.class_cond_drop_prob, 3611 | use_dpmpp_solver = conf.use_dpmpp_solver, 3612 | loss_type = conf.loss_type 3613 | ) 3614 | logger.info(f"ConditionalElucidatedDiffusionSR: image_size={conf.image_size} num_sample_steps={conf.num_sample_steps}") 3615 | 3616 | elif conf.model == 'continuous': 3617 | assert conf.learned_sinusoidal_cond 3618 | conf.use_dpmpp_solver = False 3619 | model = ContinuousTimeGaussianDiffusionSR( 3620 | model = unet, 3621 | image_size = conf.image_size, 3622 | noise_schedule = conf.noise_schedule, 3623 | num_sample_steps = conf.num_sample_steps, 3624 | clip_sample_denoised = conf.clip_sample_denoised, 3625 | learned_schedule_net_hidden_dim = conf.learned_schedule_net_hidden_dim, 3626 | learned_noise_schedule_frac_gradient = conf.learned_noise_schedule_frac_gradient, 3627 | min_snr_loss_weight = conf.min_snr_loss_weight, 3628 | min_snr_gamma = conf.min_snr_gamma, 3629 | cond_drop_prob = conf.cond_drop_prob, 3630 | loss_type = conf.loss_type, 3631 | ) 3632 | logger.info(f"ContinuousTimeGaussianDiffusionSR: image_size={conf.image_size} num_sample_steps={conf.num_sample_steps}") 3633 | 3634 | elif conf.model == 'conditional_continuous': 3635 | assert conf.learned_sinusoidal_cond 3636 | conf.use_dpmpp_solver = False 3637 | model = ConditionalContinuousTimeGaussianDiffusionSR( 3638 | model = unet, 3639 | image_size = conf.image_size, 3640 | noise_schedule = conf.noise_schedule, 3641 | num_sample_steps = conf.num_sample_steps, 3642 | clip_sample_denoised = conf.clip_sample_denoised, 3643 | learned_schedule_net_hidden_dim = conf.learned_schedule_net_hidden_dim, 3644 | learned_noise_schedule_frac_gradient = conf.learned_noise_schedule_frac_gradient, 3645 | min_snr_loss_weight = conf.min_snr_loss_weight, 3646 | min_snr_gamma = conf.min_snr_gamma, 3647 | cond_drop_prob = conf.cond_drop_prob, 3648 | class_cond_drop_prob = conf.class_cond_drop_prob, 3649 | loss_type = conf.loss_type, 3650 | ) 3651 | logger.info(f"ConditionalContinuousTimeGaussianDiffusionSR: image_size={conf.image_size} num_sample_steps={conf.num_sample_steps}") 3652 | 3653 | else: 3654 | raise NotImplementedError(conf.model) 3655 | 3656 | # ema model 3657 | ema_model = ModelEmaV2(model, decay=conf.ema_decay) 3658 | 3659 | if conf.ckpt_path: 3660 | ckpt = torch.load(conf.ckpt_path, map_location='cpu', weights_only=True) 3661 | # ema model 3662 | check = ema_model.module.load_state_dict(ckpt['ema_model'], strict=conf.load_strict) 3663 | logger.info(f"load ema_model weight from : {conf.ckpt_path}") 3664 | logger.info(f"check: {check}") 3665 | 3666 | return ema_model 3667 | 3668 | --------------------------------------------------------------------------------