├── README.md ├── __pycache__ ├── ddim_inversion_utils.cpython-37.pyc ├── inversion_utils.cpython-37.pyc ├── inversion_utils.cpython-39.pyc ├── inversion_utils1.cpython-37.pyc ├── inversion_utils2.cpython-37.pyc ├── inversion_utils3.cpython-37.pyc ├── utils.cpython-37.pyc └── utils.cpython-39.pyc ├── blind_deblurring.py ├── blind_non_uniform_deblurring.py ├── configs ├── blind_deblurring.yml ├── blind_non_uniform_deblurring.yml ├── denoising.yml ├── inpainting.yml └── super_resolution.yml ├── data ├── celeba_hq.yml ├── imgs │ ├── 00201.png │ ├── 00205.png │ ├── 00243.png │ └── 00287.png ├── kernel.npy └── motion_blur │ ├── 00870_gt.png │ └── 00870_input.png ├── ddim_inversion_utils.py ├── denoising.py ├── guided_diffusion ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-39.pyc │ ├── diffusion.cpython-37.pyc │ ├── fp16_util.cpython-37.pyc │ ├── logger.cpython-37.pyc │ ├── models.cpython-37.pyc │ ├── models.cpython-39.pyc │ ├── nn.cpython-37.pyc │ ├── script_util.cpython-37.pyc │ └── unet.cpython-37.pyc ├── diffusion.py ├── fp16_util.py ├── logger.py ├── models.py ├── nn.py ├── script_util.py └── unet.py ├── inpainting.py ├── push_to_hf.py ├── results ├── blind_deblurring.png ├── denoised.png ├── inpainted.png ├── non_uniform_deblurring.png └── super_resolution.png ├── super_resolution.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Blind Image Restoration via Fast Diffusion Inversion (BIRD) 2 | 3 | This is the official implementation of "Blind Image Restoration via Fast Diffusion Inversion 4 | ". [arxiv](https://arxiv.org/abs/2405.19572) 5 | 6 | ## Environment 7 | ``` 8 | pip install numpy torch blobfile tqdm pyYaml pillow diffusers 9 | ``` 10 | 11 | ## Pre-Trained Models 12 | 13 | For face restoration, download the pretrained [model](https://drive.google.com/file/d/1qMs7tNGV3tkOZNKH5L130dkwsmobEJdh/view?usp=sharing) from and put it into ```checkpoints/```. 14 | 15 | 16 | ## Blind Deblurring 17 | 18 | ``` 19 | python blind_deblurring.py 20 | ``` 21 | ![image info](results/blind_deblurring.png) 22 | 23 | ## Non-uniform Deblurring 24 | 25 | ``` 26 | python non_uniform_deblurring.py 27 | ``` 28 | ![image info](results/non_uniform_deblurring.png) 29 | 30 | ## Inpainting 31 | 32 | ``` 33 | python inpainting.py 34 | ``` 35 | 36 | ![image info](results/inpainted.png) 37 | 38 | ## Denoising 39 | 40 | ``` 41 | python denoising.py 42 | ``` 43 | ![image info](results/denoised.png) 44 | 45 | ## Superresolution 46 | 47 | ``` 48 | python super_resolution.py 49 | ``` 50 | ![image info](results/super_resolution.png) 51 | 52 | 53 | ## References 54 | 55 | If you find this repository useful for your research, please cite the following work. 56 | 57 | 58 | 59 | ``` 60 | @article{chihaoui2024blind, 61 | title={Blind Image Restoration via Fast Diffusion Inversion}, 62 | author={Chihaoui, Hamadi and Lemkhenter, Abdelhak and Favaro, Paolo}, 63 | journal={arXiv preprint arXiv:2405.19572}, 64 | year={2024} 65 | } 66 | 67 | ``` 68 | 69 | 70 | -------------------------------------------------------------------------------- /__pycache__/ddim_inversion_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hamadichihaoui/BIRD/713dab700a0192b224641ecd5c0e264a96e47fab/__pycache__/ddim_inversion_utils.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/inversion_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hamadichihaoui/BIRD/713dab700a0192b224641ecd5c0e264a96e47fab/__pycache__/inversion_utils.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/inversion_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hamadichihaoui/BIRD/713dab700a0192b224641ecd5c0e264a96e47fab/__pycache__/inversion_utils.cpython-39.pyc -------------------------------------------------------------------------------- /__pycache__/inversion_utils1.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hamadichihaoui/BIRD/713dab700a0192b224641ecd5c0e264a96e47fab/__pycache__/inversion_utils1.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/inversion_utils2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hamadichihaoui/BIRD/713dab700a0192b224641ecd5c0e264a96e47fab/__pycache__/inversion_utils2.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/inversion_utils3.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hamadichihaoui/BIRD/713dab700a0192b224641ecd5c0e264a96e47fab/__pycache__/inversion_utils3.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hamadichihaoui/BIRD/713dab700a0192b224641ecd5c0e264a96e47fab/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hamadichihaoui/BIRD/713dab700a0192b224641ecd5c0e264a96e47fab/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /blind_deblurring.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import numpy as np 4 | import tqdm 5 | import torch 6 | from torch import nn 7 | import sys 8 | sys.path.insert(0,'./') 9 | from guided_diffusion.models import Model 10 | import random 11 | from ddim_inversion_utils import * 12 | from utils import * 13 | 14 | with open('configs/blind_deblurring.yml', 'r') as f: 15 | task_config = yaml.safe_load(f) 16 | 17 | 18 | ### Reproducibility 19 | torch.set_printoptions(sci_mode=False) 20 | ensure_reproducibility(task_config['seed']) 21 | 22 | 23 | with open( "data/celeba_hq.yml", "r") as f: 24 | config1 = yaml.safe_load(f) 25 | config = dict2namespace(config1) 26 | model, device = load_pretrained_diffusion_model(config) 27 | 28 | ### Define the DDIM scheduler 29 | ddim_scheduler=DDIMScheduler(beta_start=config.diffusion.beta_start, beta_end=config.diffusion.beta_end, beta_schedule=config.diffusion.beta_schedule) 30 | ddim_scheduler.set_timesteps(config.diffusion.num_diffusion_timesteps // task_config['delta_t'])#task_config['Denoising_steps'] 31 | 32 | #scale=41 33 | l2_loss= nn.MSELoss() #nn.L1Loss() 34 | net_kernel = fcn(200, task_config['kernel_size'] * task_config['kernel_size']).cuda() 35 | net_input_kernel = get_noise(200, 'noise', (1, 1)).cuda() 36 | net_input_kernel.squeeze_() 37 | 38 | 39 | img_pil, downsampled_torch = generate_blurry_image('data/imgs/00287.png') 40 | radii = torch.ones([1, 1, 1]).cuda() * (np.sqrt(256*256*3)) 41 | 42 | latent = torch.nn.parameter.Parameter(torch.randn( 1, config.model.in_channels, config.data.image_size, config.data.image_size).to(device)) 43 | optimizer = torch.optim.Adam([{'params':latent,'lr':task_config['lr_img']}, {'params':net_kernel.parameters(),'lr':task_config['lr_blur']}]) 44 | 45 | 46 | for iteration in range(task_config['Optimization_steps']): 47 | optimizer.zero_grad() 48 | x_0_hat = DDIM_efficient_feed_forward(latent, model, ddim_scheduler) 49 | out_k = net_kernel(net_input_kernel) 50 | out_k_m = out_k.view(-1, 1, task_config['kernel_size'], task_config['kernel_size']) 51 | 52 | blurred_xt = nn.functional.conv2d(x_0_hat.view(-1, 1, config.data.image_size, config.data.image_size), out_k_m, padding="same", bias=None).view(1, 3, config.data.image_size, config.data.image_size) 53 | loss = l2_loss(blurred_xt, downsampled_torch) 54 | loss.backward() 55 | optimizer.step() 56 | 57 | #Project to the Sphere of radius sqrt(D) 58 | for param in latent: 59 | param.data.div_((param.pow(2).sum(tuple(range(0, param.ndim)), keepdim=True) + 1e-9).sqrt()) 60 | param.data.mul_(radii) 61 | 62 | if iteration % 10 == 0: 63 | #psnr = psnr_orig(np.array(img_pil).astype(np.float32), process(x_0_hat, 0)) 64 | #print(iteration, 'loss:', loss.item(), torch.norm(latent.detach()), psnr) 65 | Image.fromarray(np.concatenate([ process(downsampled_torch, 0), process(x_0_hat, 0), np.array(img_pil).astype(np.uint8)], 1)).save('results/blind_deblurring.png') 66 | 67 | 68 | 69 | 70 | -------------------------------------------------------------------------------- /blind_non_uniform_deblurring.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import numpy as np 4 | import tqdm 5 | import torch 6 | from torch import nn 7 | import sys 8 | sys.path.insert(0,'./') 9 | from guided_diffusion.models import Model 10 | import random 11 | from ddim_inversion_utils import * 12 | from utils import * 13 | 14 | with open('configs/blind_non_uniform_deblurring.yml', 'r') as f: 15 | task_config = yaml.safe_load(f) 16 | 17 | 18 | ### Reproducibility 19 | torch.set_printoptions(sci_mode=False) 20 | ensure_reproducibility(task_config['seed']) 21 | 22 | 23 | with open( "data/celeba_hq.yml", "r") as f: 24 | config1 = yaml.safe_load(f) 25 | config = dict2namespace(config1) 26 | model, device = load_pretrained_diffusion_model(config) 27 | 28 | ### Define the DDIM scheduler 29 | ddim_scheduler=DDIMScheduler(beta_start=config.diffusion.beta_start, beta_end=config.diffusion.beta_end, beta_schedule=config.diffusion.beta_schedule) 30 | ddim_scheduler.set_timesteps(config.diffusion.num_diffusion_timesteps // task_config['delta_t'])#task_config['Denoising_steps'] 31 | 32 | 33 | img_np = (np.array(Image.open('data/motion_blur/00870_input.png')) / 255.)*2. - 1. 34 | img_pil = Image.open('data/motion_blur/00870_gt.png') 35 | img_torch = torch.tensor(img_np).permute(2, 0, 1).unsqueeze(0).float() 36 | radii = torch.ones([1, 1, 1]).cuda() * (np.sqrt(config.data.image_size*config.data.image_size*config.model.in_channels)) 37 | 38 | latent = torch.nn.parameter.Parameter(torch.randn( 1, config.model.in_channels, config.data.image_size, config.data.image_size).to(device)) 39 | l2_loss=nn.MSELoss() #nn.L1Loss() 40 | optimizer = torch.optim.Adam([{'params':latent,'lr':task_config['lr']}])# 41 | 42 | 43 | for iteration in range(task_config['Optimization_steps']): 44 | optimizer.zero_grad() 45 | x_0_hat = DDIM_efficient_feed_forward(latent, model, ddim_scheduler) 46 | loss = l2_loss(x_0_hat, img_torch.cuda()) 47 | loss.backward() 48 | optimizer.step() 49 | 50 | #Project to the Sphere of radius sqrt(D) 51 | for param in latent: 52 | param.data.div_((param.pow(2).sum(tuple(range(0, param.ndim)), keepdim=True) + 1e-9).sqrt()) 53 | param.data.mul_(radii) 54 | 55 | if iteration % 10 == 0: 56 | #psnr = psnr_orig(np.array(img_pil).astype(np.float32), process(x_0_hat, 0)) 57 | #print(iteration, 'loss:', loss.item(), torch.norm(latent.detach()), psnr) 58 | Image.fromarray(np.concatenate([ process(img_torch.cuda(), 0), process(x_0_hat, 0), np.array(img_pil).astype(np.uint8)], 1)).save('results/non_uniform_deblurring.png') 59 | 60 | 61 | 62 | 63 | -------------------------------------------------------------------------------- /configs/blind_deblurring.yml: -------------------------------------------------------------------------------- 1 | Denoising_steps: 10 2 | Optimization_steps: 100 3 | lr_blur: 0.0002 4 | lr_img: 0.003 #0.003 5 | delta_t: 100 6 | kernel_size: 41 7 | seed: 0 -------------------------------------------------------------------------------- /configs/blind_non_uniform_deblurring.yml: -------------------------------------------------------------------------------- 1 | Optimization_steps: 25 2 | lr: 0.002 #0.003 3 | delta_t: 100 4 | seed: 0 -------------------------------------------------------------------------------- /configs/denoising.yml: -------------------------------------------------------------------------------- 1 | Denoising_steps: 10 2 | Optimization_steps: 200 3 | lr: 0.01 4 | delta_t: 100 5 | seed: 0 6 | 7 | -------------------------------------------------------------------------------- /configs/inpainting.yml: -------------------------------------------------------------------------------- 1 | Optimization_steps: 200 2 | lr: 0.01 3 | delta_t: 100 4 | seed: 0 5 | #Denoising_steps: 10 6 | 7 | 8 | -------------------------------------------------------------------------------- /configs/super_resolution.yml: -------------------------------------------------------------------------------- 1 | Optimization_steps: 100 #60 2 | lr: 0.001 #0.003 3 | delta_t: 100 4 | downsampling_ratio: 8 5 | seed: 0 6 | #Denoising_steps: 10 7 | -------------------------------------------------------------------------------- /data/celeba_hq.yml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset: "CelebA_HQ" 3 | category: "" 4 | image_size: 256 5 | channels: 3 6 | logit_transform: false 7 | uniform_dequantization: false 8 | gaussian_dequantization: false 9 | random_flip: true 10 | rescaled: true 11 | num_workers: 32 12 | out_of_dist: false 13 | 14 | model: 15 | type: "simple" 16 | in_channels: 3 17 | out_ch: 3 18 | ch: 128 19 | ch_mult: [1, 1, 2, 2, 4, 4] 20 | num_res_blocks: 2 21 | attn_resolutions: [16, ] 22 | dropout: 0.0 23 | var_type: fixedsmall 24 | ema_rate: 0.999 25 | ema: True 26 | resamp_with_conv: True 27 | 28 | diffusion: 29 | beta_schedule: linear 30 | beta_start: 0.0001 31 | beta_end: 0.02 32 | num_diffusion_timesteps: 1000 33 | 34 | sampling: 35 | batch_size: 1 36 | 37 | time_travel: 38 | T_sampling: 100 39 | travel_length: 1 40 | travel_repeat: 1 -------------------------------------------------------------------------------- /data/imgs/00201.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hamadichihaoui/BIRD/713dab700a0192b224641ecd5c0e264a96e47fab/data/imgs/00201.png -------------------------------------------------------------------------------- /data/imgs/00205.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hamadichihaoui/BIRD/713dab700a0192b224641ecd5c0e264a96e47fab/data/imgs/00205.png -------------------------------------------------------------------------------- /data/imgs/00243.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hamadichihaoui/BIRD/713dab700a0192b224641ecd5c0e264a96e47fab/data/imgs/00243.png -------------------------------------------------------------------------------- /data/imgs/00287.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hamadichihaoui/BIRD/713dab700a0192b224641ecd5c0e264a96e47fab/data/imgs/00287.png -------------------------------------------------------------------------------- /data/kernel.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hamadichihaoui/BIRD/713dab700a0192b224641ecd5c0e264a96e47fab/data/kernel.npy -------------------------------------------------------------------------------- /data/motion_blur/00870_gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hamadichihaoui/BIRD/713dab700a0192b224641ecd5c0e264a96e47fab/data/motion_blur/00870_gt.png -------------------------------------------------------------------------------- /data/motion_blur/00870_input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hamadichihaoui/BIRD/713dab700a0192b224641ecd5c0e264a96e47fab/data/motion_blur/00870_input.png -------------------------------------------------------------------------------- /ddim_inversion_utils.py: -------------------------------------------------------------------------------- 1 | # adapted and updated from https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddim.py 2 | 3 | import math 4 | from dataclasses import dataclass 5 | from typing import Optional, Tuple, Union 6 | 7 | import numpy as np 8 | import torch 9 | #import matplotlib.pyplot as plt 10 | from tqdm import tqdm as tqdm1 11 | from tqdm.auto import tqdm 12 | from PIL import Image 13 | 14 | from diffusers.configuration_utils import ConfigMixin, register_to_config 15 | from diffusers.utils import BaseOutput, deprecate 16 | from diffusers.schedulers.scheduling_utils import SchedulerMixin 17 | 18 | 19 | 20 | def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor: 21 | 22 | def alpha_bar(time_step): 23 | return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 24 | 25 | betas = [] 26 | for i in range(num_diffusion_timesteps): 27 | t1 = i / num_diffusion_timesteps 28 | t2 = (i + 1) / num_diffusion_timesteps 29 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 30 | return torch.tensor(betas) 31 | 32 | 33 | def DDIM_efficient_feed_forward(latent, model, ddim_scheduler): 34 | for i, t in enumerate(tqdm1(ddim_scheduler.timesteps)): 35 | t1 = (torch.ones(1) * t) .cuda()#.to(x_t.device) 36 | with torch.no_grad(): 37 | if i == 0: 38 | noise_pred = model(latent, t1) 39 | else: 40 | noise_pred = model(x_t, t1) #.sample 41 | noise_pred = noise_pred[:, :3] 42 | 43 | if i == 0: 44 | x_t = ddim_scheduler.step(noise_pred, t, latent, return_dict=True, use_clipped_model_output=True) 45 | else: 46 | x_t = ddim_scheduler.step(noise_pred, t, x_t, return_dict=True, use_clipped_model_output=True) 47 | return x_t 48 | 49 | class DDIMScheduler(SchedulerMixin, ConfigMixin): 50 | 51 | @register_to_config 52 | def __init__( 53 | self, 54 | num_train_timesteps: int = 1000, 55 | beta_start: float = 0.0001, 56 | beta_end: float = 0.02, 57 | beta_schedule: str = "linear", 58 | trained_betas: Optional[np.ndarray] = None, 59 | clip_sample: bool = True, 60 | set_alpha_to_one: bool = True, 61 | steps_offset: int = 0, 62 | ): 63 | if trained_betas is not None: 64 | self.betas = torch.from_numpy(trained_betas) 65 | elif beta_schedule == "linear": 66 | self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) 67 | elif beta_schedule == "scaled_linear": 68 | # this schedule is very specific to the latent diffusion model. 69 | self.betas = ( 70 | torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 71 | ) 72 | elif beta_schedule == "squaredcos_cap_v2": 73 | # Glide cosine schedule 74 | self.betas = betas_for_alpha_bar(num_train_timesteps) 75 | else: 76 | raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") 77 | 78 | self.alphas = 1.0 - self.betas 79 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) 80 | 81 | # At every step in ddim, we are looking into the previous alphas_cumprod 82 | # For the final step, there is no previous alphas_cumprod because we are already at 0 83 | # `set_alpha_to_one` decides whether we set this parameter simply to one or 84 | # whether we use the final alpha of the "non-previous" one. 85 | self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] 86 | 87 | # standard deviation of the initial noise distribution 88 | self.init_noise_sigma = 1.0 89 | 90 | # setable values 91 | self.num_inference_steps = None 92 | self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) 93 | 94 | def _get_variance(self, timestep, prev_timestep): 95 | alpha_prod_t = self.alphas_cumprod[timestep] 96 | alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod 97 | beta_prod_t = 1 - alpha_prod_t 98 | beta_prod_t_prev = 1 - alpha_prod_t_prev 99 | 100 | variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) 101 | 102 | return variance 103 | 104 | def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): 105 | """ 106 | Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. 107 | Args: 108 | num_inference_steps (`int`): 109 | the number of diffusion steps used when generating samples with a pre-trained model. 110 | """ 111 | self.num_inference_steps = num_inference_steps 112 | step_ratio = self.config.num_train_timesteps // self.num_inference_steps 113 | # creates integer timesteps by multiplying by ratio 114 | # casting to int to avoid issues when num_inference_step is power of 3 115 | timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) 116 | self.timesteps = torch.from_numpy(timesteps).to(device) 117 | self.timesteps += self.config.steps_offset 118 | 119 | def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: 120 | """ 121 | Ensures interchangeability with schedulers that need to scale the denoising model input depending on the 122 | current timestep. 123 | Args: 124 | sample (`torch.FloatTensor`): input sample 125 | timestep (`int`, optional): current timestep 126 | Returns: 127 | `torch.FloatTensor`: scaled input sample 128 | """ 129 | return sample 130 | 131 | def step( 132 | self, 133 | model_output: torch.FloatTensor, 134 | timestep: int, 135 | sample: torch.FloatTensor, 136 | eta: float = 0.0, 137 | use_clipped_model_output: bool = False, 138 | generator=None, 139 | return_dict: bool = True, 140 | ) -> torch.FloatTensor: 141 | """ 142 | Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion 143 | process from the learned model outputs (most often the predicted noise). 144 | 145 | Args: 146 | model_output (`torch.FloatTensor`): direct output from learned diffusion model. 147 | timestep (`int`): current discrete timestep in the diffusion chain. 148 | sample (`torch.FloatTensor`): 149 | current instance of sample being created by diffusion process. 150 | eta (`float`): weight of noise for added noise in diffusion step. 151 | use_clipped_model_output (`bool`): TODO 152 | generator: random number generator. 153 | return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class 154 | 155 | Returns: 156 | [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`: 157 | [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When 158 | returning a tuple, the first element is the sample tensor. 159 | 160 | """ 161 | if self.num_inference_steps is None: 162 | raise ValueError( 163 | "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" 164 | ) 165 | 166 | # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf 167 | # Ideally, read DDIM paper in-detail understanding 168 | 169 | # Notation ( -> 170 | # - pred_noise_t -> e_theta(x_t, t) 171 | # - pred_original_sample -> f_theta(x_t, t) or x_0 172 | # - std_dev_t -> sigma_t 173 | # - eta -> η 174 | # - pred_sample_direction -> "direction pointing to x_t" 175 | # - pred_prev_sample -> "x_t-1" 176 | 177 | # 1. get previous step value (=t-1) 178 | prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps 179 | 180 | # 2. compute alphas, betas 181 | alpha_prod_t = self.alphas_cumprod[timestep] 182 | alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod 183 | 184 | beta_prod_t = 1 - alpha_prod_t 185 | 186 | # 3. compute predicted original sample from predicted noise also called 187 | # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf 188 | pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) 189 | 190 | # 4. Clip "predicted x_0" 191 | if self.config.clip_sample: 192 | pred_original_sample = torch.clamp(pred_original_sample, -1, 1) 193 | 194 | # 5. compute variance: "sigma_t(η)" -> see formula (16) 195 | # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) 196 | variance = self._get_variance(timestep, prev_timestep) 197 | std_dev_t = eta * variance ** (0.5) 198 | 199 | if use_clipped_model_output: 200 | # the model_output is always re-derived from the clipped x_0 in Glide 201 | model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) 202 | 203 | # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf 204 | pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output 205 | 206 | # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf 207 | prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction 208 | 209 | if eta > 0: 210 | device = model_output.device if torch.is_tensor(model_output) else "cpu" 211 | noise = torch.randn(model_output.shape, generator=generator).to(device) 212 | variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise 213 | 214 | prev_sample = prev_sample + variance 215 | 216 | 217 | return prev_sample 218 | 219 | 220 | 221 | def add_noise( 222 | self, 223 | original_samples: torch.FloatTensor, 224 | noise: torch.FloatTensor, 225 | timesteps: torch.IntTensor, 226 | ) -> torch.FloatTensor: 227 | if self.alphas_cumprod.device != original_samples.device: 228 | self.alphas_cumprod = self.alphas_cumprod.to(original_samples.device) 229 | if timesteps.device != original_samples.device: 230 | timesteps = timesteps.to(original_samples.device) 231 | 232 | sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 233 | sqrt_alpha_prod = sqrt_alpha_prod.flatten() 234 | while len(sqrt_alpha_prod.shape) < len(original_samples.shape): 235 | sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) 236 | 237 | sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 238 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() 239 | while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): 240 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) 241 | 242 | noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise 243 | return noisy_samples 244 | 245 | def __len__(self): 246 | return self.config.num_train_timesteps 247 | 248 | -------------------------------------------------------------------------------- /denoising.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import numpy as np 4 | import tqdm 5 | import torch 6 | from torch import nn 7 | import sys 8 | sys.path.insert(0,'./') 9 | from guided_diffusion.models import Model 10 | import random 11 | from ddim_inversion_utils import * 12 | from utils import * 13 | 14 | with open('configs/denoising.yml', 'r') as f: 15 | task_config = yaml.safe_load(f) 16 | 17 | 18 | ### Reproducibility 19 | torch.set_printoptions(sci_mode=False) 20 | ensure_reproducibility(task_config['seed']) 21 | 22 | 23 | with open( "data/celeba_hq.yml", "r") as f: 24 | config1 = yaml.safe_load(f) 25 | config = dict2namespace(config1) 26 | model, device = load_pretrained_diffusion_model(config) 27 | 28 | ### Define the DDIM scheduler 29 | ddim_scheduler=DDIMScheduler(beta_start=config.diffusion.beta_start, beta_end=config.diffusion.beta_end, beta_schedule=config.diffusion.beta_schedule) 30 | ddim_scheduler.set_timesteps(config.diffusion.num_diffusion_timesteps // task_config['delta_t'])#task_config['Denoising_steps'] 31 | 32 | 33 | img_pil, img_np = generate_noisy_image('data/imgs/00201.png') 34 | img_torch = torch.tensor(img_np).permute(2, 0, 1).unsqueeze(0) 35 | radii = torch.ones([1, 1, 1]).cuda() * (np.sqrt(config.data.image_size*config.data.image_size*config.model.in_channels)) 36 | 37 | latent = torch.nn.parameter.Parameter(torch.randn( 1, config.model.in_channels, config.data.image_size, config.data.image_size).to(device)) 38 | l2_loss=nn.MSELoss() #nn.L1Loss() 39 | optimizer = torch.optim.Adam([{'params':latent,'lr':task_config['lr']}])# 40 | 41 | 42 | for iteration in range(task_config['Optimization_steps']): 43 | optimizer.zero_grad() 44 | x_0_hat = DDIM_efficient_feed_forward(latent, model, ddim_scheduler) 45 | loss = l2_loss(x_0_hat, img_torch.cuda()) 46 | loss.backward() 47 | optimizer.step() 48 | 49 | #Project to the Sphere of radius sqrt(D) 50 | for param in latent: 51 | param.data.div_((param.pow(2).sum(tuple(range(0, param.ndim)), keepdim=True) + 1e-9).sqrt()) 52 | param.data.mul_(radii) 53 | 54 | if iteration % 10 == 0: 55 | #psnr = psnr_orig(np.array(img_pil).astype(np.float32), process(x_0_hat, 0)) 56 | #print(iteration, 'loss:', loss.item(), torch.norm(latent.detach()), psnr) 57 | Image.fromarray(np.concatenate([ process(img_torch.cuda(), 0), process(x_0_hat, 0), np.array(img_pil).astype(np.uint8)], 1)).save('results/denoised.png') 58 | 59 | 60 | 61 | 62 | -------------------------------------------------------------------------------- /guided_diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hamadichihaoui/BIRD/713dab700a0192b224641ecd5c0e264a96e47fab/guided_diffusion/__init__.py -------------------------------------------------------------------------------- /guided_diffusion/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hamadichihaoui/BIRD/713dab700a0192b224641ecd5c0e264a96e47fab/guided_diffusion/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /guided_diffusion/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hamadichihaoui/BIRD/713dab700a0192b224641ecd5c0e264a96e47fab/guided_diffusion/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /guided_diffusion/__pycache__/diffusion.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hamadichihaoui/BIRD/713dab700a0192b224641ecd5c0e264a96e47fab/guided_diffusion/__pycache__/diffusion.cpython-37.pyc -------------------------------------------------------------------------------- /guided_diffusion/__pycache__/fp16_util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hamadichihaoui/BIRD/713dab700a0192b224641ecd5c0e264a96e47fab/guided_diffusion/__pycache__/fp16_util.cpython-37.pyc -------------------------------------------------------------------------------- /guided_diffusion/__pycache__/logger.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hamadichihaoui/BIRD/713dab700a0192b224641ecd5c0e264a96e47fab/guided_diffusion/__pycache__/logger.cpython-37.pyc -------------------------------------------------------------------------------- /guided_diffusion/__pycache__/models.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hamadichihaoui/BIRD/713dab700a0192b224641ecd5c0e264a96e47fab/guided_diffusion/__pycache__/models.cpython-37.pyc -------------------------------------------------------------------------------- /guided_diffusion/__pycache__/models.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hamadichihaoui/BIRD/713dab700a0192b224641ecd5c0e264a96e47fab/guided_diffusion/__pycache__/models.cpython-39.pyc -------------------------------------------------------------------------------- /guided_diffusion/__pycache__/nn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hamadichihaoui/BIRD/713dab700a0192b224641ecd5c0e264a96e47fab/guided_diffusion/__pycache__/nn.cpython-37.pyc -------------------------------------------------------------------------------- /guided_diffusion/__pycache__/script_util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hamadichihaoui/BIRD/713dab700a0192b224641ecd5c0e264a96e47fab/guided_diffusion/__pycache__/script_util.cpython-37.pyc -------------------------------------------------------------------------------- /guided_diffusion/__pycache__/unet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hamadichihaoui/BIRD/713dab700a0192b224641ecd5c0e264a96e47fab/guided_diffusion/__pycache__/unet.cpython-37.pyc -------------------------------------------------------------------------------- /guided_diffusion/diffusion.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import time 4 | import glob 5 | 6 | import numpy as np 7 | import tqdm 8 | import torch 9 | import torch.utils.data as data 10 | 11 | from datasets import get_dataset, data_transform, inverse_data_transform 12 | from functions.ckpt_util import get_ckpt_path, download 13 | from functions.svd_ddnm import ddnm_diffusion, ddnm_plus_diffusion 14 | 15 | import torchvision.utils as tvu 16 | 17 | from guided_diffusion.models import Model 18 | from guided_diffusion.script_util import create_model, create_classifier, classifier_defaults, args_to_dict 19 | import random 20 | 21 | from scipy.linalg import orth 22 | 23 | 24 | def get_gaussian_noisy_img(img, noise_level): 25 | return img + torch.randn_like(img).cuda() * noise_level 26 | 27 | def MeanUpsample(x, scale): 28 | n, c, h, w = x.shape 29 | out = torch.zeros(n, c, h, scale, w, scale).to(x.device) + x.view(n,c,h,1,w,1) 30 | out = out.view(n, c, scale*h, scale*w) 31 | return out 32 | 33 | def color2gray(x): 34 | coef=1/3 35 | x = x[:,0,:,:] * coef + x[:,1,:,:]*coef + x[:,2,:,:]*coef 36 | return x.repeat(1,3,1,1) 37 | 38 | def gray2color(x): 39 | x = x[:,0,:,:] 40 | coef=1/3 41 | base = coef**2 + coef**2 + coef**2 42 | return torch.stack((x*coef/base, x*coef/base, x*coef/base), 1) 43 | 44 | 45 | 46 | def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): 47 | def sigmoid(x): 48 | return 1 / (np.exp(-x) + 1) 49 | 50 | if beta_schedule == "quad": 51 | betas = ( 52 | np.linspace( 53 | beta_start ** 0.5, 54 | beta_end ** 0.5, 55 | num_diffusion_timesteps, 56 | dtype=np.float64, 57 | ) 58 | ** 2 59 | ) 60 | elif beta_schedule == "linear": 61 | betas = np.linspace( 62 | beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 63 | ) 64 | elif beta_schedule == "const": 65 | betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) 66 | elif beta_schedule == "jsd": 67 | betas = 1.0 / np.linspace( 68 | num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64 69 | ) 70 | elif beta_schedule == "sigmoid": 71 | betas = np.linspace(-6, 6, num_diffusion_timesteps) 72 | betas = sigmoid(betas) * (beta_end - beta_start) + beta_start 73 | else: 74 | raise NotImplementedError(beta_schedule) 75 | assert betas.shape == (num_diffusion_timesteps,) 76 | return betas 77 | 78 | 79 | class Diffusion(object): 80 | def __init__(self, args, config, device=None): 81 | self.args = args 82 | self.config = config 83 | if device is None: 84 | device = ( 85 | torch.device("cuda") 86 | if torch.cuda.is_available() 87 | else torch.device("cpu") 88 | ) 89 | self.device = device 90 | 91 | self.model_var_type = config.model.var_type 92 | betas = get_beta_schedule( 93 | beta_schedule=config.diffusion.beta_schedule, 94 | beta_start=config.diffusion.beta_start, 95 | beta_end=config.diffusion.beta_end, 96 | num_diffusion_timesteps=config.diffusion.num_diffusion_timesteps, 97 | ) 98 | betas = self.betas = torch.from_numpy(betas).float().to(self.device) 99 | self.num_timesteps = betas.shape[0] 100 | 101 | alphas = 1.0 - betas 102 | alphas_cumprod = alphas.cumprod(dim=0) 103 | alphas_cumprod_prev = torch.cat( 104 | [torch.ones(1).to(device), alphas_cumprod[:-1]], dim=0 105 | ) 106 | self.alphas_cumprod_prev = alphas_cumprod_prev 107 | posterior_variance = ( 108 | betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) 109 | ) 110 | if self.model_var_type == "fixedlarge": 111 | self.logvar = betas.log() 112 | elif self.model_var_type == "fixedsmall": 113 | self.logvar = posterior_variance.clamp(min=1e-20).log() 114 | 115 | def sample(self, simplified): 116 | cls_fn = None 117 | if self.config.model.type == 'simple': 118 | model = Model(self.config) 119 | 120 | if self.config.data.dataset == "CIFAR10": 121 | name = "cifar10" 122 | elif self.config.data.dataset == "LSUN": 123 | name = f"lsun_{self.config.data.category}" 124 | elif self.config.data.dataset == 'CelebA_HQ': 125 | name = 'celeba_hq' 126 | else: 127 | raise ValueError 128 | if name != 'celeba_hq': 129 | ckpt = get_ckpt_path(f"ema_{name}", prefix=self.args.exp) 130 | print("Loading checkpoint {}".format(ckpt)) 131 | elif name == 'celeba_hq': 132 | ckpt = os.path.join(self.args.exp, "logs/celeba/celeba_hq.ckpt") 133 | if not os.path.exists(ckpt): 134 | download('https://image-editing-test-12345.s3-us-west-2.amazonaws.com/checkpoints/celeba_hq.ckpt', 135 | ckpt) 136 | else: 137 | raise ValueError 138 | model.load_state_dict(torch.load(ckpt, map_location=self.device)) 139 | model.to(self.device) 140 | model = torch.nn.DataParallel(model) 141 | 142 | elif self.config.model.type == 'openai': 143 | config_dict = vars(self.config.model) 144 | model = create_model(**config_dict) 145 | if self.config.model.use_fp16: 146 | model.convert_to_fp16() 147 | if self.config.model.class_cond: 148 | ckpt = os.path.join(self.args.exp, 'logs/imagenet/%dx%d_diffusion.pt' % ( 149 | self.config.data.image_size, self.config.data.image_size)) 150 | if not os.path.exists(ckpt): 151 | download( 152 | 'https://openaipublic.blob.core.windows.net/diffusion/jul-2021/%dx%d_diffusion_uncond.pt' % ( 153 | self.config.data.image_size, self.config.data.image_size), ckpt) 154 | else: 155 | ckpt = os.path.join(self.args.exp, "logs/imagenet/256x256_diffusion_uncond.pt") 156 | if not os.path.exists(ckpt): 157 | download( 158 | 'https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion_uncond.pt', 159 | ckpt) 160 | 161 | model.load_state_dict(torch.load(ckpt, map_location=self.device)) 162 | model.to(self.device) 163 | model.eval() 164 | model = torch.nn.DataParallel(model) 165 | 166 | if self.config.model.class_cond: 167 | ckpt = os.path.join(self.args.exp, 'logs/imagenet/%dx%d_classifier.pt' % ( 168 | self.config.data.image_size, self.config.data.image_size)) 169 | if not os.path.exists(ckpt): 170 | image_size = self.config.data.image_size 171 | download( 172 | 'https://openaipublic.blob.core.windows.net/diffusion/jul-2021/%dx%d_classifier.pt' % image_size, 173 | ckpt) 174 | classifier = create_classifier(**args_to_dict(self.config.classifier, classifier_defaults().keys())) 175 | classifier.load_state_dict(torch.load(ckpt, map_location=self.device)) 176 | classifier.to(self.device) 177 | if self.config.classifier.classifier_use_fp16: 178 | classifier.convert_to_fp16() 179 | classifier.eval() 180 | classifier = torch.nn.DataParallel(classifier) 181 | 182 | import torch.nn.functional as F 183 | def cond_fn(x, t, y): 184 | with torch.enable_grad(): 185 | x_in = x.detach().requires_grad_(True) 186 | logits = classifier(x_in, t) 187 | log_probs = F.log_softmax(logits, dim=-1) 188 | selected = log_probs[range(len(logits)), y.view(-1)] 189 | return torch.autograd.grad(selected.sum(), x_in)[0] * self.config.classifier.classifier_scale 190 | 191 | cls_fn = cond_fn 192 | 193 | if simplified: 194 | print('Run Simplified DDNM, without SVD.', 195 | f'{self.config.time_travel.T_sampling} sampling steps.', 196 | f'travel_length = {self.config.time_travel.travel_length},', 197 | f'travel_repeat = {self.config.time_travel.travel_repeat}.', 198 | f'Task: {self.args.deg}.' 199 | ) 200 | self.simplified_ddnm_plus(model, cls_fn) 201 | else: 202 | print('Run SVD-based DDNM.', 203 | f'{self.config.time_travel.T_sampling} sampling steps.', 204 | f'travel_length = {self.config.time_travel.travel_length},', 205 | f'travel_repeat = {self.config.time_travel.travel_repeat}.', 206 | f'Task: {self.args.deg}.' 207 | ) 208 | self.svd_based_ddnm_plus(model, cls_fn) 209 | 210 | 211 | def simplified_ddnm_plus(self, model, cls_fn): 212 | args, config = self.args, self.config 213 | 214 | dataset, test_dataset = get_dataset(args, config) 215 | 216 | device_count = torch.cuda.device_count() 217 | 218 | if args.subset_start >= 0 and args.subset_end > 0: 219 | assert args.subset_end > args.subset_start 220 | test_dataset = torch.utils.data.Subset(test_dataset, range(args.subset_start, args.subset_end)) 221 | else: 222 | args.subset_start = 0 223 | args.subset_end = len(test_dataset) 224 | 225 | print(f'Dataset has size {len(test_dataset)}') 226 | 227 | def seed_worker(worker_id): 228 | worker_seed = args.seed % 2 ** 32 229 | np.random.seed(worker_seed) 230 | random.seed(worker_seed) 231 | 232 | g = torch.Generator() 233 | g.manual_seed(args.seed) 234 | val_loader = data.DataLoader( 235 | test_dataset, 236 | batch_size=config.sampling.batch_size, 237 | shuffle=True, 238 | num_workers=config.data.num_workers, 239 | worker_init_fn=seed_worker, 240 | generator=g, 241 | ) 242 | 243 | # get degradation operator 244 | print("args.deg:",args.deg) 245 | if args.deg =='colorization': 246 | A = lambda z: color2gray(z) 247 | Ap = lambda z: gray2color(z) 248 | elif args.deg =='denoising': 249 | A = lambda z: z 250 | Ap = A 251 | elif args.deg =='sr_averagepooling': 252 | scale=round(args.deg_scale) 253 | A = torch.nn.AdaptiveAvgPool2d((256//scale,256//scale)) 254 | Ap = lambda z: MeanUpsample(z,scale) 255 | elif args.deg =='inpainting': 256 | loaded = np.load("exp/inp_masks/mask.npy") 257 | mask = torch.from_numpy(loaded).to(self.device) 258 | A = lambda z: z*mask 259 | Ap = A 260 | elif args.deg =='mask_color_sr': 261 | loaded = np.load("exp/inp_masks/mask.npy") 262 | mask = torch.from_numpy(loaded).to(self.device) 263 | A1 = lambda z: z*mask 264 | A1p = A1 265 | 266 | A2 = lambda z: color2gray(z) 267 | A2p = lambda z: gray2color(z) 268 | 269 | scale=round(args.deg_scale) 270 | A3 = torch.nn.AdaptiveAvgPool2d((256//scale,256//scale)) 271 | A3p = lambda z: MeanUpsample(z,scale) 272 | 273 | A = lambda z: A3(A2(A1(z))) 274 | Ap = lambda z: A1p(A2p(A3p(z))) 275 | elif args.deg =='diy': 276 | # design your own degradation 277 | loaded = np.load("exp/inp_masks/mask.npy") 278 | mask = torch.from_numpy(loaded).to(self.device) 279 | A1 = lambda z: z*mask 280 | A1p = A1 281 | 282 | A2 = lambda z: color2gray(z) 283 | A2p = lambda z: gray2color(z) 284 | 285 | scale=args.deg_scale 286 | A3 = torch.nn.AdaptiveAvgPool2d((256//scale,256//scale)) 287 | A3p = lambda z: MeanUpsample(z,scale) 288 | 289 | A = lambda z: A3(A2(A1(z))) 290 | Ap = lambda z: A1p(A2p(A3p(z))) 291 | else: 292 | raise NotImplementedError("degradation type not supported") 293 | 294 | args.sigma_y = 2 * args.sigma_y #to account for scaling to [-1,1] 295 | sigma_y = args.sigma_y 296 | 297 | print(f'Start from {args.subset_start}') 298 | idx_init = args.subset_start 299 | idx_so_far = args.subset_start 300 | avg_psnr = 0.0 301 | pbar = tqdm.tqdm(val_loader) 302 | for x_orig, classes in pbar: 303 | x_orig = x_orig.to(self.device) 304 | x_orig = data_transform(self.config, x_orig) 305 | 306 | y = A(x_orig) 307 | 308 | if config.sampling.batch_size!=1: 309 | raise ValueError("please change the config file to set batch size as 1") 310 | 311 | Apy = Ap(y) 312 | 313 | os.makedirs(os.path.join(self.args.image_folder, "Apy"), exist_ok=True) 314 | for i in range(len(Apy)): 315 | tvu.save_image( 316 | inverse_data_transform(config, Apy[i]), 317 | os.path.join(self.args.image_folder, f"Apy/Apy_{idx_so_far + i}.png") 318 | ) 319 | tvu.save_image( 320 | inverse_data_transform(config, x_orig[i]), 321 | os.path.join(self.args.image_folder, f"Apy/orig_{idx_so_far + i}.png") 322 | ) 323 | 324 | # init x_T 325 | x = torch.randn( 326 | y.shape[0], 327 | config.data.channels, 328 | config.data.image_size, 329 | config.data.image_size, 330 | device=self.device, 331 | ) 332 | 333 | with torch.no_grad(): 334 | skip = config.diffusion.num_diffusion_timesteps//config.time_travel.T_sampling 335 | n = x.size(0) 336 | x0_preds = [] 337 | xs = [x] 338 | 339 | times = get_schedule_jump(config.time_travel.T_sampling, 340 | config.time_travel.travel_length, 341 | config.time_travel.travel_repeat, 342 | ) 343 | time_pairs = list(zip(times[:-1], times[1:])) 344 | 345 | 346 | # reverse diffusion sampling 347 | for i, j in tqdm.tqdm(time_pairs): 348 | i, j = i*skip, j*skip 349 | if j<0: j=-1 350 | 351 | if j < i: # normal sampling 352 | t = (torch.ones(n) * i).to(x.device) 353 | next_t = (torch.ones(n) * j).to(x.device) 354 | at = compute_alpha(self.betas, t.long()) 355 | at_next = compute_alpha(self.betas, next_t.long()) 356 | sigma_t = (1 - at_next**2).sqrt() 357 | xt = xs[-1].to('cuda') 358 | 359 | et = model(xt, t) 360 | 361 | if et.size(1) == 6: 362 | et = et[:, :3] 363 | 364 | # Eq. 12 365 | x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt() 366 | 367 | # Eq. 19 368 | if sigma_t >= at_next*sigma_y: 369 | lambda_t = 1. 370 | gamma_t = (sigma_t**2 - (at_next*sigma_y)**2).sqrt() 371 | else: 372 | lambda_t = (sigma_t)/(at_next*sigma_y) 373 | gamma_t = 0. 374 | 375 | # Eq. 17 376 | x0_t_hat = x0_t - lambda_t*Ap(A(x0_t) - y) 377 | 378 | eta = self.args.eta 379 | 380 | c1 = (1 - at_next).sqrt() * eta 381 | c2 = (1 - at_next).sqrt() * ((1 - eta ** 2) ** 0.5) 382 | 383 | # different from the paper, we use DDIM here instead of DDPM 384 | xt_next = at_next.sqrt() * x0_t_hat + gamma_t * (c1 * torch.randn_like(x0_t) + c2 * et) 385 | 386 | x0_preds.append(x0_t.to('cpu')) 387 | xs.append(xt_next.to('cpu')) 388 | else: # time-travel back 389 | next_t = (torch.ones(n) * j).to(x.device) 390 | at_next = compute_alpha(self.betas, next_t.long()) 391 | x0_t = x0_preds[-1].to('cuda') 392 | 393 | xt_next = at_next.sqrt() * x0_t + torch.randn_like(x0_t) * (1 - at_next).sqrt() 394 | 395 | xs.append(xt_next.to('cpu')) 396 | 397 | x = xs[-1] 398 | 399 | x = [inverse_data_transform(config, xi) for xi in x] 400 | 401 | tvu.save_image( 402 | x[0], os.path.join(self.args.image_folder, f"{idx_so_far + j}_{0}.png") 403 | ) 404 | orig = inverse_data_transform(config, x_orig[0]) 405 | mse = torch.mean((x[0].to(self.device) - orig) ** 2) 406 | psnr = 10 * torch.log10(1 / mse) 407 | avg_psnr += psnr 408 | 409 | idx_so_far += y.shape[0] 410 | 411 | pbar.set_description("PSNR: %.2f" % (avg_psnr / (idx_so_far - idx_init))) 412 | 413 | avg_psnr = avg_psnr / (idx_so_far - idx_init) 414 | print("Total Average PSNR: %.2f" % avg_psnr) 415 | print("Number of samples: %d" % (idx_so_far - idx_init)) 416 | 417 | 418 | 419 | def svd_based_ddnm_plus(self, model, cls_fn): 420 | args, config = self.args, self.config 421 | 422 | dataset, test_dataset = get_dataset(args, config) 423 | 424 | device_count = torch.cuda.device_count() 425 | 426 | if args.subset_start >= 0 and args.subset_end > 0: 427 | assert args.subset_end > args.subset_start 428 | test_dataset = torch.utils.data.Subset(test_dataset, range(args.subset_start, args.subset_end)) 429 | else: 430 | args.subset_start = 0 431 | args.subset_end = len(test_dataset) 432 | 433 | print(f'Dataset has size {len(test_dataset)}') 434 | 435 | def seed_worker(worker_id): 436 | worker_seed = args.seed % 2 ** 32 437 | np.random.seed(worker_seed) 438 | random.seed(worker_seed) 439 | 440 | g = torch.Generator() 441 | g.manual_seed(args.seed) 442 | val_loader = data.DataLoader( 443 | test_dataset, 444 | batch_size=config.sampling.batch_size, 445 | shuffle=True, 446 | num_workers=config.data.num_workers, 447 | worker_init_fn=seed_worker, 448 | generator=g, 449 | ) 450 | 451 | # get degradation matrix 452 | deg = args.deg 453 | A_funcs = None 454 | if deg == 'cs_walshhadamard': 455 | compress_by = round(1/args.deg_scale) 456 | from functions.svd_operators import WalshHadamardCS 457 | A_funcs = WalshHadamardCS(config.data.channels, self.config.data.image_size, compress_by, 458 | torch.randperm(self.config.data.image_size ** 2, device=self.device), self.device) 459 | elif deg == 'cs_blockbased': 460 | cs_ratio = args.deg_scale 461 | from functions.svd_operators import CS 462 | A_funcs = CS(config.data.channels, self.config.data.image_size, cs_ratio, self.device) 463 | elif deg == 'inpainting': 464 | from functions.svd_operators import Inpainting 465 | loaded = np.load("exp/inp_masks/mask.npy") 466 | mask = torch.from_numpy(loaded).to(self.device).reshape(-1) 467 | missing_r = torch.nonzero(mask == 0).long().reshape(-1) * 3 468 | missing_g = missing_r + 1 469 | missing_b = missing_g + 1 470 | missing = torch.cat([missing_r, missing_g, missing_b], dim=0) 471 | A_funcs = Inpainting(config.data.channels, config.data.image_size, missing, self.device) 472 | elif deg == 'denoising': 473 | from functions.svd_operators import Denoising 474 | A_funcs = Denoising(config.data.channels, self.config.data.image_size, self.device) 475 | elif deg == 'colorization': 476 | from functions.svd_operators import Colorization 477 | A_funcs = Colorization(config.data.image_size, self.device) 478 | elif deg == 'sr_averagepooling': 479 | blur_by = int(args.deg_scale) 480 | from functions.svd_operators import SuperResolution 481 | A_funcs = SuperResolution(config.data.channels, config.data.image_size, blur_by, self.device) 482 | elif deg == 'sr_bicubic': 483 | factor = int(args.deg_scale) 484 | from functions.svd_operators import SRConv 485 | def bicubic_kernel(x, a=-0.5): 486 | if abs(x) <= 1: 487 | return (a + 2) * abs(x) ** 3 - (a + 3) * abs(x) ** 2 + 1 488 | elif 1 < abs(x) and abs(x) < 2: 489 | return a * abs(x) ** 3 - 5 * a * abs(x) ** 2 + 8 * a * abs(x) - 4 * a 490 | else: 491 | return 0 492 | k = np.zeros((factor * 4)) 493 | for i in range(factor * 4): 494 | x = (1 / factor) * (i - np.floor(factor * 4 / 2) + 0.5) 495 | k[i] = bicubic_kernel(x) 496 | k = k / np.sum(k) 497 | kernel = torch.from_numpy(k).float().to(self.device) 498 | A_funcs = SRConv(kernel / kernel.sum(), \ 499 | config.data.channels, self.config.data.image_size, self.device, stride=factor) 500 | elif deg == 'deblur_uni': 501 | from functions.svd_operators import Deblurring 502 | A_funcs = Deblurring(torch.Tensor([1 / 9] * 9).to(self.device), config.data.channels, 503 | self.config.data.image_size, self.device) 504 | elif deg == 'deblur_gauss': 505 | from functions.svd_operators import Deblurring 506 | sigma = 10 507 | pdf = lambda x: torch.exp(torch.Tensor([-0.5 * (x / sigma) ** 2])) 508 | kernel = torch.Tensor([pdf(-2), pdf(-1), pdf(0), pdf(1), pdf(2)]).to(self.device) 509 | A_funcs = Deblurring(kernel / kernel.sum(), config.data.channels, self.config.data.image_size, self.device) 510 | elif deg == 'deblur_aniso': 511 | from functions.svd_operators import Deblurring2D 512 | sigma = 20 513 | pdf = lambda x: torch.exp(torch.Tensor([-0.5 * (x / sigma) ** 2])) 514 | kernel2 = torch.Tensor([pdf(-4), pdf(-3), pdf(-2), pdf(-1), pdf(0), pdf(1), pdf(2), pdf(3), pdf(4)]).to( 515 | self.device) 516 | sigma = 1 517 | pdf = lambda x: torch.exp(torch.Tensor([-0.5 * (x / sigma) ** 2])) 518 | kernel1 = torch.Tensor([pdf(-4), pdf(-3), pdf(-2), pdf(-1), pdf(0), pdf(1), pdf(2), pdf(3), pdf(4)]).to( 519 | self.device) 520 | A_funcs = Deblurring2D(kernel1 / kernel1.sum(), kernel2 / kernel2.sum(), config.data.channels, 521 | self.config.data.image_size, self.device) 522 | else: 523 | raise ValueError("degradation type not supported") 524 | args.sigma_y = 2 * args.sigma_y #to account for scaling to [-1,1] 525 | sigma_y = args.sigma_y 526 | 527 | print(f'Start from {args.subset_start}') 528 | idx_init = args.subset_start 529 | idx_so_far = args.subset_start 530 | avg_psnr = 0.0 531 | pbar = tqdm.tqdm(val_loader) 532 | for x_orig, classes in pbar: 533 | x_orig = x_orig.to(self.device) 534 | x_orig = data_transform(self.config, x_orig) 535 | 536 | y = A_funcs.A(x_orig) 537 | 538 | b, hwc = y.size() 539 | if 'color' in deg: 540 | hw = hwc / 1 541 | h = w = int(hw ** 0.5) 542 | y = y.reshape((b, 1, h, w)) 543 | elif 'inp' in deg or 'cs' in deg: 544 | pass 545 | else: 546 | hw = hwc / 3 547 | h = w = int(hw ** 0.5) 548 | y = y.reshape((b, 3, h, w)) 549 | 550 | if self.args.add_noise: # for denoising test 551 | y = get_gaussian_noisy_img(y, sigma_y) 552 | 553 | y = y.reshape((b, hwc)) 554 | 555 | Apy = A_funcs.A_pinv(y).view(y.shape[0], config.data.channels, self.config.data.image_size, 556 | self.config.data.image_size) 557 | 558 | if deg[:6] == 'deblur': 559 | Apy = y.view(y.shape[0], config.data.channels, self.config.data.image_size, 560 | self.config.data.image_size) 561 | elif deg == 'colorization': 562 | Apy = y.view(y.shape[0], 1, self.config.data.image_size, self.config.data.image_size).repeat(1,3,1,1) 563 | elif deg == 'inpainting': 564 | Apy += A_funcs.A_pinv(A_funcs.A(torch.ones_like(Apy))).reshape(*Apy.shape) - 1 565 | 566 | os.makedirs(os.path.join(self.args.image_folder, "Apy"), exist_ok=True) 567 | for i in range(len(Apy)): 568 | tvu.save_image( 569 | inverse_data_transform(config, Apy[i]), 570 | os.path.join(self.args.image_folder, f"Apy/Apy_{idx_so_far + i}.png") 571 | ) 572 | tvu.save_image( 573 | inverse_data_transform(config, x_orig[i]), 574 | os.path.join(self.args.image_folder, f"Apy/orig_{idx_so_far + i}.png") 575 | ) 576 | 577 | #Start DDIM 578 | x = torch.randn( 579 | y.shape[0], 580 | config.data.channels, 581 | config.data.image_size, 582 | config.data.image_size, 583 | device=self.device, 584 | ) 585 | 586 | with torch.no_grad(): 587 | if sigma_y==0.: # noise-free case, turn to ddnm 588 | x, _ = ddnm_diffusion(x, model, self.betas, self.args.eta, A_funcs, y, cls_fn=cls_fn, classes=classes, config=config) 589 | else: # noisy case, turn to ddnm+ 590 | x, _ = ddnm_plus_diffusion(x, model, self.betas, self.args.eta, A_funcs, y, sigma_y, cls_fn=cls_fn, classes=classes, config=config) 591 | 592 | x = [inverse_data_transform(config, xi) for xi in x] 593 | 594 | 595 | for j in range(x[0].size(0)): 596 | tvu.save_image( 597 | x[0][j], os.path.join(self.args.image_folder, f"{idx_so_far + j}_{0}.png") 598 | ) 599 | orig = inverse_data_transform(config, x_orig[j]) 600 | mse = torch.mean((x[0][j].to(self.device) - orig) ** 2) 601 | psnr = 10 * torch.log10(1 / mse) 602 | avg_psnr += psnr 603 | 604 | idx_so_far += y.shape[0] 605 | 606 | pbar.set_description("PSNR: %.2f" % (avg_psnr / (idx_so_far - idx_init))) 607 | 608 | avg_psnr = avg_psnr / (idx_so_far - idx_init) 609 | print("Total Average PSNR: %.2f" % avg_psnr) 610 | print("Number of samples: %d" % (idx_so_far - idx_init)) 611 | 612 | # Code form RePaint 613 | def get_schedule_jump(T_sampling, travel_length, travel_repeat): 614 | jumps = {} 615 | for j in range(0, T_sampling - travel_length, travel_length): 616 | jumps[j] = travel_repeat - 1 617 | 618 | t = T_sampling 619 | ts = [] 620 | 621 | while t >= 1: 622 | t = t-1 623 | ts.append(t) 624 | 625 | if jumps.get(t, 0) > 0: 626 | jumps[t] = jumps[t] - 1 627 | for _ in range(travel_length): 628 | t = t + 1 629 | ts.append(t) 630 | 631 | ts.append(-1) 632 | 633 | _check_times(ts, -1, T_sampling) 634 | return ts 635 | 636 | def _check_times(times, t_0, T_sampling): 637 | # Check end 638 | assert times[0] > times[1], (times[0], times[1]) 639 | 640 | # Check beginning 641 | assert times[-1] == -1, times[-1] 642 | 643 | # Steplength = 1 644 | for t_last, t_cur in zip(times[:-1], times[1:]): 645 | assert abs(t_last - t_cur) == 1, (t_last, t_cur) 646 | 647 | # Value range 648 | for t in times: 649 | assert t >= t_0, (t, t_0) 650 | assert t <= T_sampling, (t, T_sampling) 651 | 652 | def compute_alpha(beta, t): 653 | beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0) 654 | a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1) 655 | return a 656 | -------------------------------------------------------------------------------- /guided_diffusion/fp16_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers to train with 16-bit precision. 3 | """ 4 | 5 | import numpy as np 6 | import torch as th 7 | import torch.nn as nn 8 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 9 | 10 | from . import logger 11 | 12 | INITIAL_LOG_LOSS_SCALE = 20.0 13 | 14 | 15 | def convert_module_to_f16(l): 16 | """ 17 | Convert primitive modules to float16. 18 | """ 19 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 20 | l.weight.data = l.weight.data.half() 21 | if l.bias is not None: 22 | l.bias.data = l.bias.data.half() 23 | 24 | 25 | def convert_module_to_f32(l): 26 | """ 27 | Convert primitive modules to float32, undoing convert_module_to_f16(). 28 | """ 29 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 30 | l.weight.data = l.weight.data.float() 31 | if l.bias is not None: 32 | l.bias.data = l.bias.data.float() 33 | 34 | 35 | def make_master_params(param_groups_and_shapes): 36 | """ 37 | Copy model parameters into a (differently-shaped) list of full-precision 38 | parameters. 39 | """ 40 | master_params = [] 41 | for param_group, shape in param_groups_and_shapes: 42 | master_param = nn.Parameter( 43 | _flatten_dense_tensors( 44 | [param.detach().float() for (_, param) in param_group] 45 | ).view(shape) 46 | ) 47 | master_param.requires_grad = True 48 | master_params.append(master_param) 49 | return master_params 50 | 51 | 52 | def model_grads_to_master_grads(param_groups_and_shapes, master_params): 53 | """ 54 | Copy the gradients from the model parameters into the master parameters 55 | from make_master_params(). 56 | """ 57 | for master_param, (param_group, shape) in zip( 58 | master_params, param_groups_and_shapes 59 | ): 60 | master_param.grad = _flatten_dense_tensors( 61 | [param_grad_or_zeros(param) for (_, param) in param_group] 62 | ).view(shape) 63 | 64 | 65 | def master_params_to_model_params(param_groups_and_shapes, master_params): 66 | """ 67 | Copy the master parameter data back into the model parameters. 68 | """ 69 | # Without copying to a list, if a generator is passed, this will 70 | # silently not copy any parameters. 71 | for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes): 72 | for (_, param), unflat_master_param in zip( 73 | param_group, unflatten_master_params(param_group, master_param.view(-1)) 74 | ): 75 | param.detach().copy_(unflat_master_param) 76 | 77 | 78 | def unflatten_master_params(param_group, master_param): 79 | return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group]) 80 | 81 | 82 | def get_param_groups_and_shapes(named_model_params): 83 | named_model_params = list(named_model_params) 84 | scalar_vector_named_params = ( 85 | [(n, p) for (n, p) in named_model_params if p.ndim <= 1], 86 | (-1), 87 | ) 88 | matrix_named_params = ( 89 | [(n, p) for (n, p) in named_model_params if p.ndim > 1], 90 | (1, -1), 91 | ) 92 | return [scalar_vector_named_params, matrix_named_params] 93 | 94 | 95 | def master_params_to_state_dict( 96 | model, param_groups_and_shapes, master_params, use_fp16 97 | ): 98 | if use_fp16: 99 | state_dict = model.state_dict() 100 | for master_param, (param_group, _) in zip( 101 | master_params, param_groups_and_shapes 102 | ): 103 | for (name, _), unflat_master_param in zip( 104 | param_group, unflatten_master_params(param_group, master_param.view(-1)) 105 | ): 106 | assert name in state_dict 107 | state_dict[name] = unflat_master_param 108 | else: 109 | state_dict = model.state_dict() 110 | for i, (name, _value) in enumerate(model.named_parameters()): 111 | assert name in state_dict 112 | state_dict[name] = master_params[i] 113 | return state_dict 114 | 115 | 116 | def state_dict_to_master_params(model, state_dict, use_fp16): 117 | if use_fp16: 118 | named_model_params = [ 119 | (name, state_dict[name]) for name, _ in model.named_parameters() 120 | ] 121 | param_groups_and_shapes = get_param_groups_and_shapes(named_model_params) 122 | master_params = make_master_params(param_groups_and_shapes) 123 | else: 124 | master_params = [state_dict[name] for name, _ in model.named_parameters()] 125 | return master_params 126 | 127 | 128 | def zero_master_grads(master_params): 129 | for param in master_params: 130 | param.grad = None 131 | 132 | 133 | def zero_grad(model_params): 134 | for param in model_params: 135 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group 136 | if param.grad is not None: 137 | param.grad.detach_() 138 | param.grad.zero_() 139 | 140 | 141 | def param_grad_or_zeros(param): 142 | if param.grad is not None: 143 | return param.grad.data.detach() 144 | else: 145 | return th.zeros_like(param) 146 | 147 | 148 | class MixedPrecisionTrainer: 149 | def __init__( 150 | self, 151 | *, 152 | model, 153 | use_fp16=False, 154 | fp16_scale_growth=1e-3, 155 | initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE, 156 | ): 157 | self.model = model 158 | self.use_fp16 = use_fp16 159 | self.fp16_scale_growth = fp16_scale_growth 160 | 161 | self.model_params = list(self.model.parameters()) 162 | self.master_params = self.model_params 163 | self.param_groups_and_shapes = None 164 | self.lg_loss_scale = initial_lg_loss_scale 165 | 166 | if self.use_fp16: 167 | self.param_groups_and_shapes = get_param_groups_and_shapes( 168 | self.model.named_parameters() 169 | ) 170 | self.master_params = make_master_params(self.param_groups_and_shapes) 171 | self.model.convert_to_fp16() 172 | 173 | def zero_grad(self): 174 | zero_grad(self.model_params) 175 | 176 | def backward(self, loss: th.Tensor): 177 | if self.use_fp16: 178 | loss_scale = 2 ** self.lg_loss_scale 179 | (loss * loss_scale).backward() 180 | else: 181 | loss.backward() 182 | 183 | def optimize(self, opt: th.optim.Optimizer): 184 | if self.use_fp16: 185 | return self._optimize_fp16(opt) 186 | else: 187 | return self._optimize_normal(opt) 188 | 189 | def _optimize_fp16(self, opt: th.optim.Optimizer): 190 | logger.logkv_mean("lg_loss_scale", self.lg_loss_scale) 191 | model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params) 192 | grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale) 193 | if check_overflow(grad_norm): 194 | self.lg_loss_scale -= 1 195 | logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}") 196 | zero_master_grads(self.master_params) 197 | return False 198 | 199 | logger.logkv_mean("grad_norm", grad_norm) 200 | logger.logkv_mean("param_norm", param_norm) 201 | 202 | self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale)) 203 | opt.step() 204 | zero_master_grads(self.master_params) 205 | master_params_to_model_params(self.param_groups_and_shapes, self.master_params) 206 | self.lg_loss_scale += self.fp16_scale_growth 207 | return True 208 | 209 | def _optimize_normal(self, opt: th.optim.Optimizer): 210 | grad_norm, param_norm = self._compute_norms() 211 | logger.logkv_mean("grad_norm", grad_norm) 212 | logger.logkv_mean("param_norm", param_norm) 213 | opt.step() 214 | return True 215 | 216 | def _compute_norms(self, grad_scale=1.0): 217 | grad_norm = 0.0 218 | param_norm = 0.0 219 | for p in self.master_params: 220 | with th.no_grad(): 221 | param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2 222 | if p.grad is not None: 223 | grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2 224 | return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm) 225 | 226 | def master_params_to_state_dict(self, master_params): 227 | return master_params_to_state_dict( 228 | self.model, self.param_groups_and_shapes, master_params, self.use_fp16 229 | ) 230 | 231 | def state_dict_to_master_params(self, state_dict): 232 | return state_dict_to_master_params(self.model, state_dict, self.use_fp16) 233 | 234 | 235 | def check_overflow(value): 236 | return (value == float("inf")) or (value == -float("inf")) or (value != value) 237 | -------------------------------------------------------------------------------- /guided_diffusion/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Logger copied from OpenAI baselines to avoid extra RL-based dependencies: 3 | https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py 4 | """ 5 | 6 | import os 7 | import sys 8 | import shutil 9 | import os.path as osp 10 | import json 11 | import time 12 | import datetime 13 | import tempfile 14 | import warnings 15 | from collections import defaultdict 16 | from contextlib import contextmanager 17 | 18 | DEBUG = 10 19 | INFO = 20 20 | WARN = 30 21 | ERROR = 40 22 | 23 | DISABLED = 50 24 | 25 | 26 | class KVWriter(object): 27 | def writekvs(self, kvs): 28 | raise NotImplementedError 29 | 30 | 31 | class SeqWriter(object): 32 | def writeseq(self, seq): 33 | raise NotImplementedError 34 | 35 | 36 | class HumanOutputFormat(KVWriter, SeqWriter): 37 | def __init__(self, filename_or_file): 38 | if isinstance(filename_or_file, str): 39 | self.file = open(filename_or_file, "wt") 40 | self.own_file = True 41 | else: 42 | assert hasattr(filename_or_file, "read"), ( 43 | "expected file or str, got %s" % filename_or_file 44 | ) 45 | self.file = filename_or_file 46 | self.own_file = False 47 | 48 | def writekvs(self, kvs): 49 | # Create strings for printing 50 | key2str = {} 51 | for (key, val) in sorted(kvs.items()): 52 | if hasattr(val, "__float__"): 53 | valstr = "%-8.3g" % val 54 | else: 55 | valstr = str(val) 56 | key2str[self._truncate(key)] = self._truncate(valstr) 57 | 58 | # Find max widths 59 | if len(key2str) == 0: 60 | print("WARNING: tried to write empty key-value dict") 61 | return 62 | else: 63 | keywidth = max(map(len, key2str.keys())) 64 | valwidth = max(map(len, key2str.values())) 65 | 66 | # Write out the data 67 | dashes = "-" * (keywidth + valwidth + 7) 68 | lines = [dashes] 69 | for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()): 70 | lines.append( 71 | "| %s%s | %s%s |" 72 | % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val))) 73 | ) 74 | lines.append(dashes) 75 | self.file.write("\n".join(lines) + "\n") 76 | 77 | # Flush the output to the file 78 | self.file.flush() 79 | 80 | def _truncate(self, s): 81 | maxlen = 30 82 | return s[: maxlen - 3] + "..." if len(s) > maxlen else s 83 | 84 | def writeseq(self, seq): 85 | seq = list(seq) 86 | for (i, elem) in enumerate(seq): 87 | self.file.write(elem) 88 | if i < len(seq) - 1: # add space unless this is the last one 89 | self.file.write(" ") 90 | self.file.write("\n") 91 | self.file.flush() 92 | 93 | def close(self): 94 | if self.own_file: 95 | self.file.close() 96 | 97 | 98 | class JSONOutputFormat(KVWriter): 99 | def __init__(self, filename): 100 | self.file = open(filename, "wt") 101 | 102 | def writekvs(self, kvs): 103 | for k, v in sorted(kvs.items()): 104 | if hasattr(v, "dtype"): 105 | kvs[k] = float(v) 106 | self.file.write(json.dumps(kvs) + "\n") 107 | self.file.flush() 108 | 109 | def close(self): 110 | self.file.close() 111 | 112 | 113 | class CSVOutputFormat(KVWriter): 114 | def __init__(self, filename): 115 | self.file = open(filename, "w+t") 116 | self.keys = [] 117 | self.sep = "," 118 | 119 | def writekvs(self, kvs): 120 | # Add our current row to the history 121 | extra_keys = list(kvs.keys() - self.keys) 122 | extra_keys.sort() 123 | if extra_keys: 124 | self.keys.extend(extra_keys) 125 | self.file.seek(0) 126 | lines = self.file.readlines() 127 | self.file.seek(0) 128 | for (i, k) in enumerate(self.keys): 129 | if i > 0: 130 | self.file.write(",") 131 | self.file.write(k) 132 | self.file.write("\n") 133 | for line in lines[1:]: 134 | self.file.write(line[:-1]) 135 | self.file.write(self.sep * len(extra_keys)) 136 | self.file.write("\n") 137 | for (i, k) in enumerate(self.keys): 138 | if i > 0: 139 | self.file.write(",") 140 | v = kvs.get(k) 141 | if v is not None: 142 | self.file.write(str(v)) 143 | self.file.write("\n") 144 | self.file.flush() 145 | 146 | def close(self): 147 | self.file.close() 148 | 149 | 150 | class TensorBoardOutputFormat(KVWriter): 151 | """ 152 | Dumps key/value pairs into TensorBoard's numeric format. 153 | """ 154 | 155 | def __init__(self, dir): 156 | os.makedirs(dir, exist_ok=True) 157 | self.dir = dir 158 | self.step = 1 159 | prefix = "events" 160 | path = osp.join(osp.abspath(dir), prefix) 161 | import tensorflow as tf 162 | from tensorflow.python import pywrap_tensorflow 163 | from tensorflow.core.util import event_pb2 164 | from tensorflow.python.util import compat 165 | 166 | self.tf = tf 167 | self.event_pb2 = event_pb2 168 | self.pywrap_tensorflow = pywrap_tensorflow 169 | self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path)) 170 | 171 | def writekvs(self, kvs): 172 | def summary_val(k, v): 173 | kwargs = {"tag": k, "simple_value": float(v)} 174 | return self.tf.Summary.Value(**kwargs) 175 | 176 | summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()]) 177 | event = self.event_pb2.Event(wall_time=time.time(), summary=summary) 178 | event.step = ( 179 | self.step 180 | ) # is there any reason why you'd want to specify the step? 181 | self.writer.WriteEvent(event) 182 | self.writer.Flush() 183 | self.step += 1 184 | 185 | def close(self): 186 | if self.writer: 187 | self.writer.Close() 188 | self.writer = None 189 | 190 | 191 | def make_output_format(format, ev_dir, log_suffix=""): 192 | os.makedirs(ev_dir, exist_ok=True) 193 | if format == "stdout": 194 | return HumanOutputFormat(sys.stdout) 195 | elif format == "log": 196 | return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix)) 197 | elif format == "json": 198 | return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix)) 199 | elif format == "csv": 200 | return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix)) 201 | elif format == "tensorboard": 202 | return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix)) 203 | else: 204 | raise ValueError("Unknown format specified: %s" % (format,)) 205 | 206 | 207 | # ================================================================ 208 | # API 209 | # ================================================================ 210 | 211 | 212 | def logkv(key, val): 213 | """ 214 | Log a value of some diagnostic 215 | Call this once for each diagnostic quantity, each iteration 216 | If called many times, last value will be used. 217 | """ 218 | get_current().logkv(key, val) 219 | 220 | 221 | def logkv_mean(key, val): 222 | """ 223 | The same as logkv(), but if called many times, values averaged. 224 | """ 225 | get_current().logkv_mean(key, val) 226 | 227 | 228 | def logkvs(d): 229 | """ 230 | Log a dictionary of key-value pairs 231 | """ 232 | for (k, v) in d.items(): 233 | logkv(k, v) 234 | 235 | 236 | def dumpkvs(): 237 | """ 238 | Write all of the diagnostics from the current iteration 239 | """ 240 | return get_current().dumpkvs() 241 | 242 | 243 | def getkvs(): 244 | return get_current().name2val 245 | 246 | 247 | def log(*args, level=INFO): 248 | """ 249 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). 250 | """ 251 | get_current().log(*args, level=level) 252 | 253 | 254 | def debug(*args): 255 | log(*args, level=DEBUG) 256 | 257 | 258 | def info(*args): 259 | log(*args, level=INFO) 260 | 261 | 262 | def warn(*args): 263 | log(*args, level=WARN) 264 | 265 | 266 | def error(*args): 267 | log(*args, level=ERROR) 268 | 269 | 270 | def set_level(level): 271 | """ 272 | Set logging threshold on current logger. 273 | """ 274 | get_current().set_level(level) 275 | 276 | 277 | def set_comm(comm): 278 | get_current().set_comm(comm) 279 | 280 | 281 | def get_dir(): 282 | """ 283 | Get directory that log files are being written to. 284 | will be None if there is no output directory (i.e., if you didn't call start) 285 | """ 286 | return get_current().get_dir() 287 | 288 | 289 | record_tabular = logkv 290 | dump_tabular = dumpkvs 291 | 292 | 293 | @contextmanager 294 | def profile_kv(scopename): 295 | logkey = "wait_" + scopename 296 | tstart = time.time() 297 | try: 298 | yield 299 | finally: 300 | get_current().name2val[logkey] += time.time() - tstart 301 | 302 | 303 | def profile(n): 304 | """ 305 | Usage: 306 | @profile("my_func") 307 | def my_func(): code 308 | """ 309 | 310 | def decorator_with_name(func): 311 | def func_wrapper(*args, **kwargs): 312 | with profile_kv(n): 313 | return func(*args, **kwargs) 314 | 315 | return func_wrapper 316 | 317 | return decorator_with_name 318 | 319 | 320 | # ================================================================ 321 | # Backend 322 | # ================================================================ 323 | 324 | 325 | def get_current(): 326 | if Logger.CURRENT is None: 327 | _configure_default_logger() 328 | 329 | return Logger.CURRENT 330 | 331 | 332 | class Logger(object): 333 | DEFAULT = None # A logger with no output files. (See right below class definition) 334 | # So that you can still log to the terminal without setting up any output files 335 | CURRENT = None # Current logger being used by the free functions above 336 | 337 | def __init__(self, dir, output_formats, comm=None): 338 | self.name2val = defaultdict(float) # values this iteration 339 | self.name2cnt = defaultdict(int) 340 | self.level = INFO 341 | self.dir = dir 342 | self.output_formats = output_formats 343 | self.comm = comm 344 | 345 | # Logging API, forwarded 346 | # ---------------------------------------- 347 | def logkv(self, key, val): 348 | self.name2val[key] = val 349 | 350 | def logkv_mean(self, key, val): 351 | oldval, cnt = self.name2val[key], self.name2cnt[key] 352 | self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1) 353 | self.name2cnt[key] = cnt + 1 354 | 355 | def dumpkvs(self): 356 | if self.comm is None: 357 | d = self.name2val 358 | else: 359 | d = mpi_weighted_mean( 360 | self.comm, 361 | { 362 | name: (val, self.name2cnt.get(name, 1)) 363 | for (name, val) in self.name2val.items() 364 | }, 365 | ) 366 | if self.comm.rank != 0: 367 | d["dummy"] = 1 # so we don't get a warning about empty dict 368 | out = d.copy() # Return the dict for unit testing purposes 369 | for fmt in self.output_formats: 370 | if isinstance(fmt, KVWriter): 371 | fmt.writekvs(d) 372 | self.name2val.clear() 373 | self.name2cnt.clear() 374 | return out 375 | 376 | def log(self, *args, level=INFO): 377 | if self.level <= level: 378 | self._do_log(args) 379 | 380 | # Configuration 381 | # ---------------------------------------- 382 | def set_level(self, level): 383 | self.level = level 384 | 385 | def set_comm(self, comm): 386 | self.comm = comm 387 | 388 | def get_dir(self): 389 | return self.dir 390 | 391 | def close(self): 392 | for fmt in self.output_formats: 393 | fmt.close() 394 | 395 | # Misc 396 | # ---------------------------------------- 397 | def _do_log(self, args): 398 | for fmt in self.output_formats: 399 | if isinstance(fmt, SeqWriter): 400 | fmt.writeseq(map(str, args)) 401 | 402 | 403 | def get_rank_without_mpi_import(): 404 | # check environment variables here instead of importing mpi4py 405 | # to avoid calling MPI_Init() when this module is imported 406 | for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]: 407 | if varname in os.environ: 408 | return int(os.environ[varname]) 409 | return 0 410 | 411 | 412 | def mpi_weighted_mean(comm, local_name2valcount): 413 | """ 414 | Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110 415 | Perform a weighted average over dicts that are each on a different node 416 | Input: local_name2valcount: dict mapping key -> (value, count) 417 | Returns: key -> mean 418 | """ 419 | all_name2valcount = comm.gather(local_name2valcount) 420 | if comm.rank == 0: 421 | name2sum = defaultdict(float) 422 | name2count = defaultdict(float) 423 | for n2vc in all_name2valcount: 424 | for (name, (val, count)) in n2vc.items(): 425 | try: 426 | val = float(val) 427 | except ValueError: 428 | if comm.rank == 0: 429 | warnings.warn( 430 | "WARNING: tried to compute mean on non-float {}={}".format( 431 | name, val 432 | ) 433 | ) 434 | else: 435 | name2sum[name] += val * count 436 | name2count[name] += count 437 | return {name: name2sum[name] / name2count[name] for name in name2sum} 438 | else: 439 | return {} 440 | 441 | 442 | def configure(dir=None, format_strs=None, comm=None, log_suffix=""): 443 | """ 444 | If comm is provided, average all numerical stats across that comm 445 | """ 446 | if dir is None: 447 | dir = os.getenv("OPENAI_LOGDIR") 448 | if dir is None: 449 | dir = osp.join( 450 | tempfile.gettempdir(), 451 | datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"), 452 | ) 453 | assert isinstance(dir, str) 454 | dir = os.path.expanduser(dir) 455 | os.makedirs(os.path.expanduser(dir), exist_ok=True) 456 | 457 | rank = get_rank_without_mpi_import() 458 | if rank > 0: 459 | log_suffix = log_suffix + "-rank%03i" % rank 460 | 461 | if format_strs is None: 462 | if rank == 0: 463 | format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",") 464 | else: 465 | format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",") 466 | format_strs = filter(None, format_strs) 467 | output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs] 468 | 469 | Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm) 470 | if output_formats: 471 | log("Logging to %s" % dir) 472 | 473 | 474 | def _configure_default_logger(): 475 | configure() 476 | Logger.DEFAULT = Logger.CURRENT 477 | 478 | 479 | def reset(): 480 | if Logger.CURRENT is not Logger.DEFAULT: 481 | Logger.CURRENT.close() 482 | Logger.CURRENT = Logger.DEFAULT 483 | log("Reset logger") 484 | 485 | 486 | @contextmanager 487 | def scoped_configure(dir=None, format_strs=None, comm=None): 488 | prevlogger = Logger.CURRENT 489 | configure(dir=dir, format_strs=format_strs, comm=comm) 490 | try: 491 | yield 492 | finally: 493 | Logger.CURRENT.close() 494 | Logger.CURRENT = prevlogger 495 | 496 | -------------------------------------------------------------------------------- /guided_diffusion/models.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | from huggingface_hub import PyTorchModelHubMixin 6 | 7 | 8 | def get_timestep_embedding(timesteps, embedding_dim): 9 | """ 10 | This matches the implementation in Denoising Diffusion Probabilistic Models: 11 | From Fairseq. 12 | Build sinusoidal embeddings. 13 | This matches the implementation in tensor2tensor, but differs slightly 14 | from the description in Section 3.5 of "Attention Is All You Need". 15 | """ 16 | assert len(timesteps.shape) == 1 17 | 18 | half_dim = embedding_dim // 2 19 | emb = math.log(10000) / (half_dim - 1) 20 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) 21 | emb = emb.to(device=timesteps.device) 22 | emb = timesteps.float()[:, None] * emb[None, :] 23 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) 24 | if embedding_dim % 2 == 1: # zero pad 25 | emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) 26 | return emb 27 | 28 | 29 | def nonlinearity(x): 30 | # swish 31 | return x*torch.sigmoid(x) 32 | 33 | 34 | def Normalize(in_channels): 35 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 36 | 37 | 38 | class Upsample(nn.Module): 39 | def __init__(self, in_channels, with_conv): 40 | super().__init__() 41 | self.with_conv = with_conv 42 | if self.with_conv: 43 | self.conv = torch.nn.Conv2d(in_channels, 44 | in_channels, 45 | kernel_size=3, 46 | stride=1, 47 | padding=1) 48 | 49 | def forward(self, x): 50 | x = torch.nn.functional.interpolate( 51 | x, scale_factor=2.0, mode="nearest") 52 | if self.with_conv: 53 | x = self.conv(x) 54 | return x 55 | 56 | 57 | class Downsample(nn.Module): 58 | def __init__(self, in_channels, with_conv): 59 | super().__init__() 60 | self.with_conv = with_conv 61 | if self.with_conv: 62 | # no asymmetric padding in torch conv, must do it ourselves 63 | self.conv = torch.nn.Conv2d(in_channels, 64 | in_channels, 65 | kernel_size=3, 66 | stride=2, 67 | padding=0) 68 | 69 | def forward(self, x): 70 | if self.with_conv: 71 | pad = (0, 1, 0, 1) 72 | x = torch.nn.functional.pad(x, pad, mode="constant", value=0) 73 | x = self.conv(x) 74 | else: 75 | x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) 76 | return x 77 | 78 | 79 | class ResnetBlock(nn.Module): 80 | def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, 81 | dropout, temb_channels=512): 82 | super().__init__() 83 | self.in_channels = in_channels 84 | out_channels = in_channels if out_channels is None else out_channels 85 | self.out_channels = out_channels 86 | self.use_conv_shortcut = conv_shortcut 87 | 88 | self.norm1 = Normalize(in_channels) 89 | self.conv1 = torch.nn.Conv2d(in_channels, 90 | out_channels, 91 | kernel_size=3, 92 | stride=1, 93 | padding=1) 94 | self.temb_proj = torch.nn.Linear(temb_channels, 95 | out_channels) 96 | self.norm2 = Normalize(out_channels) 97 | self.dropout = torch.nn.Dropout(dropout) 98 | self.conv2 = torch.nn.Conv2d(out_channels, 99 | out_channels, 100 | kernel_size=3, 101 | stride=1, 102 | padding=1) 103 | if self.in_channels != self.out_channels: 104 | if self.use_conv_shortcut: 105 | self.conv_shortcut = torch.nn.Conv2d(in_channels, 106 | out_channels, 107 | kernel_size=3, 108 | stride=1, 109 | padding=1) 110 | else: 111 | self.nin_shortcut = torch.nn.Conv2d(in_channels, 112 | out_channels, 113 | kernel_size=1, 114 | stride=1, 115 | padding=0) 116 | 117 | def forward(self, x, temb): 118 | h = x 119 | h = self.norm1(h) 120 | h = nonlinearity(h) 121 | h = self.conv1(h) 122 | 123 | h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] 124 | 125 | h = self.norm2(h) 126 | h = nonlinearity(h) 127 | h = self.dropout(h) 128 | h = self.conv2(h) 129 | 130 | if self.in_channels != self.out_channels: 131 | if self.use_conv_shortcut: 132 | x = self.conv_shortcut(x) 133 | else: 134 | x = self.nin_shortcut(x) 135 | 136 | return x+h 137 | 138 | 139 | class AttnBlock(nn.Module): 140 | def __init__(self, in_channels): 141 | super().__init__() 142 | self.in_channels = in_channels 143 | 144 | self.norm = Normalize(in_channels) 145 | self.q = torch.nn.Conv2d(in_channels, 146 | in_channels, 147 | kernel_size=1, 148 | stride=1, 149 | padding=0) 150 | self.k = torch.nn.Conv2d(in_channels, 151 | in_channels, 152 | kernel_size=1, 153 | stride=1, 154 | padding=0) 155 | self.v = torch.nn.Conv2d(in_channels, 156 | in_channels, 157 | kernel_size=1, 158 | stride=1, 159 | padding=0) 160 | self.proj_out = torch.nn.Conv2d(in_channels, 161 | in_channels, 162 | kernel_size=1, 163 | stride=1, 164 | padding=0) 165 | 166 | def forward(self, x): 167 | h_ = x 168 | h_ = self.norm(h_) 169 | q = self.q(h_) 170 | k = self.k(h_) 171 | v = self.v(h_) 172 | 173 | # compute attention 174 | b, c, h, w = q.shape 175 | q = q.reshape(b, c, h*w) 176 | q = q.permute(0, 2, 1) # b,hw,c 177 | k = k.reshape(b, c, h*w) # b,c,hw 178 | w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] 179 | w_ = w_ * (int(c)**(-0.5)) 180 | w_ = torch.nn.functional.softmax(w_, dim=2) 181 | 182 | # attend to values 183 | v = v.reshape(b, c, h*w) 184 | w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) 185 | # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] 186 | h_ = torch.bmm(v, w_) 187 | h_ = h_.reshape(b, c, h, w) 188 | 189 | h_ = self.proj_out(h_) 190 | 191 | return x+h_ 192 | 193 | 194 | class Model(nn.Module, PyTorchModelHubMixin): 195 | def __init__(self, config): 196 | super().__init__() 197 | self.config = config 198 | ch, out_ch, ch_mult = config.model.ch, config.model.out_ch, tuple(config.model.ch_mult) 199 | num_res_blocks = config.model.num_res_blocks 200 | attn_resolutions = config.model.attn_resolutions 201 | dropout = config.model.dropout 202 | in_channels = config.model.in_channels 203 | resolution = config.data.image_size 204 | resamp_with_conv = config.model.resamp_with_conv 205 | num_timesteps = config.diffusion.num_diffusion_timesteps 206 | 207 | if config.model.type == 'bayesian': 208 | self.logvar = nn.Parameter(torch.zeros(num_timesteps)) 209 | 210 | self.ch = ch 211 | self.temb_ch = self.ch*4 212 | self.num_resolutions = len(ch_mult) 213 | self.num_res_blocks = num_res_blocks 214 | self.resolution = resolution 215 | self.in_channels = in_channels 216 | 217 | # timestep embedding 218 | self.temb = nn.Module() 219 | self.temb.dense = nn.ModuleList([ 220 | torch.nn.Linear(self.ch, 221 | self.temb_ch), 222 | torch.nn.Linear(self.temb_ch, 223 | self.temb_ch), 224 | ]) 225 | 226 | # downsampling 227 | self.conv_in = torch.nn.Conv2d(in_channels, 228 | self.ch, 229 | kernel_size=3, 230 | stride=1, 231 | padding=1) 232 | 233 | curr_res = resolution 234 | in_ch_mult = (1,)+ch_mult 235 | self.down = nn.ModuleList() 236 | block_in = None 237 | for i_level in range(self.num_resolutions): 238 | block = nn.ModuleList() 239 | attn = nn.ModuleList() 240 | block_in = ch*in_ch_mult[i_level] 241 | block_out = ch*ch_mult[i_level] 242 | for i_block in range(self.num_res_blocks): 243 | block.append(ResnetBlock(in_channels=block_in, 244 | out_channels=block_out, 245 | temb_channels=self.temb_ch, 246 | dropout=dropout)) 247 | block_in = block_out 248 | if curr_res in attn_resolutions: 249 | attn.append(AttnBlock(block_in)) 250 | down = nn.Module() 251 | down.block = block 252 | down.attn = attn 253 | if i_level != self.num_resolutions-1: 254 | down.downsample = Downsample(block_in, resamp_with_conv) 255 | curr_res = curr_res // 2 256 | self.down.append(down) 257 | 258 | # middle 259 | self.mid = nn.Module() 260 | self.mid.block_1 = ResnetBlock(in_channels=block_in, 261 | out_channels=block_in, 262 | temb_channels=self.temb_ch, 263 | dropout=dropout) 264 | self.mid.attn_1 = AttnBlock(block_in) 265 | self.mid.block_2 = ResnetBlock(in_channels=block_in, 266 | out_channels=block_in, 267 | temb_channels=self.temb_ch, 268 | dropout=dropout) 269 | 270 | # upsampling 271 | self.up = nn.ModuleList() 272 | for i_level in reversed(range(self.num_resolutions)): 273 | block = nn.ModuleList() 274 | attn = nn.ModuleList() 275 | block_out = ch*ch_mult[i_level] 276 | skip_in = ch*ch_mult[i_level] 277 | for i_block in range(self.num_res_blocks+1): 278 | if i_block == self.num_res_blocks: 279 | skip_in = ch*in_ch_mult[i_level] 280 | block.append(ResnetBlock(in_channels=block_in+skip_in, 281 | out_channels=block_out, 282 | temb_channels=self.temb_ch, 283 | dropout=dropout)) 284 | block_in = block_out 285 | if curr_res in attn_resolutions: 286 | attn.append(AttnBlock(block_in)) 287 | up = nn.Module() 288 | up.block = block 289 | up.attn = attn 290 | if i_level != 0: 291 | up.upsample = Upsample(block_in, resamp_with_conv) 292 | curr_res = curr_res * 2 293 | self.up.insert(0, up) # prepend to get consistent order 294 | 295 | # end 296 | self.norm_out = Normalize(block_in) 297 | self.conv_out = torch.nn.Conv2d(block_in, 298 | out_ch, 299 | kernel_size=3, 300 | stride=1, 301 | padding=1) 302 | 303 | def forward(self, x, t): 304 | assert x.shape[2] == x.shape[3] == self.resolution 305 | 306 | # timestep embedding 307 | temb = get_timestep_embedding(t, self.ch) 308 | temb = self.temb.dense[0](temb) 309 | temb = nonlinearity(temb) 310 | temb = self.temb.dense[1](temb) 311 | 312 | # downsampling 313 | hs = [self.conv_in(x)] 314 | for i_level in range(self.num_resolutions): 315 | for i_block in range(self.num_res_blocks): 316 | h = self.down[i_level].block[i_block](hs[-1], temb) 317 | if len(self.down[i_level].attn) > 0: 318 | h = self.down[i_level].attn[i_block](h) 319 | hs.append(h) 320 | if i_level != self.num_resolutions-1: 321 | hs.append(self.down[i_level].downsample(hs[-1])) 322 | 323 | # middle 324 | h = hs[-1] 325 | h = self.mid.block_1(h, temb) 326 | h = self.mid.attn_1(h) 327 | h = self.mid.block_2(h, temb) 328 | 329 | # upsampling 330 | for i_level in reversed(range(self.num_resolutions)): 331 | for i_block in range(self.num_res_blocks+1): 332 | h = self.up[i_level].block[i_block]( 333 | torch.cat([h, hs.pop()], dim=1), temb) 334 | if len(self.up[i_level].attn) > 0: 335 | h = self.up[i_level].attn[i_block](h) 336 | if i_level != 0: 337 | h = self.up[i_level].upsample(h) 338 | 339 | # end 340 | h = self.norm_out(h) 341 | h = nonlinearity(h) 342 | h = self.conv_out(h) 343 | return h 344 | -------------------------------------------------------------------------------- /guided_diffusion/nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utilities for neural networks. 3 | """ 4 | 5 | import math 6 | 7 | import torch as th 8 | import torch.nn as nn 9 | 10 | 11 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 12 | class SiLU(nn.Module): 13 | def forward(self, x): 14 | return x * th.sigmoid(x) 15 | 16 | 17 | class GroupNorm32(nn.GroupNorm): 18 | def forward(self, x): 19 | return super().forward(x.float()).type(x.dtype) 20 | 21 | 22 | def conv_nd(dims, *args, **kwargs): 23 | """ 24 | Create a 1D, 2D, or 3D convolution module. 25 | """ 26 | if dims == 1: 27 | return nn.Conv1d(*args, **kwargs) 28 | elif dims == 2: 29 | return nn.Conv2d(*args, **kwargs) 30 | elif dims == 3: 31 | return nn.Conv3d(*args, **kwargs) 32 | raise ValueError(f"unsupported dimensions: {dims}") 33 | 34 | 35 | def linear(*args, **kwargs): 36 | """ 37 | Create a linear module. 38 | """ 39 | return nn.Linear(*args, **kwargs) 40 | 41 | 42 | def avg_pool_nd(dims, *args, **kwargs): 43 | """ 44 | Create a 1D, 2D, or 3D average pooling module. 45 | """ 46 | if dims == 1: 47 | return nn.AvgPool1d(*args, **kwargs) 48 | elif dims == 2: 49 | return nn.AvgPool2d(*args, **kwargs) 50 | elif dims == 3: 51 | return nn.AvgPool3d(*args, **kwargs) 52 | raise ValueError(f"unsupported dimensions: {dims}") 53 | 54 | 55 | def update_ema(target_params, source_params, rate=0.99): 56 | """ 57 | Update target parameters to be closer to those of source parameters using 58 | an exponential moving average. 59 | 60 | :param target_params: the target parameter sequence. 61 | :param source_params: the source parameter sequence. 62 | :param rate: the EMA rate (closer to 1 means slower). 63 | """ 64 | for targ, src in zip(target_params, source_params): 65 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 66 | 67 | 68 | def zero_module(module): 69 | """ 70 | Zero out the parameters of a module and return it. 71 | """ 72 | for p in module.parameters(): 73 | p.detach().zero_() 74 | return module 75 | 76 | 77 | def scale_module(module, scale): 78 | """ 79 | Scale the parameters of a module and return it. 80 | """ 81 | for p in module.parameters(): 82 | p.detach().mul_(scale) 83 | return module 84 | 85 | 86 | def mean_flat(tensor): 87 | """ 88 | Take the mean over all non-batch dimensions. 89 | """ 90 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 91 | 92 | 93 | def normalization(channels): 94 | """ 95 | Make a standard normalization layer. 96 | 97 | :param channels: number of input channels. 98 | :return: an nn.Module for normalization. 99 | """ 100 | return GroupNorm32(32, channels) 101 | 102 | 103 | def timestep_embedding(timesteps, dim, max_period=10000): 104 | """ 105 | Create sinusoidal timestep embeddings. 106 | 107 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 108 | These may be fractional. 109 | :param dim: the dimension of the output. 110 | :param max_period: controls the minimum frequency of the embeddings. 111 | :return: an [N x dim] Tensor of positional embeddings. 112 | """ 113 | half = dim // 2 114 | freqs = th.exp( 115 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half 116 | ).to(device=timesteps.device) 117 | args = timesteps[:, None].float() * freqs[None] 118 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 119 | if dim % 2: 120 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 121 | return embedding 122 | 123 | 124 | def checkpoint(func, inputs, params, flag): 125 | """ 126 | Evaluate a function without caching intermediate activations, allowing for 127 | reduced memory at the expense of extra compute in the backward pass. 128 | 129 | :param func: the function to evaluate. 130 | :param inputs: the argument sequence to pass to `func`. 131 | :param params: a sequence of parameters `func` depends on but does not 132 | explicitly take as arguments. 133 | :param flag: if False, disable gradient checkpointing. 134 | """ 135 | if flag: 136 | args = tuple(inputs) + tuple(params) 137 | return CheckpointFunction.apply(func, len(inputs), *args) 138 | else: 139 | return func(*inputs) 140 | 141 | 142 | class CheckpointFunction(th.autograd.Function): 143 | @staticmethod 144 | def forward(ctx, run_function, length, *args): 145 | ctx.run_function = run_function 146 | ctx.input_tensors = list(args[:length]) 147 | ctx.input_params = list(args[length:]) 148 | with th.no_grad(): 149 | output_tensors = ctx.run_function(*ctx.input_tensors) 150 | return output_tensors 151 | 152 | @staticmethod 153 | def backward(ctx, *output_grads): 154 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 155 | with th.enable_grad(): 156 | # Fixes a bug where the first op in run_function modifies the 157 | # Tensor storage in place, which is not allowed for detach()'d 158 | # Tensors. 159 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 160 | output_tensors = ctx.run_function(*shallow_copies) 161 | input_grads = th.autograd.grad( 162 | output_tensors, 163 | ctx.input_tensors + ctx.input_params, 164 | output_grads, 165 | allow_unused=True, 166 | ) 167 | del ctx.input_tensors 168 | del ctx.input_params 169 | del output_tensors 170 | return (None, None) + input_grads 171 | -------------------------------------------------------------------------------- /guided_diffusion/script_util.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import inspect 3 | 4 | #from . import gaussian_diffusion as gd 5 | #from .respace import SpacedDiffusion, space_timesteps 6 | from .unet import SuperResModel, UNetModel, EncoderUNetModel 7 | 8 | NUM_CLASSES = 1000 9 | 10 | 11 | def diffusion_defaults(): 12 | """ 13 | Defaults for image and classifier training. 14 | """ 15 | return dict( 16 | learn_sigma=False, 17 | diffusion_steps=1000, 18 | noise_schedule="linear", 19 | timestep_respacing="", 20 | use_kl=False, 21 | predict_xstart=False, 22 | rescale_timesteps=False, 23 | rescale_learned_sigmas=False, 24 | ) 25 | 26 | 27 | def classifier_defaults(): 28 | """ 29 | Defaults for classifier models. 30 | """ 31 | return dict( 32 | image_size=64, 33 | classifier_use_fp16=False, 34 | classifier_width=128, 35 | classifier_depth=2, 36 | classifier_attention_resolutions="32,16,8", # 16 37 | classifier_use_scale_shift_norm=True, # False 38 | classifier_resblock_updown=True, # False 39 | classifier_pool="attention", 40 | ) 41 | 42 | 43 | def model_and_diffusion_defaults(): 44 | """ 45 | Defaults for image training. 46 | """ 47 | res = dict( 48 | image_size=64, 49 | num_channels=128, 50 | num_res_blocks=2, 51 | num_heads=4, 52 | num_heads_upsample=-1, 53 | num_head_channels=-1, 54 | attention_resolutions="16,8", 55 | channel_mult="", 56 | dropout=0.0, 57 | class_cond=False, 58 | use_checkpoint=False, 59 | use_scale_shift_norm=True, 60 | resblock_updown=False, 61 | use_fp16=False, 62 | use_new_attention_order=False, 63 | ) 64 | res.update(diffusion_defaults()) 65 | return res 66 | 67 | 68 | def classifier_and_diffusion_defaults(): 69 | res = classifier_defaults() 70 | res.update(diffusion_defaults()) 71 | return res 72 | 73 | 74 | def create_model_and_diffusion( 75 | image_size, 76 | class_cond, 77 | learn_sigma, 78 | num_channels, 79 | num_res_blocks, 80 | channel_mult, 81 | num_heads, 82 | num_head_channels, 83 | num_heads_upsample, 84 | attention_resolutions, 85 | dropout, 86 | diffusion_steps, 87 | noise_schedule, 88 | timestep_respacing, 89 | use_kl, 90 | predict_xstart, 91 | rescale_timesteps, 92 | rescale_learned_sigmas, 93 | use_checkpoint, 94 | use_scale_shift_norm, 95 | resblock_updown, 96 | use_fp16, 97 | use_new_attention_order, 98 | ): 99 | model = create_model( 100 | image_size, 101 | num_channels, 102 | num_res_blocks, 103 | channel_mult=channel_mult, 104 | learn_sigma=learn_sigma, 105 | class_cond=class_cond, 106 | use_checkpoint=use_checkpoint, 107 | attention_resolutions=attention_resolutions, 108 | num_heads=num_heads, 109 | num_head_channels=num_head_channels, 110 | num_heads_upsample=num_heads_upsample, 111 | use_scale_shift_norm=use_scale_shift_norm, 112 | dropout=dropout, 113 | resblock_updown=resblock_updown, 114 | use_fp16=use_fp16, 115 | use_new_attention_order=use_new_attention_order, 116 | ) 117 | diffusion = create_gaussian_diffusion( 118 | steps=diffusion_steps, 119 | learn_sigma=learn_sigma, 120 | noise_schedule=noise_schedule, 121 | use_kl=use_kl, 122 | predict_xstart=predict_xstart, 123 | rescale_timesteps=rescale_timesteps, 124 | rescale_learned_sigmas=rescale_learned_sigmas, 125 | timestep_respacing=timestep_respacing, 126 | ) 127 | return model, diffusion 128 | 129 | 130 | def create_model( 131 | image_size, 132 | num_channels, 133 | num_res_blocks, 134 | channel_mult="", 135 | learn_sigma=False, 136 | class_cond=False, 137 | use_checkpoint=False, 138 | attention_resolutions="16", 139 | num_heads=1, 140 | num_head_channels=-1, 141 | num_heads_upsample=-1, 142 | use_scale_shift_norm=False, 143 | dropout=0, 144 | resblock_updown=False, 145 | use_fp16=False, 146 | use_new_attention_order=False, 147 | **kwargs 148 | ): 149 | if channel_mult == "": 150 | if image_size == 512: 151 | channel_mult = (0.5, 1, 1, 2, 2, 4, 4) 152 | elif image_size == 256: 153 | channel_mult = (1, 1, 2, 2, 4, 4) 154 | elif image_size == 128: 155 | channel_mult = (1, 1, 2, 3, 4) 156 | elif image_size == 64: 157 | channel_mult = (1, 2, 3, 4) 158 | else: 159 | raise ValueError(f"unsupported image size: {image_size}") 160 | else: 161 | channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(",")) 162 | 163 | attention_ds = [] 164 | for res in attention_resolutions.split(","): 165 | attention_ds.append(image_size // int(res)) 166 | 167 | return UNetModel( 168 | image_size=image_size, 169 | in_channels=3, 170 | model_channels=num_channels, 171 | out_channels=(3 if not learn_sigma else 6), 172 | num_res_blocks=num_res_blocks, 173 | attention_resolutions=tuple(attention_ds), 174 | dropout=dropout, 175 | channel_mult=channel_mult, 176 | num_classes=(NUM_CLASSES if class_cond else None), 177 | use_checkpoint=use_checkpoint, 178 | use_fp16=use_fp16, 179 | num_heads=num_heads, 180 | num_head_channels=num_head_channels, 181 | num_heads_upsample=num_heads_upsample, 182 | use_scale_shift_norm=use_scale_shift_norm, 183 | resblock_updown=resblock_updown, 184 | use_new_attention_order=use_new_attention_order, 185 | ) 186 | 187 | 188 | def create_classifier_and_diffusion( 189 | image_size, 190 | classifier_use_fp16, 191 | classifier_width, 192 | classifier_depth, 193 | classifier_attention_resolutions, 194 | classifier_use_scale_shift_norm, 195 | classifier_resblock_updown, 196 | classifier_pool, 197 | learn_sigma, 198 | diffusion_steps, 199 | noise_schedule, 200 | timestep_respacing, 201 | use_kl, 202 | predict_xstart, 203 | rescale_timesteps, 204 | rescale_learned_sigmas, 205 | ): 206 | classifier = create_classifier( 207 | image_size, 208 | classifier_use_fp16, 209 | classifier_width, 210 | classifier_depth, 211 | classifier_attention_resolutions, 212 | classifier_use_scale_shift_norm, 213 | classifier_resblock_updown, 214 | classifier_pool, 215 | ) 216 | diffusion = create_gaussian_diffusion( 217 | steps=diffusion_steps, 218 | learn_sigma=learn_sigma, 219 | noise_schedule=noise_schedule, 220 | use_kl=use_kl, 221 | predict_xstart=predict_xstart, 222 | rescale_timesteps=rescale_timesteps, 223 | rescale_learned_sigmas=rescale_learned_sigmas, 224 | timestep_respacing=timestep_respacing, 225 | ) 226 | return classifier, diffusion 227 | 228 | 229 | def create_classifier( 230 | image_size, 231 | classifier_use_fp16, 232 | classifier_width, 233 | classifier_depth, 234 | classifier_attention_resolutions, 235 | classifier_use_scale_shift_norm, 236 | classifier_resblock_updown, 237 | classifier_pool, 238 | ): 239 | if image_size == 512: 240 | channel_mult = (0.5, 1, 1, 2, 2, 4, 4) 241 | elif image_size == 256: 242 | channel_mult = (1, 1, 2, 2, 4, 4) 243 | elif image_size == 128: 244 | channel_mult = (1, 1, 2, 3, 4) 245 | elif image_size == 64: 246 | channel_mult = (1, 2, 3, 4) 247 | else: 248 | raise ValueError(f"unsupported image size: {image_size}") 249 | 250 | attention_ds = [] 251 | for res in classifier_attention_resolutions.split(","): 252 | attention_ds.append(image_size // int(res)) 253 | 254 | return EncoderUNetModel( 255 | image_size=image_size, 256 | in_channels=3, 257 | model_channels=classifier_width, 258 | out_channels=1000, 259 | num_res_blocks=classifier_depth, 260 | attention_resolutions=tuple(attention_ds), 261 | channel_mult=channel_mult, 262 | use_fp16=classifier_use_fp16, 263 | num_head_channels=64, 264 | use_scale_shift_norm=classifier_use_scale_shift_norm, 265 | resblock_updown=classifier_resblock_updown, 266 | pool=classifier_pool, 267 | ) 268 | 269 | 270 | def sr_model_and_diffusion_defaults(): 271 | res = model_and_diffusion_defaults() 272 | res["large_size"] = 256 273 | res["small_size"] = 64 274 | arg_names = inspect.getfullargspec(sr_create_model_and_diffusion)[0] 275 | for k in res.copy().keys(): 276 | if k not in arg_names: 277 | del res[k] 278 | return res 279 | 280 | 281 | def sr_create_model_and_diffusion( 282 | large_size, 283 | small_size, 284 | class_cond, 285 | learn_sigma, 286 | num_channels, 287 | num_res_blocks, 288 | num_heads, 289 | num_head_channels, 290 | num_heads_upsample, 291 | attention_resolutions, 292 | dropout, 293 | diffusion_steps, 294 | noise_schedule, 295 | timestep_respacing, 296 | use_kl, 297 | predict_xstart, 298 | rescale_timesteps, 299 | rescale_learned_sigmas, 300 | use_checkpoint, 301 | use_scale_shift_norm, 302 | resblock_updown, 303 | use_fp16, 304 | ): 305 | model = sr_create_model( 306 | large_size, 307 | small_size, 308 | num_channels, 309 | num_res_blocks, 310 | learn_sigma=learn_sigma, 311 | class_cond=class_cond, 312 | use_checkpoint=use_checkpoint, 313 | attention_resolutions=attention_resolutions, 314 | num_heads=num_heads, 315 | num_head_channels=num_head_channels, 316 | num_heads_upsample=num_heads_upsample, 317 | use_scale_shift_norm=use_scale_shift_norm, 318 | dropout=dropout, 319 | resblock_updown=resblock_updown, 320 | use_fp16=use_fp16, 321 | ) 322 | diffusion = create_gaussian_diffusion( 323 | steps=diffusion_steps, 324 | learn_sigma=learn_sigma, 325 | noise_schedule=noise_schedule, 326 | use_kl=use_kl, 327 | predict_xstart=predict_xstart, 328 | rescale_timesteps=rescale_timesteps, 329 | rescale_learned_sigmas=rescale_learned_sigmas, 330 | timestep_respacing=timestep_respacing, 331 | ) 332 | return model, diffusion 333 | 334 | 335 | def sr_create_model( 336 | large_size, 337 | small_size, 338 | num_channels, 339 | num_res_blocks, 340 | learn_sigma, 341 | class_cond, 342 | use_checkpoint, 343 | attention_resolutions, 344 | num_heads, 345 | num_head_channels, 346 | num_heads_upsample, 347 | use_scale_shift_norm, 348 | dropout, 349 | resblock_updown, 350 | use_fp16, 351 | ): 352 | _ = small_size # hack to prevent unused variable 353 | 354 | if large_size == 512: 355 | channel_mult = (1, 1, 2, 2, 4, 4) 356 | elif large_size == 256: 357 | channel_mult = (1, 1, 2, 2, 4, 4) 358 | elif large_size == 64: 359 | channel_mult = (1, 2, 3, 4) 360 | else: 361 | raise ValueError(f"unsupported large size: {large_size}") 362 | 363 | attention_ds = [] 364 | for res in attention_resolutions.split(","): 365 | attention_ds.append(large_size // int(res)) 366 | 367 | return SuperResModel( 368 | image_size=large_size, 369 | in_channels=3, 370 | model_channels=num_channels, 371 | out_channels=(3 if not learn_sigma else 6), 372 | num_res_blocks=num_res_blocks, 373 | attention_resolutions=tuple(attention_ds), 374 | dropout=dropout, 375 | channel_mult=channel_mult, 376 | num_classes=(NUM_CLASSES if class_cond else None), 377 | use_checkpoint=use_checkpoint, 378 | num_heads=num_heads, 379 | num_head_channels=num_head_channels, 380 | num_heads_upsample=num_heads_upsample, 381 | use_scale_shift_norm=use_scale_shift_norm, 382 | resblock_updown=resblock_updown, 383 | use_fp16=use_fp16, 384 | ) 385 | 386 | 387 | def create_gaussian_diffusion( 388 | *, 389 | steps=1000, 390 | learn_sigma=False, 391 | sigma_small=False, 392 | noise_schedule="linear", 393 | use_kl=False, 394 | predict_xstart=False, 395 | rescale_timesteps=False, 396 | rescale_learned_sigmas=False, 397 | timestep_respacing="", 398 | ): 399 | betas = gd.get_named_beta_schedule(noise_schedule, steps) 400 | if use_kl: 401 | loss_type = gd.LossType.RESCALED_KL 402 | elif rescale_learned_sigmas: 403 | loss_type = gd.LossType.RESCALED_MSE 404 | else: 405 | loss_type = gd.LossType.MSE 406 | if not timestep_respacing: 407 | timestep_respacing = [steps] 408 | return SpacedDiffusion( 409 | use_timesteps=space_timesteps(steps, timestep_respacing), 410 | betas=betas, 411 | model_mean_type=( 412 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X 413 | ), 414 | model_var_type=( 415 | ( 416 | gd.ModelVarType.FIXED_LARGE 417 | if not sigma_small 418 | else gd.ModelVarType.FIXED_SMALL 419 | ) 420 | if not learn_sigma 421 | else gd.ModelVarType.LEARNED_RANGE 422 | ), 423 | loss_type=loss_type, 424 | rescale_timesteps=rescale_timesteps, 425 | ) 426 | 427 | 428 | def add_dict_to_argparser(parser, default_dict): 429 | for k, v in default_dict.items(): 430 | v_type = type(v) 431 | if v is None: 432 | v_type = str 433 | elif isinstance(v, bool): 434 | v_type = str2bool 435 | parser.add_argument(f"--{k}", default=v, type=v_type) 436 | 437 | 438 | def args_to_dict(args, keys): 439 | return {k: getattr(args, k) for k in keys} 440 | 441 | 442 | def str2bool(v): 443 | """ 444 | https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse 445 | """ 446 | if isinstance(v, bool): 447 | return v 448 | if v.lower() in ("yes", "true", "t", "y", "1"): 449 | return True 450 | elif v.lower() in ("no", "false", "f", "n", "0"): 451 | return False 452 | else: 453 | raise argparse.ArgumentTypeError("boolean value expected") 454 | -------------------------------------------------------------------------------- /guided_diffusion/unet.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | import math 4 | 5 | import numpy as np 6 | import torch as th 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from .fp16_util import convert_module_to_f16, convert_module_to_f32 11 | from .nn import ( 12 | checkpoint, 13 | conv_nd, 14 | linear, 15 | avg_pool_nd, 16 | zero_module, 17 | normalization, 18 | timestep_embedding, 19 | ) 20 | 21 | 22 | class AttentionPool2d(nn.Module): 23 | """ 24 | Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py 25 | """ 26 | 27 | def __init__( 28 | self, 29 | spacial_dim: int, 30 | embed_dim: int, 31 | num_heads_channels: int, 32 | output_dim: int = None, 33 | ): 34 | super().__init__() 35 | self.positional_embedding = nn.Parameter( 36 | th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5 37 | ) 38 | self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) 39 | self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) 40 | self.num_heads = embed_dim // num_heads_channels 41 | self.attention = QKVAttention(self.num_heads) 42 | 43 | def forward(self, x): 44 | b, c, *_spatial = x.shape 45 | x = x.reshape(b, c, -1) # NC(HW) 46 | x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) 47 | x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) 48 | x = self.qkv_proj(x) 49 | x = self.attention(x) 50 | x = self.c_proj(x) 51 | return x[:, :, 0] 52 | 53 | 54 | class TimestepBlock(nn.Module): 55 | """ 56 | Any module where forward() takes timestep embeddings as a second argument. 57 | """ 58 | 59 | @abstractmethod 60 | def forward(self, x, emb): 61 | """ 62 | Apply the module to `x` given `emb` timestep embeddings. 63 | """ 64 | 65 | 66 | class TimestepEmbedSequential(nn.Sequential, TimestepBlock): 67 | """ 68 | A sequential module that passes timestep embeddings to the children that 69 | support it as an extra input. 70 | """ 71 | 72 | def forward(self, x, emb): 73 | for layer in self: 74 | if isinstance(layer, TimestepBlock): 75 | x = layer(x, emb) 76 | else: 77 | x = layer(x) 78 | return x 79 | 80 | 81 | class Upsample(nn.Module): 82 | """ 83 | An upsampling layer with an optional convolution. 84 | 85 | :param channels: channels in the inputs and outputs. 86 | :param use_conv: a bool determining if a convolution is applied. 87 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 88 | upsampling occurs in the inner-two dimensions. 89 | """ 90 | 91 | def __init__(self, channels, use_conv, dims=2, out_channels=None): 92 | super().__init__() 93 | self.channels = channels 94 | self.out_channels = out_channels or channels 95 | self.use_conv = use_conv 96 | self.dims = dims 97 | if use_conv: 98 | self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1) 99 | 100 | def forward(self, x): 101 | assert x.shape[1] == self.channels 102 | if self.dims == 3: 103 | x = F.interpolate( 104 | x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" 105 | ) 106 | else: 107 | x = F.interpolate(x, scale_factor=2, mode="nearest") 108 | if self.use_conv: 109 | x = self.conv(x) 110 | return x 111 | 112 | 113 | class Downsample(nn.Module): 114 | """ 115 | A downsampling layer with an optional convolution. 116 | 117 | :param channels: channels in the inputs and outputs. 118 | :param use_conv: a bool determining if a convolution is applied. 119 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 120 | downsampling occurs in the inner-two dimensions. 121 | """ 122 | 123 | def __init__(self, channels, use_conv, dims=2, out_channels=None): 124 | super().__init__() 125 | self.channels = channels 126 | self.out_channels = out_channels or channels 127 | self.use_conv = use_conv 128 | self.dims = dims 129 | stride = 2 if dims != 3 else (1, 2, 2) 130 | if use_conv: 131 | self.op = conv_nd( 132 | dims, self.channels, self.out_channels, 3, stride=stride, padding=1 133 | ) 134 | else: 135 | assert self.channels == self.out_channels 136 | self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) 137 | 138 | def forward(self, x): 139 | assert x.shape[1] == self.channels 140 | return self.op(x) 141 | 142 | 143 | class ResBlock(TimestepBlock): 144 | """ 145 | A residual block that can optionally change the number of channels. 146 | 147 | :param channels: the number of input channels. 148 | :param emb_channels: the number of timestep embedding channels. 149 | :param dropout: the rate of dropout. 150 | :param out_channels: if specified, the number of out channels. 151 | :param use_conv: if True and out_channels is specified, use a spatial 152 | convolution instead of a smaller 1x1 convolution to change the 153 | channels in the skip connection. 154 | :param dims: determines if the signal is 1D, 2D, or 3D. 155 | :param use_checkpoint: if True, use gradient checkpointing on this module. 156 | :param up: if True, use this block for upsampling. 157 | :param down: if True, use this block for downsampling. 158 | """ 159 | 160 | def __init__( 161 | self, 162 | channels, 163 | emb_channels, 164 | dropout, 165 | out_channels=None, 166 | use_conv=False, 167 | use_scale_shift_norm=False, 168 | dims=2, 169 | use_checkpoint=False, 170 | up=False, 171 | down=False, 172 | ): 173 | super().__init__() 174 | self.channels = channels 175 | self.emb_channels = emb_channels 176 | self.dropout = dropout 177 | self.out_channels = out_channels or channels 178 | self.use_conv = use_conv 179 | self.use_checkpoint = use_checkpoint 180 | self.use_scale_shift_norm = use_scale_shift_norm 181 | 182 | self.in_layers = nn.Sequential( 183 | normalization(channels), 184 | nn.SiLU(), 185 | conv_nd(dims, channels, self.out_channels, 3, padding=1), 186 | ) 187 | 188 | self.updown = up or down 189 | 190 | if up: 191 | self.h_upd = Upsample(channels, False, dims) 192 | self.x_upd = Upsample(channels, False, dims) 193 | elif down: 194 | self.h_upd = Downsample(channels, False, dims) 195 | self.x_upd = Downsample(channels, False, dims) 196 | else: 197 | self.h_upd = self.x_upd = nn.Identity() 198 | 199 | self.emb_layers = nn.Sequential( 200 | nn.SiLU(), 201 | linear( 202 | emb_channels, 203 | 2 * self.out_channels if use_scale_shift_norm else self.out_channels, 204 | ), 205 | ) 206 | self.out_layers = nn.Sequential( 207 | normalization(self.out_channels), 208 | nn.SiLU(), 209 | nn.Dropout(p=dropout), 210 | zero_module( 211 | conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) 212 | ), 213 | ) 214 | 215 | if self.out_channels == channels: 216 | self.skip_connection = nn.Identity() 217 | elif use_conv: 218 | self.skip_connection = conv_nd( 219 | dims, channels, self.out_channels, 3, padding=1 220 | ) 221 | else: 222 | self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) 223 | 224 | def forward(self, x, emb): 225 | """ 226 | Apply the block to a Tensor, conditioned on a timestep embedding. 227 | 228 | :param x: an [N x C x ...] Tensor of features. 229 | :param emb: an [N x emb_channels] Tensor of timestep embeddings. 230 | :return: an [N x C x ...] Tensor of outputs. 231 | """ 232 | return checkpoint( 233 | self._forward, (x, emb), self.parameters(), self.use_checkpoint 234 | ) 235 | 236 | def _forward(self, x, emb): 237 | if self.updown: 238 | in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] 239 | h = in_rest(x) 240 | h = self.h_upd(h) 241 | x = self.x_upd(x) 242 | h = in_conv(h) 243 | else: 244 | h = self.in_layers(x) 245 | emb_out = self.emb_layers(emb).type(h.dtype) 246 | while len(emb_out.shape) < len(h.shape): 247 | emb_out = emb_out[..., None] 248 | if self.use_scale_shift_norm: 249 | out_norm, out_rest = self.out_layers[0], self.out_layers[1:] 250 | scale, shift = th.chunk(emb_out, 2, dim=1) 251 | h = out_norm(h) * (1 + scale) + shift 252 | h = out_rest(h) 253 | else: 254 | h = h + emb_out 255 | h = self.out_layers(h) 256 | return self.skip_connection(x) + h 257 | 258 | 259 | class AttentionBlock(nn.Module): 260 | """ 261 | An attention block that allows spatial positions to attend to each other. 262 | 263 | Originally ported from here, but adapted to the N-d case. 264 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. 265 | """ 266 | 267 | def __init__( 268 | self, 269 | channels, 270 | num_heads=1, 271 | num_head_channels=-1, 272 | use_checkpoint=False, 273 | use_new_attention_order=False, 274 | ): 275 | super().__init__() 276 | self.channels = channels 277 | if num_head_channels == -1: 278 | self.num_heads = num_heads 279 | else: 280 | assert ( 281 | channels % num_head_channels == 0 282 | ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" 283 | self.num_heads = channels // num_head_channels 284 | self.use_checkpoint = use_checkpoint 285 | self.norm = normalization(channels) 286 | self.qkv = conv_nd(1, channels, channels * 3, 1) 287 | if use_new_attention_order: 288 | # split qkv before split heads 289 | self.attention = QKVAttention(self.num_heads) 290 | else: 291 | # split heads before split qkv 292 | self.attention = QKVAttentionLegacy(self.num_heads) 293 | 294 | self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) 295 | 296 | def forward(self, x): 297 | return checkpoint(self._forward, (x,), self.parameters(), True) 298 | 299 | def _forward(self, x): 300 | b, c, *spatial = x.shape 301 | x = x.reshape(b, c, -1) 302 | qkv = self.qkv(self.norm(x)) 303 | h = self.attention(qkv) 304 | h = self.proj_out(h) 305 | return (x + h).reshape(b, c, *spatial) 306 | 307 | 308 | def count_flops_attn(model, _x, y): 309 | """ 310 | A counter for the `thop` package to count the operations in an 311 | attention operation. 312 | Meant to be used like: 313 | macs, params = thop.profile( 314 | model, 315 | inputs=(inputs, timestamps), 316 | custom_ops={QKVAttention: QKVAttention.count_flops}, 317 | ) 318 | """ 319 | b, c, *spatial = y[0].shape 320 | num_spatial = int(np.prod(spatial)) 321 | # We perform two matmuls with the same number of ops. 322 | # The first computes the weight matrix, the second computes 323 | # the combination of the value vectors. 324 | matmul_ops = 2 * b * (num_spatial ** 2) * c 325 | model.total_ops += th.DoubleTensor([matmul_ops]) 326 | 327 | 328 | class QKVAttentionLegacy(nn.Module): 329 | """ 330 | A module which performs QKV attention. Matches legacy QKVAttention + input/output heads shaping 331 | """ 332 | 333 | def __init__(self, n_heads): 334 | super().__init__() 335 | self.n_heads = n_heads 336 | 337 | def forward(self, qkv): 338 | """ 339 | Apply QKV attention. 340 | 341 | :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. 342 | :return: an [N x (H * C) x T] tensor after attention. 343 | """ 344 | bs, width, length = qkv.shape 345 | assert width % (3 * self.n_heads) == 0 346 | ch = width // (3 * self.n_heads) 347 | q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) 348 | scale = 1 / math.sqrt(math.sqrt(ch)) 349 | weight = th.einsum( 350 | "bct,bcs->bts", q * scale, k * scale 351 | ) # More stable with f16 than dividing afterwards 352 | weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) 353 | a = th.einsum("bts,bcs->bct", weight, v) 354 | return a.reshape(bs, -1, length) 355 | 356 | @staticmethod 357 | def count_flops(model, _x, y): 358 | return count_flops_attn(model, _x, y) 359 | 360 | 361 | class QKVAttention(nn.Module): 362 | """ 363 | A module which performs QKV attention and splits in a different order. 364 | """ 365 | 366 | def __init__(self, n_heads): 367 | super().__init__() 368 | self.n_heads = n_heads 369 | 370 | def forward(self, qkv): 371 | """ 372 | Apply QKV attention. 373 | 374 | :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. 375 | :return: an [N x (H * C) x T] tensor after attention. 376 | """ 377 | bs, width, length = qkv.shape 378 | assert width % (3 * self.n_heads) == 0 379 | ch = width // (3 * self.n_heads) 380 | q, k, v = qkv.chunk(3, dim=1) 381 | scale = 1 / math.sqrt(math.sqrt(ch)) 382 | weight = th.einsum( 383 | "bct,bcs->bts", 384 | (q * scale).view(bs * self.n_heads, ch, length), 385 | (k * scale).view(bs * self.n_heads, ch, length), 386 | ) # More stable with f16 than dividing afterwards 387 | weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) 388 | a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) 389 | return a.reshape(bs, -1, length) 390 | 391 | @staticmethod 392 | def count_flops(model, _x, y): 393 | return count_flops_attn(model, _x, y) 394 | 395 | 396 | class UNetModel(nn.Module): 397 | """ 398 | The full UNet model with attention and timestep embedding. 399 | 400 | :param in_channels: channels in the input Tensor. 401 | :param model_channels: base channel count for the model. 402 | :param out_channels: channels in the output Tensor. 403 | :param num_res_blocks: number of residual blocks per downsample. 404 | :param attention_resolutions: a collection of downsample rates at which 405 | attention will take place. May be a set, list, or tuple. 406 | For example, if this contains 4, then at 4x downsampling, attention 407 | will be used. 408 | :param dropout: the dropout probability. 409 | :param channel_mult: channel multiplier for each level of the UNet. 410 | :param conv_resample: if True, use learned convolutions for upsampling and 411 | downsampling. 412 | :param dims: determines if the signal is 1D, 2D, or 3D. 413 | :param num_classes: if specified (as an int), then this model will be 414 | class-conditional with `num_classes` classes. 415 | :param use_checkpoint: use gradient checkpointing to reduce memory usage. 416 | :param num_heads: the number of attention heads in each attention layer. 417 | :param num_heads_channels: if specified, ignore num_heads and instead use 418 | a fixed channel width per attention head. 419 | :param num_heads_upsample: works with num_heads to set a different number 420 | of heads for upsampling. Deprecated. 421 | :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. 422 | :param resblock_updown: use residual blocks for up/downsampling. 423 | :param use_new_attention_order: use a different attention pattern for potentially 424 | increased efficiency. 425 | """ 426 | 427 | def __init__( 428 | self, 429 | image_size, 430 | in_channels, 431 | model_channels, 432 | out_channels, 433 | num_res_blocks, 434 | attention_resolutions, 435 | dropout=0, 436 | channel_mult=(1, 2, 4, 8), 437 | conv_resample=True, 438 | dims=2, 439 | num_classes=None, 440 | use_checkpoint=False, 441 | use_fp16=False, 442 | num_heads=1, 443 | num_head_channels=-1, 444 | num_heads_upsample=-1, 445 | use_scale_shift_norm=False, 446 | resblock_updown=False, 447 | use_new_attention_order=False, 448 | **kwargs 449 | ): 450 | super().__init__() 451 | 452 | if num_heads_upsample == -1: 453 | num_heads_upsample = num_heads 454 | 455 | self.image_size = image_size 456 | self.in_channels = in_channels 457 | self.model_channels = model_channels 458 | self.out_channels = out_channels 459 | self.num_res_blocks = num_res_blocks 460 | self.attention_resolutions = attention_resolutions 461 | self.dropout = dropout 462 | self.channel_mult = channel_mult 463 | self.conv_resample = conv_resample 464 | self.num_classes = num_classes 465 | self.use_checkpoint = use_checkpoint 466 | self.dtype = th.float16 if use_fp16 else th.float32 467 | self.num_heads = num_heads 468 | self.num_head_channels = num_head_channels 469 | self.num_heads_upsample = num_heads_upsample 470 | 471 | time_embed_dim = model_channels * 4 472 | self.time_embed = nn.Sequential( 473 | linear(model_channels, time_embed_dim), 474 | nn.SiLU(), 475 | linear(time_embed_dim, time_embed_dim), 476 | ) 477 | 478 | if self.num_classes is not None: 479 | self.label_emb = nn.Embedding(num_classes, time_embed_dim) 480 | 481 | ch = input_ch = int(channel_mult[0] * model_channels) 482 | self.input_blocks = nn.ModuleList( 483 | [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))] 484 | ) 485 | self._feature_size = ch 486 | input_block_chans = [ch] 487 | ds = 1 488 | for level, mult in enumerate(channel_mult): 489 | for _ in range(num_res_blocks): 490 | layers = [ 491 | ResBlock( 492 | ch, 493 | time_embed_dim, 494 | dropout, 495 | out_channels=int(mult * model_channels), 496 | dims=dims, 497 | use_checkpoint=use_checkpoint, 498 | use_scale_shift_norm=use_scale_shift_norm, 499 | ) 500 | ] 501 | ch = int(mult * model_channels) 502 | if ds in attention_resolutions: 503 | layers.append( 504 | AttentionBlock( 505 | ch, 506 | use_checkpoint=use_checkpoint, 507 | num_heads=num_heads, 508 | num_head_channels=num_head_channels, 509 | use_new_attention_order=use_new_attention_order, 510 | ) 511 | ) 512 | self.input_blocks.append(TimestepEmbedSequential(*layers)) 513 | self._feature_size += ch 514 | input_block_chans.append(ch) 515 | if level != len(channel_mult) - 1: 516 | out_ch = ch 517 | self.input_blocks.append( 518 | TimestepEmbedSequential( 519 | ResBlock( 520 | ch, 521 | time_embed_dim, 522 | dropout, 523 | out_channels=out_ch, 524 | dims=dims, 525 | use_checkpoint=use_checkpoint, 526 | use_scale_shift_norm=use_scale_shift_norm, 527 | down=True, 528 | ) 529 | if resblock_updown 530 | else Downsample( 531 | ch, conv_resample, dims=dims, out_channels=out_ch 532 | ) 533 | ) 534 | ) 535 | ch = out_ch 536 | input_block_chans.append(ch) 537 | ds *= 2 538 | self._feature_size += ch 539 | 540 | self.middle_block = TimestepEmbedSequential( 541 | ResBlock( 542 | ch, 543 | time_embed_dim, 544 | dropout, 545 | dims=dims, 546 | use_checkpoint=use_checkpoint, 547 | use_scale_shift_norm=use_scale_shift_norm, 548 | ), 549 | AttentionBlock( 550 | ch, 551 | use_checkpoint=use_checkpoint, 552 | num_heads=num_heads, 553 | num_head_channels=num_head_channels, 554 | use_new_attention_order=use_new_attention_order, 555 | ), 556 | ResBlock( 557 | ch, 558 | time_embed_dim, 559 | dropout, 560 | dims=dims, 561 | use_checkpoint=use_checkpoint, 562 | use_scale_shift_norm=use_scale_shift_norm, 563 | ), 564 | ) 565 | self._feature_size += ch 566 | 567 | self.output_blocks = nn.ModuleList([]) 568 | for level, mult in list(enumerate(channel_mult))[::-1]: 569 | for i in range(num_res_blocks + 1): 570 | ich = input_block_chans.pop() 571 | layers = [ 572 | ResBlock( 573 | ch + ich, 574 | time_embed_dim, 575 | dropout, 576 | out_channels=int(model_channels * mult), 577 | dims=dims, 578 | use_checkpoint=use_checkpoint, 579 | use_scale_shift_norm=use_scale_shift_norm, 580 | ) 581 | ] 582 | ch = int(model_channels * mult) 583 | if ds in attention_resolutions: 584 | layers.append( 585 | AttentionBlock( 586 | ch, 587 | use_checkpoint=use_checkpoint, 588 | num_heads=num_heads_upsample, 589 | num_head_channels=num_head_channels, 590 | use_new_attention_order=use_new_attention_order, 591 | ) 592 | ) 593 | if level and i == num_res_blocks: 594 | out_ch = ch 595 | layers.append( 596 | ResBlock( 597 | ch, 598 | time_embed_dim, 599 | dropout, 600 | out_channels=out_ch, 601 | dims=dims, 602 | use_checkpoint=use_checkpoint, 603 | use_scale_shift_norm=use_scale_shift_norm, 604 | up=True, 605 | ) 606 | if resblock_updown 607 | else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) 608 | ) 609 | ds //= 2 610 | self.output_blocks.append(TimestepEmbedSequential(*layers)) 611 | self._feature_size += ch 612 | 613 | self.out = nn.Sequential( 614 | normalization(ch), 615 | nn.SiLU(), 616 | zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)), 617 | ) 618 | 619 | def convert_to_fp16(self): 620 | """ 621 | Convert the torso of the model to float16. 622 | """ 623 | self.input_blocks.apply(convert_module_to_f16) 624 | self.middle_block.apply(convert_module_to_f16) 625 | self.output_blocks.apply(convert_module_to_f16) 626 | 627 | def convert_to_fp32(self): 628 | """ 629 | Convert the torso of the model to float32. 630 | """ 631 | self.input_blocks.apply(convert_module_to_f32) 632 | self.middle_block.apply(convert_module_to_f32) 633 | self.output_blocks.apply(convert_module_to_f32) 634 | 635 | def forward(self, x, timesteps, y=None): 636 | """ 637 | Apply the model to an input batch. 638 | 639 | :param x: an [N x C x ...] Tensor of inputs. 640 | :param timesteps: a 1-D batch of timesteps. 641 | :param y: an [N] Tensor of labels, if class-conditional. 642 | :return: an [N x C x ...] Tensor of outputs. 643 | """ 644 | assert (y is not None) == ( 645 | self.num_classes is not None 646 | ), "must specify y if and only if the model is class-conditional" 647 | 648 | hs = [] 649 | emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) 650 | 651 | if self.num_classes is not None: 652 | assert y.shape == (x.shape[0],) 653 | emb = emb + self.label_emb(y) 654 | 655 | h = x.type(self.dtype) 656 | for module in self.input_blocks: 657 | h = module(h, emb) 658 | hs.append(h) 659 | h = self.middle_block(h, emb) 660 | for module in self.output_blocks: 661 | h = th.cat([h, hs.pop()], dim=1) 662 | h = module(h, emb) 663 | h = h.type(x.dtype) 664 | return self.out(h) 665 | 666 | 667 | class SuperResModel(UNetModel): 668 | """ 669 | A UNetModel that performs super-resolution. 670 | 671 | Expects an extra kwarg `low_res` to condition on a low-resolution image. 672 | """ 673 | 674 | def __init__(self, image_size, in_channels, *args, **kwargs): 675 | super().__init__(image_size, in_channels * 2, *args, **kwargs) 676 | 677 | def forward(self, x, timesteps, low_res=None, **kwargs): 678 | _, _, new_height, new_width = x.shape 679 | upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear") 680 | x = th.cat([x, upsampled], dim=1) 681 | return super().forward(x, timesteps, **kwargs) 682 | 683 | 684 | class EncoderUNetModel(nn.Module): 685 | """ 686 | The half UNet model with attention and timestep embedding. 687 | 688 | For usage, see UNet. 689 | """ 690 | 691 | def __init__( 692 | self, 693 | image_size, 694 | in_channels, 695 | model_channels, 696 | out_channels, 697 | num_res_blocks, 698 | attention_resolutions, 699 | dropout=0, 700 | channel_mult=(1, 2, 4, 8), 701 | conv_resample=True, 702 | dims=2, 703 | use_checkpoint=False, 704 | use_fp16=False, 705 | num_heads=1, 706 | num_head_channels=-1, 707 | num_heads_upsample=-1, 708 | use_scale_shift_norm=False, 709 | resblock_updown=False, 710 | use_new_attention_order=False, 711 | pool="adaptive", 712 | ): 713 | super().__init__() 714 | 715 | if num_heads_upsample == -1: 716 | num_heads_upsample = num_heads 717 | 718 | self.in_channels = in_channels 719 | self.model_channels = model_channels 720 | self.out_channels = out_channels 721 | self.num_res_blocks = num_res_blocks 722 | self.attention_resolutions = attention_resolutions 723 | self.dropout = dropout 724 | self.channel_mult = channel_mult 725 | self.conv_resample = conv_resample 726 | self.use_checkpoint = use_checkpoint 727 | self.dtype = th.float16 if use_fp16 else th.float32 728 | self.num_heads = num_heads 729 | self.num_head_channels = num_head_channels 730 | self.num_heads_upsample = num_heads_upsample 731 | 732 | time_embed_dim = model_channels * 4 733 | self.time_embed = nn.Sequential( 734 | linear(model_channels, time_embed_dim), 735 | nn.SiLU(), 736 | linear(time_embed_dim, time_embed_dim), 737 | ) 738 | 739 | ch = int(channel_mult[0] * model_channels) 740 | self.input_blocks = nn.ModuleList( 741 | [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))] 742 | ) 743 | self._feature_size = ch 744 | input_block_chans = [ch] 745 | ds = 1 746 | for level, mult in enumerate(channel_mult): 747 | for _ in range(num_res_blocks): 748 | layers = [ 749 | ResBlock( 750 | ch, 751 | time_embed_dim, 752 | dropout, 753 | out_channels=int(mult * model_channels), 754 | dims=dims, 755 | use_checkpoint=use_checkpoint, 756 | use_scale_shift_norm=use_scale_shift_norm, 757 | ) 758 | ] 759 | ch = int(mult * model_channels) 760 | if ds in attention_resolutions: 761 | layers.append( 762 | AttentionBlock( 763 | ch, 764 | use_checkpoint=use_checkpoint, 765 | num_heads=num_heads, 766 | num_head_channels=num_head_channels, 767 | use_new_attention_order=use_new_attention_order, 768 | ) 769 | ) 770 | self.input_blocks.append(TimestepEmbedSequential(*layers)) 771 | self._feature_size += ch 772 | input_block_chans.append(ch) 773 | if level != len(channel_mult) - 1: 774 | out_ch = ch 775 | self.input_blocks.append( 776 | TimestepEmbedSequential( 777 | ResBlock( 778 | ch, 779 | time_embed_dim, 780 | dropout, 781 | out_channels=out_ch, 782 | dims=dims, 783 | use_checkpoint=use_checkpoint, 784 | use_scale_shift_norm=use_scale_shift_norm, 785 | down=True, 786 | ) 787 | if resblock_updown 788 | else Downsample( 789 | ch, conv_resample, dims=dims, out_channels=out_ch 790 | ) 791 | ) 792 | ) 793 | ch = out_ch 794 | input_block_chans.append(ch) 795 | ds *= 2 796 | self._feature_size += ch 797 | 798 | self.middle_block = TimestepEmbedSequential( 799 | ResBlock( 800 | ch, 801 | time_embed_dim, 802 | dropout, 803 | dims=dims, 804 | use_checkpoint=use_checkpoint, 805 | use_scale_shift_norm=use_scale_shift_norm, 806 | ), 807 | AttentionBlock( 808 | ch, 809 | use_checkpoint=use_checkpoint, 810 | num_heads=num_heads, 811 | num_head_channels=num_head_channels, 812 | use_new_attention_order=use_new_attention_order, 813 | ), 814 | ResBlock( 815 | ch, 816 | time_embed_dim, 817 | dropout, 818 | dims=dims, 819 | use_checkpoint=use_checkpoint, 820 | use_scale_shift_norm=use_scale_shift_norm, 821 | ), 822 | ) 823 | self._feature_size += ch 824 | self.pool = pool 825 | if pool == "adaptive": 826 | self.out = nn.Sequential( 827 | normalization(ch), 828 | nn.SiLU(), 829 | nn.AdaptiveAvgPool2d((1, 1)), 830 | zero_module(conv_nd(dims, ch, out_channels, 1)), 831 | nn.Flatten(), 832 | ) 833 | elif pool == "attention": 834 | assert num_head_channels != -1 835 | self.out = nn.Sequential( 836 | normalization(ch), 837 | nn.SiLU(), 838 | AttentionPool2d( 839 | (image_size // ds), ch, num_head_channels, out_channels 840 | ), 841 | ) 842 | elif pool == "spatial": 843 | self.out = nn.Sequential( 844 | nn.Linear(self._feature_size, 2048), 845 | nn.ReLU(), 846 | nn.Linear(2048, self.out_channels), 847 | ) 848 | elif pool == "spatial_v2": 849 | self.out = nn.Sequential( 850 | nn.Linear(self._feature_size, 2048), 851 | normalization(2048), 852 | nn.SiLU(), 853 | nn.Linear(2048, self.out_channels), 854 | ) 855 | else: 856 | raise NotImplementedError(f"Unexpected {pool} pooling") 857 | 858 | def convert_to_fp16(self): 859 | """ 860 | Convert the torso of the model to float16. 861 | """ 862 | self.input_blocks.apply(convert_module_to_f16) 863 | self.middle_block.apply(convert_module_to_f16) 864 | 865 | def convert_to_fp32(self): 866 | """ 867 | Convert the torso of the model to float32. 868 | """ 869 | self.input_blocks.apply(convert_module_to_f32) 870 | self.middle_block.apply(convert_module_to_f32) 871 | 872 | def forward(self, x, timesteps): 873 | """ 874 | Apply the model to an input batch. 875 | 876 | :param x: an [N x C x ...] Tensor of inputs. 877 | :param timesteps: a 1-D batch of timesteps. 878 | :return: an [N x K] Tensor of outputs. 879 | """ 880 | emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) 881 | 882 | results = [] 883 | h = x.type(self.dtype) 884 | for module in self.input_blocks: 885 | h = module(h, emb) 886 | if self.pool.startswith("spatial"): 887 | results.append(h.type(x.dtype).mean(dim=(2, 3))) 888 | h = self.middle_block(h, emb) 889 | if self.pool.startswith("spatial"): 890 | results.append(h.type(x.dtype).mean(dim=(2, 3))) 891 | h = th.cat(results, axis=-1) 892 | return self.out(h) 893 | else: 894 | h = h.type(x.dtype) 895 | return self.out(h) 896 | -------------------------------------------------------------------------------- /inpainting.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import numpy as np 4 | import tqdm 5 | import torch 6 | from torch import nn 7 | import sys 8 | sys.path.insert(0,'./') 9 | from guided_diffusion.models import Model 10 | import random 11 | from ddim_inversion_utils import * 12 | from utils import * 13 | 14 | with open('configs/inpainting.yml', 'r') as f: 15 | task_config = yaml.safe_load(f) 16 | 17 | 18 | ### Reproducibility 19 | torch.set_printoptions(sci_mode=False) 20 | ensure_reproducibility(task_config['seed']) 21 | 22 | 23 | with open( "data/celeba_hq.yml", "r") as f: 24 | config1 = yaml.safe_load(f) 25 | config = dict2namespace(config1) 26 | model, device = load_pretrained_diffusion_model(config) 27 | 28 | ### Define the DDIM scheduler 29 | ddim_scheduler=DDIMScheduler(beta_start=config.diffusion.beta_start, beta_end=config.diffusion.beta_end, beta_schedule=config.diffusion.beta_schedule) 30 | ddim_scheduler.set_timesteps(config.diffusion.num_diffusion_timesteps // task_config['delta_t'])#task_config['Denoising_steps'] 31 | 32 | 33 | img_pil, img_np, mask = generate_noisy_image_and_mask('data/imgs/00243.png') 34 | img_torch = torch.tensor(img_np).permute(2,0,1).unsqueeze(0) 35 | t_mask = torch.from_numpy(mask).unsqueeze(0).unsqueeze(0).cuda() 36 | radii = torch.ones([1, 1, 1]).cuda() * (np.sqrt(config.data.image_size*config.data.image_size*config.model.in_channels)) 37 | 38 | latent = torch.nn.parameter.Parameter(torch.randn( 1, config.model.in_channels, config.data.image_size, config.data.image_size).to(device)) 39 | l2_loss=nn.MSELoss() #nn.L1Loss() 40 | optimizer = torch.optim.Adam([{'params':latent,'lr':task_config['lr']}])# 41 | 42 | 43 | for iteration in range(task_config['Optimization_steps']): 44 | optimizer.zero_grad() 45 | x_0_hat = DDIM_efficient_feed_forward(latent, model, ddim_scheduler) 46 | loss = l2_loss(x_0_hat*t_mask, img_torch.cuda()*t_mask) 47 | loss.backward() 48 | optimizer.step() 49 | 50 | #Project to the Sphere of radius sqrt(D) 51 | for param in latent: 52 | param.data.div_((param.pow(2).sum(tuple(range(0, param.ndim)), keepdim=True) + 1e-9).sqrt()) 53 | param.data.mul_(radii) 54 | 55 | if iteration % 10 == 0: 56 | #psnr = psnr_orig(np.array(img_pil).astype(np.float32), process(x_0_hat, 0)) 57 | #print(iteration, 'loss:', loss.item(), torch.norm(latent.detach()), psnr) 58 | Image.fromarray(np.concatenate([ process(img_torch.cuda()*t_mask, 0), process(x_0_hat, 0), np.array(img_pil).astype(np.uint8)], 1)).save('results/inpainted.png') 59 | 60 | 61 | 62 | 63 | -------------------------------------------------------------------------------- /push_to_hf.py: -------------------------------------------------------------------------------- 1 | from guided_diffusion.models import Model 2 | import torch 3 | import yaml 4 | from utils import dict2namespace 5 | 6 | 7 | def load_pretrained_diffusion_model(config): 8 | model = Model(config) 9 | ckpt = "checkpoints/celeba_hq.ckpt" 10 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 11 | config.device = device 12 | model.load_state_dict(torch.load(ckpt, map_location=device)) 13 | model.to(device) 14 | model.eval() 15 | for param in model.parameters(): 16 | param.requires_grad = False 17 | model = torch.nn.DataParallel(model) 18 | return model, device 19 | 20 | 21 | with open( "data/celeba_hq.yml", "r") as f: 22 | config1 = yaml.safe_load(f) 23 | 24 | config = dict2namespace(config1) 25 | 26 | model, device = load_pretrained_diffusion_model(config) 27 | 28 | # push to hub 29 | model.push_to_hub("nielsr/bird-demo") 30 | 31 | # reload 32 | model = Model.from_pretrained("nielsr/bird-demo") -------------------------------------------------------------------------------- /results/blind_deblurring.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hamadichihaoui/BIRD/713dab700a0192b224641ecd5c0e264a96e47fab/results/blind_deblurring.png -------------------------------------------------------------------------------- /results/denoised.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hamadichihaoui/BIRD/713dab700a0192b224641ecd5c0e264a96e47fab/results/denoised.png -------------------------------------------------------------------------------- /results/inpainted.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hamadichihaoui/BIRD/713dab700a0192b224641ecd5c0e264a96e47fab/results/inpainted.png -------------------------------------------------------------------------------- /results/non_uniform_deblurring.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hamadichihaoui/BIRD/713dab700a0192b224641ecd5c0e264a96e47fab/results/non_uniform_deblurring.png -------------------------------------------------------------------------------- /results/super_resolution.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hamadichihaoui/BIRD/713dab700a0192b224641ecd5c0e264a96e47fab/results/super_resolution.png -------------------------------------------------------------------------------- /super_resolution.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import numpy as np 4 | import tqdm 5 | import torch 6 | from torch import nn 7 | import sys 8 | sys.path.insert(0,'./') 9 | from guided_diffusion.models import Model 10 | import random 11 | from ddim_inversion_utils import * 12 | from utils import * 13 | 14 | with open('configs/super_resolution.yml', 'r') as f: 15 | task_config = yaml.safe_load(f) 16 | 17 | 18 | ### Reproducibility 19 | torch.set_printoptions(sci_mode=False) 20 | ensure_reproducibility(task_config['seed']) 21 | 22 | 23 | with open( "data/celeba_hq.yml", "r") as f: 24 | config1 = yaml.safe_load(f) 25 | config = dict2namespace(config1) 26 | model, device = load_pretrained_diffusion_model(config) 27 | 28 | ### Define the DDIM scheduler 29 | ddim_scheduler=DDIMScheduler(beta_start=config.diffusion.beta_start, beta_end=config.diffusion.beta_end, beta_schedule=config.diffusion.beta_schedule) 30 | ddim_scheduler.set_timesteps(config.diffusion.num_diffusion_timesteps // task_config['delta_t'])#task_config['Denoising_steps'] 31 | 32 | img_pil, downsampled_torch, downsampling_op = generate_lr_image('data/imgs/00205.png', task_config['downsampling_ratio']) 33 | radii = torch.ones([1, 1, 1]).cuda() * (np.sqrt(config.data.image_size*config.data.image_size*config.model.in_channels)) 34 | 35 | latent = torch.nn.parameter.Parameter(torch.randn( 1, config.model.in_channels, config.data.image_size, config.data.image_size).to(device)) 36 | l2_loss=nn.MSELoss() #nn.L1Loss() 37 | optimizer = torch.optim.Adam([{'params':latent,'lr':task_config['lr']}])# 38 | 39 | 40 | for iteration in range(task_config['Optimization_steps']): 41 | optimizer.zero_grad() 42 | x_0_hat = DDIM_efficient_feed_forward(latent, model, ddim_scheduler) 43 | loss = l2_loss(downsampling_op(x_0_hat), downsampled_torch) 44 | loss.backward() 45 | optimizer.step() 46 | 47 | #Project to the Sphere of radius 1 48 | for param in latent: 49 | param.data.div_((param.pow(2).sum(tuple(range(0, param.ndim)), keepdim=True) + 1e-9).sqrt()) 50 | param.data.mul_(radii) 51 | 52 | if iteration % 10 == 0: 53 | #psnr = psnr_orig(np.array(img_pil).astype(np.float32), process(x_0_hat, 0)) 54 | #print(iteration, 'loss:', loss.item(), torch.norm(latent.detach()), psnr) 55 | Image.fromarray(np.concatenate([ process(MeanUpsample(downsampled_torch, task_config['downsampling_ratio']), 0), process(x_0_hat, 0), np.array(img_pil).astype(np.uint8)], 1)).save('results/super_resolution.png') 56 | 57 | 58 | 59 | 60 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import argparse 3 | import numpy as np 4 | import torch 5 | from torch import nn 6 | import random 7 | from math import log10, sqrt 8 | import sys 9 | sys.path.insert(0,'./') 10 | from guided_diffusion.models import Model 11 | 12 | def ensure_reproducibility(seed): 13 | torch.backends.cudnn.benchmark = False 14 | torch.backends.cudnn.deterministic = True 15 | random.seed(seed) 16 | np.random.seed(seed) 17 | torch.manual_seed(seed) 18 | 19 | 20 | 21 | def dict2namespace(config): 22 | namespace = argparse.Namespace() 23 | for key, value in config.items(): 24 | if isinstance(value, dict): 25 | new_value = dict2namespace(value) 26 | else: 27 | new_value = value 28 | setattr(namespace, key, new_value) 29 | return namespace 30 | 31 | def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): 32 | def sigmoid(x): 33 | return 1 / (np.exp(-x) + 1) 34 | 35 | if beta_schedule == "quad": 36 | betas = ( 37 | np.linspace( 38 | beta_start ** 0.5, 39 | beta_end ** 0.5, 40 | num_diffusion_timesteps, 41 | dtype=np.float64, 42 | ) 43 | ** 2 44 | ) 45 | elif beta_schedule == "linear": 46 | betas = np.linspace( 47 | beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 48 | ) 49 | elif beta_schedule == "const": 50 | betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) 51 | elif beta_schedule == "jsd": 52 | betas = 1.0 / np.linspace( 53 | num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64 54 | ) 55 | elif beta_schedule == "sigmoid": 56 | betas = np.linspace(-6, 6, num_diffusion_timesteps) 57 | betas = sigmoid(betas) * (beta_end - beta_start) + beta_start 58 | else: 59 | raise NotImplementedError(beta_schedule) 60 | assert betas.shape == (num_diffusion_timesteps,) 61 | return betas 62 | 63 | 64 | # Code form RePaint 65 | def get_schedule_jump(T_sampling, travel_length, travel_repeat): 66 | jumps = {} 67 | for j in range(0, T_sampling - travel_length, travel_length): 68 | jumps[j] = travel_repeat - 1 69 | 70 | t = T_sampling 71 | ts = [] 72 | 73 | while t >= 1: 74 | t = t-1 75 | ts.append(t) 76 | 77 | if jumps.get(t, 0) > 0: 78 | jumps[t] = jumps[t] - 1 79 | for _ in range(travel_length): 80 | t = t + 1 81 | ts.append(t) 82 | 83 | ts.append(-1) 84 | 85 | _check_times(ts, -1, T_sampling) 86 | return ts 87 | 88 | def _check_times(times, t_0, T_sampling): 89 | # Check end 90 | assert times[0] > times[1], (times[0], times[1]) 91 | 92 | # Check beginning 93 | assert times[-1] == -1, times[-1] 94 | 95 | # Steplength = 1 96 | for t_last, t_cur in zip(times[:-1], times[1:]): 97 | assert abs(t_last - t_cur) == 1, (t_last, t_cur) 98 | 99 | # Value range 100 | for t in times: 101 | assert t >= t_0, (t, t_0) 102 | assert t <= T_sampling, (t, T_sampling) 103 | 104 | def compute_alpha(beta, t): 105 | beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0) 106 | a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1) 107 | return a 108 | 109 | 110 | 111 | def process(sample, i): 112 | image_processed = sample.detach().cpu().permute(0, 2, 3, 1) 113 | image_processed = image_processed.squeeze(0) 114 | image_processed = torch.clip((image_processed + 1.0) * 127.5, 0., 255.) 115 | image_processed = image_processed.numpy().astype(np.uint8) 116 | return image_processed 117 | 118 | def process_gray(sample, i): 119 | image_processed = sample.detach().cpu().permute(0, 2, 3, 1) 120 | image_processed = image_processed.squeeze(0) 121 | image_processed = torch.clip((image_processed + 1.0) * 127.5, 0., 255.) 122 | image_processed = image_processed.numpy().astype(np.uint8) 123 | init_image=Image.fromarray(image_processed).convert('L') 124 | img1 = np.expand_dims(np.array(init_image).astype(np.uint8), axis=2) 125 | #mask1 = np.where(img1 < 50) 126 | img2 = np.tile(img1, [1, 1, 3]) 127 | return img2 128 | 129 | def process_gray_thresh(sample, i, thresh=170): 130 | image_processed = sample.detach().cpu().permute(0, 2, 3, 1) 131 | image_processed = image_processed.squeeze(0) 132 | image_processed = torch.clip((image_processed + 1.0) * 127.5, 0., 255.) 133 | image_processed = image_processed.numpy().astype(np.uint8) 134 | init_image=Image.fromarray(image_processed).convert('L') 135 | img1 = np.expand_dims(np.array(init_image).astype(np.uint8), axis=2) 136 | img3 = (np.where((img1 > thresh), 255, 0)).astype(np.uint8) 137 | img2 = np.tile(img3, [1, 1, 3]) 138 | 139 | return img2 140 | 141 | 142 | def get_mask(sample, i, thresh=170): 143 | image_processed = sample.detach().cpu().permute(0, 2, 3, 1) 144 | image_processed = image_processed.squeeze(0) 145 | image_processed = torch.clip((image_processed + 1.0) * 127.5, 0., 255.) 146 | image_processed = image_processed.numpy().astype(np.uint8) 147 | init_image=Image.fromarray(image_processed).convert('L') 148 | img1 = np.expand_dims(np.array(init_image).astype(np.uint8), axis=2) 149 | img3 = (np.where((img1 > thresh), 0., 1.)).astype(np.float32) 150 | 151 | 152 | return img3[:,:,0] 153 | 154 | 155 | 156 | def psnr_orig(original, compressed): 157 | mse = np.mean((original - compressed) ** 2) 158 | if(mse == 0): # MSE is zero means no noise is present in the signal . 159 | # Therefore PSNR have no importance. 160 | return 100 161 | max_pixel = 255.0 162 | psnr = 20 * log10(max_pixel / sqrt(mse)) 163 | return psnr 164 | 165 | def psnr_mask(original, compressed, mask): 166 | mse = ((original*mask - compressed*mask) ** 2).sum() / mask.sum() 167 | if(mse == 0): # MSE is zero means no noise is present in the signal . 168 | # Therefore PSNR have no importance. 169 | return 100 170 | max_pixel = 255.0 171 | psnr = 20 * log10(max_pixel / sqrt(mse)) 172 | return psnr 173 | 174 | def generate_noisy_image(path, speckle_coef=0.8, speckle_lambda=0.4, gauss_coef=0.2, gauss_sigma=0.15): 175 | pil_img = Image.open(path).resize((256, 256)) 176 | gauss = np.random.normal(0, speckle_lambda, 3*256*256) 177 | gauss = gauss.reshape(256, 256, 3).astype(np.float32) 178 | x = np.array(pil_img).astype(np.float32) 179 | img_np = speckle_coef * np.array(x + x * gauss, dtype=np.float32) + gauss_coef * (x + np.random.normal(size=np.array(pil_img).shape).astype(np.float32) * gauss_sigma * 255) 180 | img_np = np.clip(img_np, 0., 255.) 181 | #print(path1, 'std', np.std(img_np-x) / 255.) 182 | img_np = img_np/ 255 * 2 - 1 183 | return pil_img, img_np 184 | 185 | def generate_noisy_image_and_mask(path='imgs/00205.png', speckle_coef=0.8, speckle_lambda=0.12, gauss_coef=0.2, gauss_sigma=0.05): 186 | init_image = Image.open(path).resize((256, 256)) 187 | gauss = np.random.normal(0,speckle_lambda, 3*256*256) 188 | gauss = gauss.reshape(256,256,3).astype(np.float32) 189 | x = np.array(init_image).astype(np.float32) 190 | img_np = speckle_coef * np.array(x + x * gauss, dtype=np.float32) + gauss_coef * (x + np.random.normal(size=np.array(init_image).shape).astype(np.float32) * gauss_sigma * 255.) 191 | img_np = np.clip(img_np, 0., 255.) 192 | img_np = img_np/ 255 * 2 - 1 193 | mask = np.ones((256, 256)) 194 | for i in range(128-40, 128+40): 195 | for j in range(128-40, 128+40): 196 | mask[i, j]=0. 197 | 198 | return init_image, img_np, mask 199 | 200 | 201 | def generate_lr_image(path, downsampling_ratio, speckle_coef=0.8, speckle_lambda=0.12, gauss_coef=0.2, gauss_sigma=0.05): 202 | 203 | init_image = Image.open(path).resize((256, 256)) 204 | img_np = np.array(init_image).astype(np.float32) / 255 * 2 - 1 205 | img = torch.tensor(img_np).permute(2,0,1).unsqueeze(0) 206 | downsampling_op = torch.nn.AdaptiveAvgPool2d((256//downsampling_ratio,256//downsampling_ratio)).cuda() 207 | for param in downsampling_op.parameters(): 208 | param.requires_grad = False 209 | #b, c, h, w = img.shape 210 | downsampled = downsampling_op(img.cuda()) 211 | downsampled_resc1 = (downsampled + 1.) / 2. 212 | gauss = torch.randn_like(downsampled) * speckle_lambda 213 | 214 | downsampled_resc = speckle_coef *(downsampled_resc1 + downsampled_resc1 * gauss) + gauss_coef * (downsampled_resc1 + gauss_sigma * torch.randn_like(downsampled)) 215 | #print('std', (downsampled_resc - downsampled_resc1).std()) 216 | downsampled = downsampled_resc * 2. - 1. 217 | return init_image, downsampled, downsampling_op 218 | 219 | def MeanUpsample(x, scale): 220 | n, c, h, w = x.shape 221 | out = torch.zeros(n, c, h, scale, w, scale).to(x.device) + x.view(n,c,h,1,w,1) 222 | out = out.view(n, c, scale*h, scale*w) 223 | return out 224 | 225 | 226 | def get_conv(scale): 227 | conv = torch.nn.Conv2d(in_channels=1, out_channels=1, kernel_size=scale, stride=1, padding=scale//2, bias=False) 228 | kernel = np.load('data/kernel.npy') 229 | kernel_torch = torch.from_numpy(kernel).unsqueeze(0).unsqueeze(0).float() 230 | conv.weight = torch.nn.Parameter(kernel_torch) 231 | return conv 232 | 233 | def generate_blurry_image(path='imgs/00205.png', kernel_size=41, speckle_coef=0.8, speckle_lambda=0.12, gauss_coef=0.2, gauss_sigma=0.05): 234 | 235 | pil_image = Image.open(path).resize((256, 256)) 236 | img_np = np.array(pil_image).astype(np.float32) / 255 * 2 - 1 237 | img = torch.tensor(img_np).permute(2,0,1).unsqueeze(0).cuda() 238 | conv = get_conv(kernel_size).cuda() 239 | 240 | for param in conv.parameters(): 241 | param.requires_grad = False 242 | b, c, h, w = img.shape 243 | blurred = conv(img.view(-1, 1, h, w)) 244 | blurred = blurred.view(1, c, h, w) 245 | downsampled = blurred#[:,:, ::32, ::32] 246 | 247 | downsampled_resc1 = (downsampled + 1.) / 2. 248 | gauss = torch.randn_like(downsampled) * speckle_lambda 249 | downsampled_resc = speckle_coef *(downsampled_resc1 + downsampled_resc1 * gauss) + gauss_coef * (downsampled_resc1 + gauss_sigma * torch.randn_like(downsampled)) 250 | #print('std', (downsampled_resc - downsampled_resc1).std()) 251 | downsampled = downsampled_resc * 2. - 1. 252 | return pil_image, downsampled 253 | 254 | def fill_noise(x, noise_type): 255 | """Fills tensor `x` with noise of type `noise_type`.""" 256 | torch.manual_seed(0) 257 | if noise_type == 'u': 258 | x.uniform_() 259 | elif noise_type == 'n': 260 | x.normal_() 261 | else: 262 | assert False 263 | 264 | def get_noise(input_depth, method, spatial_size, noise_type='u', var=1./10): 265 | """Returns a pytorch.Tensor of size (1 x `input_depth` x `spatial_size[0]` x `spatial_size[1]`) 266 | initialized in a specific way. 267 | Args: 268 | input_depth: number of channels in the tensor 269 | method: `noise` for fillting tensor with noise; `meshgrid` for np.meshgrid 270 | spatial_size: spatial size of the tensor to initialize 271 | noise_type: 'u' for uniform; 'n' for normal 272 | var: a factor, a noise will be multiplicated by. Basically it is standard deviation scaler. 273 | """ 274 | if isinstance(spatial_size, int): 275 | spatial_size = (spatial_size, spatial_size) 276 | if method == 'noise': 277 | shape = [1, input_depth, spatial_size[0], spatial_size[1]] 278 | net_input = torch.zeros(shape) 279 | 280 | fill_noise(net_input, noise_type) 281 | net_input *= var 282 | elif method == 'meshgrid': 283 | assert input_depth == 2 284 | X, Y = np.meshgrid(np.arange(0, spatial_size[1])/float(spatial_size[1]-1), np.arange(0, spatial_size[0])/float(spatial_size[0]-1)) 285 | meshgrid = np.concatenate([X[None, :], Y[None, :]]) 286 | net_input = np_to_torch(meshgrid) 287 | else: 288 | assert False 289 | 290 | return net_input 291 | 292 | def fcn(num_input_channels=200, num_output_channels=1, num_hidden=1000): 293 | 294 | layers = [] 295 | layers.append(nn.Linear(num_input_channels, num_hidden,bias=True)) 296 | layers.append(nn.ReLU6()) 297 | # 298 | layers.append(nn.Linear(num_hidden, num_output_channels)) 299 | layers.append(nn.Softmax()) 300 | model2 = nn.Sequential(*layers) 301 | 302 | return model2 303 | 304 | def load_pretrained_diffusion_model(config): 305 | model = Model(config) 306 | ckpt = "checkpoints/celeba_hq.ckpt" 307 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 308 | config.device = device 309 | model.load_state_dict(torch.load(ckpt, map_location=device)) 310 | model.to(device) 311 | model.eval() 312 | for param in model.parameters(): 313 | param.requires_grad = False 314 | model = torch.nn.DataParallel(model) 315 | return model, device 316 | --------------------------------------------------------------------------------