├── assets ├── method.png └── visualization.png ├── code ├── ckpt │ └── pretrained.txt ├── test.sh ├── train.sh ├── datasets_setting.py ├── test.py ├── train.py └── model.py ├── LICENSE.md └── Readme.md /assets/method.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MiliLab/DGSolver/HEAD/assets/method.png -------------------------------------------------------------------------------- /assets/visualization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MiliLab/DGSolver/HEAD/assets/visualization.png -------------------------------------------------------------------------------- /code/ckpt/pretrained.txt: -------------------------------------------------------------------------------- 1 | [Baidu Cloud](https://pan.baidu.com/s/1CvZ2HAiwqM1t2VOJe5dYkg?pwd=fz2d)[fz2d] 2 | [Google Cloud](https://drive.google.com/file/d/1cc7WkG2E8gGKEzBt7WZFXmkpV7VzLWk7/view?usp=drive_link) -------------------------------------------------------------------------------- /code/test.sh: -------------------------------------------------------------------------------- 1 | export CUDA_DEVICE_ORDER="PCI_BUS_ID" 2 | export CUDA_VISIBLE_DEVICES="0,1,2,3" 3 | 4 | start_time=$(date +%s) 5 | echo "start_time: ${start_time}" 6 | 7 | # TORCH_DISTRIBUTED_DEBUG=DETAIL 8 | # nohup python train.py > ./train.log 2>&1 & 9 | 10 | nohup python -m torch.distributed.launch --nproc_per_node 4 --use_env test.py > ./test.log 2>&1 & 11 | 12 | end_time=$(date +%s) 13 | e2e_time=$(($end_time - $start_time)) 14 | 15 | echo "------------------ Final result ------------------" -------------------------------------------------------------------------------- /code/train.sh: -------------------------------------------------------------------------------- 1 | export CUDA_DEVICE_ORDER="PCI_BUS_ID" 2 | export CUDA_VISIBLE_DEVICES="0,1,2,3" 3 | 4 | start_time=$(date +%s) 5 | echo "start_time: ${start_time}" 6 | 7 | # TORCH_DISTRIBUTED_DEBUG=DETAIL 8 | # nohup python train.py > ./train.log 2>&1 & 9 | 10 | nohup python -m torch.distributed.launch --nproc_per_node 4 --use_env train.py > ./train.log 2>&1 & 11 | 12 | end_time=$(date +%s) 13 | e2e_time=$(($end_time - $start_time)) 14 | 15 | echo "------------------ Final result ------------------" -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 MiliLab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | # Project 2 | 3 | # DGSolver: Diffusion Generalist Solver with Universal Posterior Sampling for Image Restoration 4 | 5 | Hebaixu Wang, Jing Zhang, Haonan Guo, Di Wang, Jiayi Ma and Bo Du. 6 | 7 | [Paper](https://arxiv.org/abs/2504.21487) | [Github Code](https://github.com/MiliLab/DGSolver) 8 | 9 | ## Abstract 10 | 11 | Diffusion models have achieved remarkable progress in universal image restoration. While existing methods speed up inference by reducing sampling steps, substantial step intervals often introduce cumulative errors. Moreover, they struggle to balance the commonality of degradation representations and restoration quality. To address these challenges, we introduce \textbf{DGSolver}, a diffusion generalist solver with universal posterior sampling. We first derive the exact ordinary differential equations for generalist diffusion models and tailor high-order solvers with a queue-based accelerated sampling strategy to improve both accuracy and efficiency. We then integrate universal posterior sampling to better approximate manifold-constrained gradients, yielding a more accurate noise estimation and correcting errors in inverse inference. Extensive experiments show that DGSolver outperforms state-of-the-art methods in restoration accuracy, stability, and scalability, both qualitatively and quantitatively. 12 | 13 | ## Overview 14 | 15 | 16 | 17 | ## Visualization 18 | 19 | 20 | 21 | ## Datasets Information 22 | 23 | | Task | Dataset | Synthetic/Real | Download Links | 24 | |--------------------------|--------------------------------|---------------------|----------------| 25 | | **Deraining** | DID | Synthetic | [URL](https://github.com/hezhangsprinter/DID-MDN) | 26 | | | DeRaindrop | Real | [URL](https://github.com/rui1996/DeRaindrop) | 27 | | | Rain13K | Synthetic | [URL](https://github.com/kuijiang94/MSPFN) | 28 | | | Rain_100H | Synthetic | [URL](https://github.com/kuijiang94/MSPFN) | 29 | | | Rain_100L | Synthetic | [URL](https://github.com/kuijiang94/MSPFN) | 30 | | | GT-Rain | Real | [URL](https://github.com/UCLA-VMG/GT-RAIN) | 31 | | | RealRain-1k | Real | [URL](https://github.com/hiker-lw/RealRain-1k) | 32 | | **Low-light Enhancement**| LOL | Real | [URL](https://github.com/weichen582/RetinexNet?tab=readme-ov-file) | 33 | | | MEF | Real | [URL](https://ieeexplore.ieee.org/abstract/document/7120119) | 34 | | | VE-LOL-L | Synthetic/Real | [URL](https://flyywh.github.io/IJCV2021LowLight_VELOL/) | 35 | | | NPE | Real | [URL](https://ieeexplore.ieee.org/abstract/document/6512558) | 36 | | **Desnowing** | CSD | Synthetic | [URL](https://github.com/weitingchen83/ICCV2021-Single-Image-Desnowing-HDCWNet) | 37 | | | Snow100K-Real | Real | [URL](https://sites.google.com/view/yunfuliu/desnownet) | 38 | | **Dehazing** | SOTS | Synthetic | [URL](https://sites.google.com/view/reside-dehaze-datasets/reside-standard?authuser=3D0) | 39 | | | ITS_v2 | Synthetic | [URL](https://sites.google.com/view/reside-dehaze-datasets/reside-standard?authuser=3D0) | 40 | | | D-HAZY | Synthetic | [URL](https://www.cvmart.net/dataSets/detail/559?channel_id=op10&utm_source=cvmartmp&utm_campaign=datasets&utm_medium=article) | 41 | | | NH-HAZE | Real | [URL](https://data.vision.ee.ethz.ch/cvl/ntire20/nh-haze/) | 42 | | | Dense-Haze | Real | [URL](https://data.vision.ee.ethz.ch/cvl/ntire19/dense-haze/) | 43 | | | NHRW | Real | [URL](https://github.com/chaimi2013/3R) | 44 | | **Deblur** | GoPro | Synthetic | [URL](https://github.com/SeungjunNah/DeepDeblur-PyTorch) | 45 | | | RealBlur | Real | [URL](https://github.com/rimchang/RealBlur) | 46 | 47 | ## Model Checkpoint 48 | 49 | [Google Cloud](https://drive.google.com/file/d/1cc7WkG2E8gGKEzBt7WZFXmkpV7VzLWk7/view?usp=drive_link) 50 | 51 | [Baidu Cloud](https://pan.baidu.com/s/1CvZ2HAiwqM1t2VOJe5dYkg?pwd=fz2d)[fz2d] 52 | 53 | [Huggingface](https://huggingface.co/BaiXuYa/DGSolver) 54 | 55 | ### Contributor 56 | 57 | Baixuzx7 @ wanghebaixu@gmail.com 58 | 59 | ### Copyright statement 60 | 61 | The project is signed under the MIT license, see the [LICENSE.md](https://github.com/MiliLab/DGSolver/LICENSE.md) 62 | -------------------------------------------------------------------------------- /code/datasets_setting.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.data import Dataset, DataLoader 4 | from torchvision import transforms as T, utils 5 | 6 | from PIL import Image 7 | from torch import nn 8 | import imageio 9 | import cv2 10 | import numpy as np 11 | import random 12 | 13 | # import torch_npu 14 | # from torch_npu.contrib import transfer_to_npu 15 | 16 | def exists(x): 17 | return x is not None 18 | 19 | def cycle(dl): 20 | while True: 21 | for data in dl: 22 | yield data 23 | 24 | def default(val, d): 25 | if exists(val) and (val is not None): 26 | return val 27 | return d() if callable(d) else d 28 | 29 | 30 | def set_seed(SEED): 31 | torch.manual_seed(SEED) 32 | torch.cuda.manual_seed_all(SEED) 33 | np.random.seed(SEED) 34 | random.seed(SEED) 35 | 36 | degradtion_cache = ['Enlighening', 'Desnowing', 'Deraining', 'Deblur', 'Dehazing'] 37 | 38 | class train_dataset(Dataset): 39 | def __init__(self, root_dir, task_folder, sub_folder = None, image_size = 256): 40 | super().__init__() 41 | assert task_folder in degradtion_cache 42 | self.root_path = root_dir 43 | self.task_path = os.path.join(root_dir,task_folder) 44 | if sub_folder is not None: 45 | self.sub_datasets = [sub_folder] 46 | else: 47 | self.sub_datasets = os.listdir(self.task_path) 48 | self.image_size = image_size 49 | self.skip_datasets = [] 50 | self.image_load_path,self.image_name_list = [],[] 51 | self.condi_load_path,self.condi_name_list = [],[] 52 | 53 | for sub_dataset in self.sub_datasets: 54 | assert sub_dataset in os.listdir(self.task_path) 55 | if sub_dataset in self.skip_datasets: 56 | continue 57 | else: 58 | if sub_dataset == 'ITS_v2' or sub_dataset == 'SOTS' or sub_dataset == 'RESIDE' : 59 | dataset_path = os.path.join(self.task_path,sub_dataset,'train') 60 | image_path = os.path.join(dataset_path,'label') 61 | condi_path = os.path.join(dataset_path ,'condition') 62 | condi_file_list = os.listdir(condi_path) 63 | condi_path_list = [os.path.join(condi_path,x) for x in condi_file_list] 64 | image_file_list = os.listdir(condi_path) 65 | image_path_list = [os.path.join(image_path,x.split('_',1)[0]+x[-4:]) for x in condi_file_list] 66 | else: 67 | dataset_path = os.path.join(self.task_path,sub_dataset,'train') 68 | image_path = os.path.join(dataset_path,'label') 69 | image_file_list = os.listdir(image_path) 70 | condi_path = os.path.join(dataset_path ,'condition') 71 | condi_file_list = os.listdir(condi_path) 72 | image_path_list = [os.path.join(image_path,x) for x in image_file_list] 73 | condi_path_list = [os.path.join(condi_path,x) for x in condi_file_list] 74 | 75 | self.image_load_path = self.image_load_path + image_path_list 76 | self.image_name_list = self.image_name_list + image_file_list 77 | self.condi_load_path = self.condi_load_path + condi_path_list 78 | self.condi_name_list = self.condi_name_list + condi_file_list 79 | 80 | self.transform = T.Compose([T.ToTensor()]) 81 | 82 | def __len__(self): 83 | assert len(self.condi_name_list) == len(self.image_name_list),f'the number of label files does not match the number of condition files' 84 | return len(self.condi_name_list) 85 | 86 | def __getitem__(self, index): 87 | image_name_file = self.image_name_list[index] 88 | condi_name_file = self.condi_name_list[index] 89 | if condi_name_file != image_name_file: 90 | print(self.condi_load_path[index],'\n',self.image_load_path[index]) 91 | assert condi_name_file == image_name_file, f'image pairs are not matched' 92 | image_file_path = self.image_load_path[index] 93 | condi_file_path = self.condi_load_path[index] 94 | image = imageio.imread(image_file_path) 95 | condi = imageio.imread(condi_file_path) 96 | if image.shape != condi.shape: 97 | print(image_file_path,'\n',condi_file_path) 98 | assert image.shape == condi.shape, 'image sizes are not matched' 99 | image,condi = self.random_crop_size(image,condi,self.image_size) 100 | 101 | if self.transform is not None: 102 | image_tf = self.transform(image) 103 | condi_tf = self.transform(condi) 104 | 105 | return condi_name_file,image_tf,condi_tf 106 | 107 | def resize_shape(self, image, short_side_length): 108 | oldh, oldw, _ = image.shape[0],image.shape[1],image.shape[2] 109 | if min(oldh, oldw) < short_side_length: 110 | scale = short_side_length * 1.0 / min(oldh, oldw) 111 | newh, neww = oldh * scale, oldw * scale 112 | image = cv2.resize(image,(int(neww + 0.5),int(newh + 0.5))) 113 | return image 114 | 115 | def random_crop_size(self,imageA,imageB,crop_size): 116 | imageA = self.resize_shape(imageA,crop_size) 117 | imageB = self.resize_shape(imageB,crop_size) 118 | assert imageA.shape == imageB.shape, f'image sizes are not matched' 119 | h,w,_ = imageA.shape 120 | h_start,w_start = np.random.randint(0,h-crop_size+1),np.random.randint(0,w-crop_size+1) 121 | imageA_crop = imageA[h_start:h_start+crop_size,w_start:w_start+crop_size,:] 122 | imageB_crop = imageB[h_start:h_start+crop_size,w_start:w_start+crop_size,:] 123 | return imageA_crop,imageB_crop 124 | 125 | 126 | class test_dataset(Dataset): 127 | def __init__(self, root_dir, task_folder, sub_folder = None): 128 | super().__init__() 129 | assert task_folder in degradtion_cache 130 | self.root_path = root_dir 131 | self.task_path = os.path.join(root_dir,task_folder) 132 | if sub_folder is not None: 133 | self.sub_datasets = [sub_folder] 134 | else: 135 | self.sub_datasets = os.listdir(self.task_path) 136 | self.skip_datasets = [] 137 | self.image_load_path,self.image_name_list = [],[] 138 | self.condi_load_path,self.condi_name_list = [],[] 139 | 140 | for sub_dataset in self.sub_datasets: 141 | assert sub_dataset in os.listdir(self.task_path) 142 | if sub_dataset in self.skip_datasets: 143 | continue 144 | else: 145 | if sub_dataset == 'ITS_v2' or sub_dataset == 'SOTS': 146 | dataset_path = os.path.join(self.task_path,sub_dataset,'test') 147 | image_path = os.path.join(dataset_path,'label') 148 | condi_path = os.path.join(dataset_path ,'condition') 149 | condi_file_list = os.listdir(condi_path) 150 | condi_path_list = [os.path.join(condi_path,x) for x in condi_file_list] 151 | image_file_list = os.listdir(condi_path) 152 | image_path_list = [os.path.join(image_path,x.split('_',1)[0]+x[-4:]) for x in condi_file_list] 153 | else: 154 | dataset_path = os.path.join(self.task_path,sub_dataset,'test') 155 | image_path = os.path.join(dataset_path,'label') 156 | image_file_list = os.listdir(image_path) 157 | condi_path = os.path.join(dataset_path ,'condition') 158 | condi_file_list = os.listdir(condi_path) 159 | image_path_list = [os.path.join(image_path,x) for x in image_file_list] 160 | condi_path_list = [os.path.join(condi_path,x) for x in condi_file_list] 161 | 162 | self.image_load_path = self.image_load_path + image_path_list 163 | self.image_name_list = self.image_name_list + image_file_list 164 | self.condi_load_path = self.condi_load_path + condi_path_list 165 | self.condi_name_list = self.condi_name_list + condi_file_list 166 | 167 | self.transform = T.Compose([T.ToTensor()]) 168 | 169 | def __len__(self): 170 | assert len(self.condi_name_list) == len(self.image_name_list),f'the number of label files does not match the number of condition files' 171 | return len(self.condi_name_list) 172 | 173 | def __getitem__(self, index): 174 | image_name_file = self.image_name_list[index] 175 | condi_name_file = self.condi_name_list[index] 176 | if condi_name_file != image_name_file: 177 | print(self.condi_load_path[index],'\n',self.image_load_path[index]) 178 | assert condi_name_file == image_name_file, f'image pairs are not matched' 179 | image_file_path = self.image_load_path[index] 180 | condi_file_path = self.condi_load_path[index] 181 | image = imageio.imread(image_file_path) 182 | condi = imageio.imread(condi_file_path) 183 | assert image.shape == condi.shape, 'image sizes are not matched' 184 | 185 | if self.transform is not None: 186 | image_tf = self.transform(image) 187 | condi_tf = self.transform(condi) 188 | 189 | return condi_file_path,image_tf,condi_tf -------------------------------------------------------------------------------- /code/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import math 4 | import time 5 | import imageio 6 | import json 7 | 8 | import accelerate 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from argparse import ArgumentParser 14 | from tqdm.auto import tqdm 15 | from ema_pytorch import EMA 16 | from pathlib import Path 17 | from skimage.metrics import structural_similarity 18 | from torch.optim import Adam 19 | from torchvision import transforms as T, utils 20 | from torch.utils.data import DataLoader 21 | from model import (ResidualDiffusion, Unet, UnetRes, set_seed) 22 | from datasets_setting import train_dataset,test_dataset,set_seed 23 | 24 | # import torch_npu 25 | # from torch_npu.contrib import transfer_to_npu 26 | 27 | 28 | parser = ArgumentParser() 29 | parser.add_argument("--project_description", type=str, default="UniDiffSolver For Image Restoration", help="Name of Project") 30 | 31 | parser.add_argument("--step_number", type=int, default=5000, help="step_number") 32 | parser.add_argument("--batch_size", type=int, default=8, help="batch_size") 33 | parser.add_argument("--image_size", type=int, default=512, help="image_size") 34 | parser.add_argument("--num_unet", type=int, default=1, help="num_unet") 35 | parser.add_argument("--objective", type=str, default='pred_res', help="[pred_res_noise,pred_x0_noise,pred_noise,pred_res]") 36 | parser.add_argument("--test_res_or_noise", type=str, default='res', help="[res_noise,res,noise]") 37 | parser.add_argument("--lr", type=float, default=0.0003, help="learning_rate") 38 | parser.add_argument("--sampling_timesteps", type=int, default=1, help="sampling_timesteps") 39 | 40 | def exists(x): 41 | return x is not None 42 | 43 | def has_int_squareroot(num): 44 | return (math.sqrt(num) ** 2) == num 45 | 46 | def cycle(dl): 47 | while True: 48 | for data in dl: 49 | yield data 50 | 51 | def create_folder(folder_path): 52 | if not os.path.exists(folder_path): 53 | os.makedirs(folder_path) 54 | 55 | def create_empty_json(json_path): 56 | with open(json_path, 'w') as file: 57 | pass 58 | 59 | def remove_json(json_path): 60 | os.remove(json_path) 61 | 62 | def write_json(json_path,item): 63 | with open(json_path, 'a+', encoding='utf-8') as f: 64 | line = json.dumps(item) 65 | f.write(line+'\n') 66 | 67 | def readline_json(json_path,key=None): 68 | data = [] 69 | with open(json_path, 'r') as f: 70 | items = f.readlines() 71 | file_flag = [] 72 | if key is not None: 73 | for item in items: 74 | file_name = json.loads(item)['file_path'] 75 | if file_name not in file_flag: 76 | file_flag.append(file_name) 77 | data.append(json.loads(item)[key]) 78 | return np.asarray(data).mean() 79 | else: 80 | for item in items: 81 | data.append(json.loads(item)) 82 | return data 83 | 84 | 85 | class Trainer(object): 86 | def __init__( 87 | self, 88 | diffusion_model, 89 | train_folder, 90 | eval_folder, 91 | train_num_steps = 100000, 92 | train_batch_size = 1, 93 | save_and_sample_every = 5000, 94 | save_best_and_latest_only = True, 95 | calculate_metric = True, 96 | results_folder = './results/', 97 | gradient_accumulate_every = 1, 98 | *, 99 | augment_horizontal_flip = True, 100 | train_lr = 8e-5, 101 | ema_update_every = 1, 102 | ema_decay = 0.995, 103 | adam_betas = (0.9, 0.99), 104 | save_row = 10, 105 | amp = False, 106 | mixed_precision_type = 'fp16', 107 | split_batches = True, 108 | convert_image_to = None, 109 | max_grad_norm = 1., 110 | ): 111 | super().__init__() 112 | 113 | self.accelerator = accelerate.Accelerator(split_batches = split_batches) 114 | self.model = diffusion_model 115 | is_ddim_sampling = diffusion_model.is_ddim_sampling 116 | self.save_row = save_row 117 | self.save_and_sample_every = save_and_sample_every 118 | self.batch_size = train_batch_size 119 | self.gradient_accumulate_every = gradient_accumulate_every 120 | self.image_size = diffusion_model.image_size 121 | self.max_grad_norm = max_grad_norm 122 | 123 | self.train_folder = train_folder 124 | self.eval_folder = eval_folder 125 | self.ds_eval_hazy = test_dataset(eval_folder,task_folder='Dehazing') 126 | self.ds_eval_light = test_dataset(eval_folder,task_folder='Enlighening') 127 | self.ds_eval_rain = test_dataset(eval_folder,task_folder='Deraining') 128 | self.ds_eval_snow = test_dataset(eval_folder,task_folder='Desnowing') 129 | self.ds_eval_blur = test_dataset(eval_folder,task_folder='Deblur') 130 | self.dl_eval_hazy = self.accelerator.prepare(DataLoader(self.ds_eval_hazy, batch_size = 1)) 131 | self.dl_eval_light = self.accelerator.prepare(DataLoader(self.ds_eval_light, batch_size = 1)) 132 | self.dl_eval_rain = self.accelerator.prepare(DataLoader(self.ds_eval_rain, batch_size = 1)) 133 | self.dl_eval_snow = self.accelerator.prepare(DataLoader(self.ds_eval_snow, batch_size = 1)) 134 | self.dl_eval_blur = self.accelerator.prepare(DataLoader(self.ds_eval_blur, batch_size = 1)) 135 | 136 | if self.accelerator.is_main_process: 137 | self.accelerator.print('Validation Samplies :') 138 | self.accelerator.print(' : (hazy :{})'.format(len(self.ds_eval_hazy))) 139 | self.accelerator.print(' : (light:{})'.format(len(self.ds_eval_light))) 140 | self.accelerator.print(' : (rain :{})'.format(len(self.ds_eval_rain))) 141 | self.accelerator.print(' : (snow :{})'.format(len(self.ds_eval_snow))) 142 | self.accelerator.print(' : (blur :{})'.format(len(self.ds_eval_blur))) 143 | 144 | self.opt = Adam(diffusion_model.parameters(), lr = train_lr, betas = adam_betas) 145 | self.ema = EMA(diffusion_model, beta = ema_decay, update_every = ema_update_every) 146 | self.ema.to(self.device) 147 | self.results_folder = Path(results_folder) 148 | self.results_folder.mkdir(exist_ok = True) 149 | self.train_num_steps = train_num_steps 150 | self.step = 0 151 | self.model, self.opt = self.accelerator.prepare(self.model, self.opt) 152 | self.calculate_metric = calculate_metric and self.accelerator.is_main_process 153 | 154 | @property 155 | def device(self): 156 | return self.accelerator.device 157 | 158 | def save(self, milestone = None): 159 | if not self.accelerator.is_local_main_process: 160 | return 161 | data = { 162 | 'step': self.step, 163 | 'model': self.accelerator.get_state_dict(self.model), 164 | 'opt': self.opt.state_dict(), 165 | 'ema': self.ema.state_dict(), 166 | 'scaler': self.accelerator.scaler.state_dict() if exists(self.accelerator.scaler) else None, 167 | } 168 | checkpoint_save_path = os.path.join(self.results_folder,f'model-{milestone}') 169 | if not os.path.exists(checkpoint_save_path): 170 | os.makedirs(checkpoint_save_path) 171 | torch.save(data, checkpoint_save_path + '/' + f'model-{milestone}.pt') 172 | 173 | def load(self, milestone = None): 174 | accelerator = self.accelerator 175 | device = accelerator.device 176 | checkpoint_save_path = os.path.join(self.results_folder,f'model-{milestone}') 177 | data = torch.load('./ckpt/pretrained.pt', map_location=device) 178 | model = self.accelerator.unwrap_model(self.model) 179 | model.load_state_dict(data['model']) 180 | self.step = data['step'] 181 | self.opt.load_state_dict(data['opt']) 182 | self.ema.load_state_dict(data["ema"]) 183 | if exists(self.accelerator.scaler) and exists(data['scaler']): 184 | self.accelerator.scaler.load_state_dict(data['scaler']) 185 | 186 | def cal_psnr(self,img_ref, img_gen, data_range = 255.0): 187 | mse = np.mean((img_ref.astype(np.float32)/data_range - img_gen.astype(np.float32)/data_range) ** 2) 188 | if mse < 1.0e-10: 189 | return 100 190 | PIXEL_MAX = 1 191 | return 20 * math.log10(PIXEL_MAX / math.sqrt(mse)) 192 | 193 | def cal_ssim(self,img_ref, img_gen): 194 | ssim_val = 0 195 | for i in range(img_ref.shape[-1]): 196 | ssim_val = ssim_val + structural_similarity(img_ref[:,:,i], img_gen[:,:,i]) 197 | return ssim_val/img_ref.shape[-1] 198 | 199 | def train(self): 200 | accelerator = self.accelerator 201 | device = accelerator.device 202 | track_metric_json_path = os.path.join(self.results_folder,'metric.json') 203 | if self.accelerator.is_main_process: 204 | create_empty_json(track_metric_json_path) 205 | 206 | self.test(dataloader = self.dl_eval_rain, degradation = 'Deraining') 207 | self.test(dataloader = self.dl_eval_light, degradation = 'Enlighening') 208 | self.test(dataloader = self.dl_eval_blur, degradation = 'Deblur') 209 | self.test(dataloader = self.dl_eval_snow, degradation = 'Desnowing') 210 | self.test(dataloader = self.dl_eval_hazy, degradation = 'Dehazing') 211 | 212 | if self.accelerator.is_main_process: 213 | write_json(track_metric_json_path,f'model-{self.step} : ') 214 | degradation_types = ['Deraining', 'Enlighening', 'Desnowing', 'Dehazing' , 'Deblur'] 215 | for degradation in degradation_types: 216 | json_path = os.path.join(self.results_folder,f'model-{self.step}') + '/{}.json'.format(degradation) 217 | psnr_val,ssim_val = readline_json(json_path,'psnr'),readline_json(json_path,'ssim') 218 | accelerator.print('{} -> (PSNR/SSIM) : {:.6f}/{:.6f} '.format(degradation,psnr_val,ssim_val)) 219 | write_json(track_metric_json_path,'{} -> (PSNR/SSIM) : {:.6f}/{:.6f} '.format(degradation,psnr_val,ssim_val)) 220 | 221 | accelerator.print('Testing complete') 222 | 223 | def test(self,dataloader,degradation): 224 | self.accelerator.wait_for_everyone() 225 | if self.accelerator.is_main_process: 226 | start_time = time.time() 227 | save_json_dir = os.path.join(self.results_folder,f'model-{self.step}') 228 | create_folder(save_json_dir) 229 | save_json_path = save_json_dir + '/{}.json'.format(degradation) 230 | create_empty_json(save_json_path) 231 | self.accelerator.wait_for_everyone() 232 | save_json_path = os.path.join(self.results_folder,f'model-{self.step}') + '/{}.json'.format(degradation) 233 | self.ema.model.eval() 234 | for batch_id,batch in enumerate(dataloader): 235 | name_path,image_tf,condi_tf = batch 236 | img_gen = self.ema.model.sample(condi_tf.to(self.device)) 237 | for element_id in range(len(name_path)): 238 | image_np_ref = self.tf2img(image_tf[element_id,:,:,].unsqueeze(0)) 239 | image_np_gen = self.tf2img(img_gen[element_id,:,:,].unsqueeze(0)) 240 | psnr_val = self.cal_psnr(image_np_ref,image_np_gen) 241 | ssim_val = self.cal_ssim(image_np_ref,image_np_gen) 242 | data_dump_info = { 243 | 'file_path' : name_path[element_id], 244 | 'psnr' : psnr_val, 245 | 'ssim' : ssim_val, 246 | } 247 | print(batch_id,name_path,'PSNR / SSIM : {:.6f} : {:.6f}'.format(psnr_val,ssim_val)) 248 | write_json(save_json_path,data_dump_info) 249 | image_save_dir = os.path.join(self.results_folder,f'model-{self.step}',name_path[element_id].split('/')[-5],name_path[element_id].split('/')[-4]) 250 | create_folder(image_save_dir) 251 | imageio.imwrite(os.path.join(image_save_dir,name_path[element_id].split('/')[-1]),image_np_gen) 252 | 253 | if self.accelerator.is_main_process: 254 | end_time = time.time() 255 | test_time_consuming = end_time - start_time 256 | self.accelerator.print('Test_time_consuming : {:.6} s'.format(test_time_consuming)) 257 | 258 | self.accelerator.wait_for_everyone() 259 | 260 | def tf2np(self,image_tf): 261 | n,c,h,w = image_tf.size() 262 | assert n == 1 263 | if c == 1: 264 | image_np = image_tf.squeeze(0).squeeze(0).detach().cpu().numpy() 265 | else: 266 | image_np = image_tf.squeeze(0).permute(1,2,0).detach().cpu().numpy() 267 | 268 | return image_np 269 | 270 | def tf2img(self,image_tf): 271 | image_np = self.tf2np(torch.clamp(image_tf,min=0.,max=1.)) 272 | image_np = (image_np * 255).astype(np.uint8) 273 | return image_np 274 | 275 | 276 | def test_ddp_accelerate(args): 277 | train_folder = ' ' 278 | eval_folder = ' ' 279 | print('Procedure Running: ',args.project_description) 280 | image_size = 256 281 | num_unet = 1 282 | objective = 'pred_res' 283 | ddim_sampling_eta = 0.0 284 | test_res_or_noise = "res" 285 | sum_scale = 0.01 286 | delta_end = 2.0e-3 287 | condition = True 288 | sampling_timesteps = 2 289 | model = UnetRes(dim=64, dim_mults=(1, 2, 4, 8),num_unet=num_unet, condition=condition, objective=objective, test_res_or_noise = test_res_or_noise) 290 | diffusion = ResidualDiffusion(model,image_size=image_size, timesteps=1000,delta_end = delta_end,sampling_timesteps=sampling_timesteps, objective=objective,ddim_sampling_eta= ddim_sampling_eta,loss_type='l1',condition=condition,sum_scale=sum_scale,test_res_or_noise = test_res_or_noise) 291 | diffusion_process_trainer = Trainer( 292 | diffusion_model = diffusion, 293 | train_folder = train_folder, 294 | eval_folder = eval_folder, 295 | train_num_steps = 500000, 296 | train_batch_size = 16, 297 | save_and_sample_every = 5000, 298 | save_best_and_latest_only = True, 299 | calculate_metric = True, 300 | results_folder = './save_folder', 301 | gradient_accumulate_every = 1, 302 | ) 303 | 304 | diffusion_process_trainer.load() 305 | diffusion_process_trainer.train() 306 | print('Procedure Termination: (Finished)') 307 | 308 | 309 | 310 | if __name__ == '__main__': 311 | args = parser.parse_args() 312 | set_seed(0) 313 | test_ddp_accelerate(args) -------------------------------------------------------------------------------- /code/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import math 4 | import time 5 | import imageio 6 | import json 7 | 8 | import accelerate 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from argparse import ArgumentParser 14 | from tqdm.auto import tqdm 15 | from ema_pytorch import EMA 16 | from pathlib import Path 17 | from skimage.metrics import structural_similarity 18 | from torch.optim import Adam 19 | from torchvision import transforms as T, utils 20 | from torch.utils.data import DataLoader 21 | from model import (ResidualDiffusion, Unet, UnetRes, set_seed) 22 | from datasets_setting import train_dataset,test_dataset,set_seed 23 | # import torch_npu 24 | # from torch_npu.contrib import transfer_to_npu 25 | 26 | 27 | parser = ArgumentParser() 28 | parser.add_argument("--project_description", type=str, default="UniDiffSolver For Image Restoration", help="Name of Project") 29 | 30 | parser.add_argument("--step_number", type=int, default=5000, help="step_number") 31 | parser.add_argument("--batch_size", type=int, default=8, help="batch_size") 32 | parser.add_argument("--image_size", type=int, default=512, help="image_size") 33 | parser.add_argument("--num_unet", type=int, default=1, help="num_unet") 34 | parser.add_argument("--objective", type=str, default='pred_res', help="[pred_res_noise,pred_x0_noise,pred_noise,pred_res]") 35 | parser.add_argument("--test_res_or_noise", type=str, default='res', help="[res_noise,res,noise]") 36 | parser.add_argument("--lr", type=float, default=0.0003, help="learning_rate") 37 | parser.add_argument("--sampling_timesteps", type=int, default=1, help="sampling_timesteps") 38 | 39 | def exists(x): 40 | return x is not None 41 | 42 | def has_int_squareroot(num): 43 | return (math.sqrt(num) ** 2) == num 44 | 45 | def cycle(dl): 46 | while True: 47 | for data in dl: 48 | yield data 49 | 50 | def divisible_by(numer, denom): 51 | return (numer % denom) == 0 52 | 53 | def create_folder(folder_path): 54 | if not os.path.exists(folder_path): 55 | os.makedirs(folder_path) 56 | 57 | def create_empty_json(json_path): 58 | with open(json_path, 'w') as file: 59 | pass 60 | 61 | def remove_json(json_path): 62 | os.remove(json_path) 63 | 64 | def write_json(json_path,item): 65 | with open(json_path, 'a+', encoding='utf-8') as f: 66 | line = json.dumps(item) 67 | f.write(line+'\n') 68 | 69 | def readline_json(json_path,key=None): 70 | data = [] 71 | with open(json_path, 'r') as f: 72 | items = f.readlines() 73 | if key is not None: 74 | for item in items: 75 | data.append(json.loads(item)[key]) 76 | return np.asarray(data).mean() 77 | else: 78 | for item in items: 79 | data.append(json.loads(item)) 80 | return data 81 | 82 | class Trainer(object): 83 | def __init__( 84 | self, 85 | diffusion_model, 86 | train_folder, 87 | eval_folder, 88 | train_num_steps = 100000, 89 | train_batch_size = 1, 90 | save_and_sample_every = 5000, 91 | save_best_and_latest_only = True, 92 | calculate_metric = True, 93 | results_folder = './results/', 94 | gradient_accumulate_every = 1, 95 | *, 96 | augment_horizontal_flip = True, 97 | train_lr = 8e-5, 98 | ema_update_every = 1, 99 | ema_decay = 0.995, 100 | adam_betas = (0.9, 0.99), 101 | save_row = 10, 102 | amp = False, 103 | mixed_precision_type = 'fp16', 104 | split_batches = True, 105 | convert_image_to = None, 106 | max_grad_norm = 1., 107 | ): 108 | super().__init__() 109 | 110 | self.accelerator = accelerate.Accelerator( 111 | split_batches = split_batches, 112 | mixed_precision = mixed_precision_type if amp else 'no' 113 | ) 114 | self.model = diffusion_model 115 | is_ddim_sampling = diffusion_model.is_ddim_sampling 116 | self.save_row = save_row 117 | self.save_and_sample_every = save_and_sample_every 118 | self.batch_size = train_batch_size 119 | self.gradient_accumulate_every = gradient_accumulate_every 120 | self.image_size = diffusion_model.image_size 121 | self.max_grad_norm = max_grad_norm 122 | 123 | self.train_folder = train_folder 124 | self.eval_folder = eval_folder 125 | 126 | self.ds_hazy = train_dataset(train_folder,task_folder='Dehazing',image_size = 256) 127 | self.ds_light = train_dataset(train_folder,task_folder='Enlighening',image_size = 256) 128 | self.ds_rain = train_dataset(train_folder,task_folder='Deraining',image_size = 256) 129 | self.ds_snow = train_dataset(train_folder,task_folder='Desnowing',image_size = 256) 130 | self.ds_blur = train_dataset(train_folder,task_folder='Deblur',image_size = 256) 131 | 132 | self.dl_hazy = cycle(self.accelerator.prepare(DataLoader(self.ds_hazy, batch_size = 1))) 133 | self.dl_light = cycle(self.accelerator.prepare(DataLoader(self.ds_light, batch_size = 1))) 134 | self.dl_rain = cycle(self.accelerator.prepare(DataLoader(self.ds_rain, batch_size = 1))) 135 | self.dl_snow = cycle(self.accelerator.prepare(DataLoader(self.ds_snow, batch_size = 1))) 136 | self.dl_blur = cycle(self.accelerator.prepare(DataLoader(self.ds_blur, batch_size = 1))) 137 | 138 | if self.accelerator.is_main_process: 139 | self.accelerator.print('Training Samplies : (hazy :{})'.format(len(self.ds_hazy))) 140 | self.accelerator.print(' : (light:{})'.format(len(self.ds_light))) 141 | self.accelerator.print(' : (rain :{})'.format(len(self.ds_rain))) 142 | self.accelerator.print(' : (snow :{})'.format(len(self.ds_snow))) 143 | self.accelerator.print(' : (blur :{})'.format(len(self.ds_blur))) 144 | 145 | self.ds_eval_hazy = test_dataset(eval_folder,task_folder='Dehazing') 146 | self.ds_eval_light = test_dataset(eval_folder,task_folder='Enlighening') 147 | self.ds_eval_rain = test_dataset(eval_folder,task_folder='Deraining') 148 | self.ds_eval_snow = test_dataset(eval_folder,task_folder='Desnowing') 149 | self.ds_eval_blur = test_dataset(eval_folder,task_folder='Deblur') 150 | 151 | self.dl_eval_hazy = self.accelerator.prepare(DataLoader(self.ds_eval_hazy, batch_size = 1)) 152 | self.dl_eval_light = self.accelerator.prepare(DataLoader(self.ds_eval_light, batch_size = 1)) 153 | self.dl_eval_rain = self.accelerator.prepare(DataLoader(self.ds_eval_rain, batch_size = 1)) 154 | self.dl_eval_snow = self.accelerator.prepare(DataLoader(self.ds_eval_snow, batch_size = 1)) 155 | self.dl_eval_blur = self.accelerator.prepare(DataLoader(self.ds_eval_blur, batch_size = 1)) 156 | 157 | if self.accelerator.is_main_process: 158 | self.accelerator.print('Validation Samplies : (hazy :{})'.format(len(self.ds_eval_hazy))) 159 | self.accelerator.print(' : (light:{})'.format(len(self.ds_eval_light))) 160 | self.accelerator.print(' : (rain :{})'.format(len(self.ds_eval_rain))) 161 | self.accelerator.print(' : (snow :{})'.format(len(self.ds_eval_snow))) 162 | self.accelerator.print(' : (blur :{})'.format(len(self.ds_eval_blur))) 163 | 164 | self.opt = Adam(diffusion_model.parameters(), lr = train_lr, betas = adam_betas) 165 | self.ema = EMA(diffusion_model, beta = ema_decay, update_every = ema_update_every) 166 | self.ema.to(self.device) 167 | 168 | self.results_folder = Path(results_folder) 169 | self.results_folder.mkdir(exist_ok = True) 170 | 171 | self.train_num_steps = train_num_steps 172 | self.step = 0 173 | 174 | self.model, self.opt = self.accelerator.prepare(self.model, self.opt) 175 | self.calculate_metric = calculate_metric and self.accelerator.is_main_process 176 | 177 | if save_best_and_latest_only: 178 | self.best_metric = 0 179 | 180 | self.save_best_and_latest_only = save_best_and_latest_only 181 | 182 | @property 183 | def device(self): 184 | return self.accelerator.device 185 | 186 | def save(self, milestone = None): 187 | if not self.accelerator.is_local_main_process: 188 | return 189 | data = { 190 | 'step': self.step, 191 | 'model': self.accelerator.get_state_dict(self.model), 192 | 'opt': self.opt.state_dict(), 193 | 'ema': self.ema.state_dict(), 194 | 'scaler': self.accelerator.scaler.state_dict() if exists(self.accelerator.scaler) else None, 195 | } 196 | checkpoint_save_path = os.path.join(self.results_folder,f'model-{milestone}') 197 | if not os.path.exists(checkpoint_save_path): 198 | os.makedirs(checkpoint_save_path) 199 | torch.save(data, checkpoint_save_path + '/' + f'model-{milestone}.pt') 200 | 201 | def load(self, milestone= None): 202 | accelerator = self.accelerator 203 | device = accelerator.device 204 | checkpoint_save_path = os.path.join(self.results_folder,f'model-{milestone}') 205 | data = torch.load('./ckpt/pretrained.pt', map_location=device) 206 | model = self.accelerator.unwrap_model(self.model) 207 | model.load_state_dict(data['model']) 208 | self.step = data['step'] 209 | self.opt.load_state_dict(data['opt']) 210 | self.ema.load_state_dict(data["ema"]) 211 | if exists(self.accelerator.scaler) and exists(data['scaler']): 212 | self.accelerator.scaler.load_state_dict(data['scaler']) 213 | 214 | def cal_psnr(self,img_ref, img_gen, data_range = 255.0): 215 | mse = np.mean((img_ref.astype(np.float32)/data_range - img_gen.astype(np.float32)/data_range) ** 2) 216 | if mse < 1.0e-10: 217 | return 100 218 | PIXEL_MAX = 1 219 | return 20 * math.log10(PIXEL_MAX / math.sqrt(mse)) 220 | 221 | def cal_ssim(self,img_ref, img_gen): 222 | ssim_val = 0 223 | for i in range(img_ref.shape[-1]): 224 | ssim_val = ssim_val + structural_similarity(img_ref[:,:,i], img_gen[:,:,i]) 225 | return ssim_val/img_ref.shape[-1] 226 | 227 | def train(self): 228 | accelerator = self.accelerator 229 | device = accelerator.device 230 | track_metric_json_path = os.path.join(self.results_folder,'metric.json') 231 | if self.accelerator.is_main_process: 232 | create_empty_json(track_metric_json_path) 233 | with tqdm(initial=self.step, total=self.train_num_steps, disable=not accelerator.is_main_process) as pbar: 234 | while self.step < self.train_num_steps: 235 | self.model.train() 236 | hazy_name,hazy_target,hazy_input = next(self.dl_hazy) 237 | light_name,light_target,light_input = next(self.dl_light) 238 | rain_name,rain_target,rain_input = next(self.dl_rain) 239 | snow_name,snow_target,snow_input = next(self.dl_snow) 240 | blur_name,blur_target,blur_input = next(self.dl_blur) 241 | file_name = hazy_name + light_name + rain_name + snow_name + blur_name 242 | file_target = torch.cat([hazy_target,light_target,rain_target,snow_target,blur_target],dim=0) 243 | file_input = torch.cat([hazy_input,light_input,rain_input,snow_input,blur_input],dim=0) 244 | with self.accelerator.autocast(): 245 | loss = self.model(img = [file_target.to(device),file_input.to(device)])[0] 246 | self.accelerator.backward(loss) 247 | pbar.set_description(f'loss: {loss.item():.4f}') 248 | accelerator.wait_for_everyone() 249 | accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) 250 | self.opt.step() 251 | self.opt.zero_grad() 252 | accelerator.wait_for_everyone() 253 | 254 | if self.step != 0 and divisible_by(self.step, self.save_and_sample_every): 255 | self.test(dataloader = self.dl_eval_rain, degradation = 'Deraining') 256 | self.test(dataloader = self.dl_eval_light, degradation = 'Enlighening') 257 | self.test(dataloader = self.dl_eval_blur, degradation = 'Deblur') 258 | self.test(dataloader = self.dl_eval_snow, degradation = 'Desnowing') 259 | self.test(dataloader = self.dl_eval_hazy, degradation = 'Dehazing') 260 | 261 | if self.accelerator.is_main_process: 262 | if self.step != 0 and divisible_by(self.step, self.save_and_sample_every): 263 | write_json(track_metric_json_path,f'model-{self.step} : ') 264 | degradation_types = ['Enlighening', 'Desnowing', 'Deraining', 'Deblur', 'Dehazing'] 265 | for degradation in degradation_types: 266 | json_path = os.path.join(self.results_folder,f'model-{self.step}') + '/{}.json'.format(degradation) 267 | psnr_val,ssim_val = readline_json(json_path,'psnr'),readline_json(json_path,'ssim') 268 | accelerator.print('{} -> (PSNR/SSIM) : {:.6f}/{:.6f} '.format(degradation,psnr_val,ssim_val)) 269 | write_json(track_metric_json_path,'{} -> (PSNR/SSIM) : {:.6f}/{:.6f} '.format(degradation,psnr_val,ssim_val)) 270 | 271 | accelerator.wait_for_everyone() 272 | 273 | self.ema.update() 274 | if self.accelerator.is_main_process: 275 | if self.step != 0 and divisible_by(self.step, self.save_and_sample_every): 276 | accelerator.print('save model checkpoint') 277 | self.save(self.step) 278 | 279 | accelerator.wait_for_everyone() 280 | self.step += 1 281 | pbar.update(1) 282 | 283 | accelerator.print('Training complete') 284 | 285 | def test(self,dataloader,degradation): 286 | self.accelerator.wait_for_everyone() 287 | if self.accelerator.is_main_process: 288 | start_time = time.time() 289 | save_json_dir = os.path.join(self.results_folder,f'model-{self.step}') 290 | create_folder(save_json_dir) 291 | save_json_path = save_json_dir + '/{}.json'.format(degradation) 292 | create_empty_json(save_json_path) 293 | self.accelerator.wait_for_everyone() 294 | save_json_path = os.path.join(self.results_folder,f'model-{self.step}') + '/{}.json'.format(degradation) 295 | self.ema.model.eval() 296 | for batch_id,batch in enumerate(dataloader): 297 | name_path,image_tf,condi_tf = batch 298 | img_gen = self.ema.model.sample(condi_tf.to(self.device)) 299 | for element_id in range(len(name_path)): 300 | image_np_ref = self.tf2img(image_tf[element_id,:,:,].unsqueeze(0)) 301 | image_np_gen = self.tf2img(img_gen[element_id,:,:,].unsqueeze(0)) 302 | 303 | psnr_val = self.cal_psnr(image_np_ref,image_np_gen) 304 | ssim_val = self.cal_ssim(image_np_ref,image_np_gen) 305 | 306 | data_dump_info = { 307 | 'file_path' : name_path[element_id], 308 | 'psnr' : psnr_val, 309 | 'ssim' : ssim_val, 310 | } 311 | print(batch_id,name_path,'PSNR / SSIM : {:.6f} : {:.6f}'.format(psnr_val,ssim_val)) 312 | write_json(save_json_path,data_dump_info) 313 | image_save_dir = os.path.join(self.results_folder,f'model-{self.step}',name_path[element_id].split('/')[-5],name_path[element_id].split('/')[-4]) 314 | create_folder(image_save_dir) 315 | imageio.imwrite(os.path.join(image_save_dir,name_path[element_id].split('/')[-1]),image_np_gen) 316 | 317 | if self.accelerator.is_main_process: 318 | end_time = time.time() 319 | test_time_consuming = end_time - start_time 320 | self.accelerator.print('Test_time_consuming : {:.6} s'.format(test_time_consuming)) 321 | 322 | self.accelerator.wait_for_everyone() 323 | 324 | def tf2np(self,image_tf): 325 | n,c,h,w = image_tf.size() 326 | assert n == 1 327 | if c == 1: 328 | image_np = image_tf.squeeze(0).squeeze(0).detach().cpu().numpy() 329 | else: 330 | image_np = image_tf.squeeze(0).permute(1,2,0).detach().cpu().numpy() 331 | 332 | return image_np 333 | 334 | def tf2img(self,image_tf): 335 | image_np = self.tf2np(torch.clamp(image_tf,min=0.,max=1.)) 336 | image_np = (image_np * 255).astype(np.uint8) 337 | return image_np 338 | 339 | def train_ddp_accelerate(args): 340 | train_folder = '' 341 | eval_folder = '' 342 | print('Procedure Running: ',args.project_description) 343 | image_size = 256 344 | num_unet = 1 345 | objective = 'pred_res' 346 | ddim_sampling_eta = 0.0 347 | test_res_or_noise = "res" 348 | sum_scale = 0.01 349 | delta_end = 2.0e-3 350 | condition = True 351 | sampling_timesteps = 8 352 | model = UnetRes(dim=64, dim_mults=(1, 2, 4, 8),num_unet=num_unet, condition=condition, objective=objective, test_res_or_noise = test_res_or_noise) 353 | diffusion = ResidualDiffusion(model,image_size=image_size, timesteps=1000,delta_end = delta_end,sampling_timesteps=sampling_timesteps, objective=objective,ddim_sampling_eta= ddim_sampling_eta,loss_type='l1',condition=condition,sum_scale=sum_scale,test_res_or_noise = test_res_or_noise) 354 | diffusion_process_trainer = Trainer( 355 | diffusion_model = diffusion, 356 | train_folder = train_folder, 357 | eval_folder = eval_folder, 358 | train_num_steps = 500000, 359 | train_batch_size = 16, 360 | save_and_sample_every = 5000, 361 | save_best_and_latest_only = True, 362 | calculate_metric = True, 363 | results_folder = './save_folder', 364 | gradient_accumulate_every = 1, 365 | ) 366 | diffusion_process_trainer.load() 367 | diffusion_process_trainer.train() 368 | print('Procedure Termination: (Finished)') 369 | 370 | if __name__ == '__main__': 371 | args = parser.parse_args() 372 | set_seed(0) 373 | train_ddp_accelerate(args) -------------------------------------------------------------------------------- /code/model.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import glob 3 | import math 4 | import os 5 | import random 6 | from collections import namedtuple 7 | from functools import partial 8 | from multiprocessing import cpu_count 9 | from pathlib import Path 10 | import torchvision.transforms as transforms 11 | 12 | import cv2 13 | import numpy as np 14 | import torch 15 | import torch.nn.functional as F 16 | import torchvision.transforms.functional as TF 17 | from accelerate import Accelerator 18 | from einops import rearrange, reduce 19 | from einops.layers.torch import Rearrange 20 | from ema_pytorch import EMA 21 | from PIL import Image 22 | import time 23 | from torch import einsum, nn 24 | from torch.optim import Adam, RAdam 25 | from torch.utils.data import DataLoader 26 | from torchvision import transforms as T 27 | from torchvision import utils 28 | from tqdm.auto import tqdm 29 | import copy 30 | 31 | # import torch_npu 32 | # from torch_npu.contrib import transfer_to_npu 33 | 34 | ModelResPrediction = namedtuple('ModelResPrediction', ['pred_res', 'pred_noise', 'pred_x_start']) 35 | 36 | def set_seed(SEED): 37 | torch.manual_seed(SEED) 38 | torch.cuda.manual_seed_all(SEED) 39 | np.random.seed(SEED) 40 | random.seed(SEED) 41 | 42 | def exists(x): 43 | return x is not None 44 | 45 | def default(val, d): 46 | if exists(val): 47 | return val 48 | return d() if callable(d) else d 49 | 50 | def identity(t, *args, **kwargs): 51 | return t 52 | 53 | def normalize_to_neg_one_to_one(img): 54 | if isinstance(img, list): 55 | return [img[k] * 2 - 1 for k in range(len(img))] 56 | else: 57 | return img * 2 - 1 58 | 59 | def unnormalize_to_zero_to_one(img): 60 | if isinstance(img, list): 61 | return [(img[k] + 1) * 0.5 for k in range(len(img))] 62 | else: 63 | return (img + 1) * 0.5 64 | 65 | class Residual(nn.Module): 66 | def __init__(self, fn): 67 | super().__init__() 68 | self.fn = fn 69 | 70 | def forward(self, x, *args, **kwargs): 71 | return self.fn(x, *args, **kwargs) + x 72 | 73 | 74 | def Upsample(dim, dim_out=None): 75 | return nn.Sequential( 76 | nn.Upsample(scale_factor=2, mode='nearest'), 77 | nn.Conv2d(dim, default(dim_out, dim), 3, padding=1) 78 | ) 79 | 80 | 81 | def Downsample(dim, dim_out=None): 82 | return nn.Conv2d(dim, default(dim_out, dim), 4, 2, 1) 83 | 84 | 85 | class WeightStandardizedConv2d(nn.Conv2d): 86 | def forward(self, x): 87 | eps = 1e-5 if x.dtype == torch.float32 else 1e-3 88 | 89 | weight = self.weight 90 | mean = reduce(weight, 'o ... -> o 1 1 1', 'mean') 91 | var = reduce(weight, 'o ... -> o 1 1 1', 92 | partial(torch.var, unbiased=False)) 93 | normalized_weight = (weight - mean) * (var + eps).rsqrt() 94 | 95 | return F.conv2d(x, normalized_weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 96 | 97 | 98 | class LayerNorm(nn.Module): 99 | def __init__(self, dim): 100 | super().__init__() 101 | self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) 102 | 103 | def forward(self, x): 104 | eps = 1e-5 if x.dtype == torch.float32 else 1e-3 105 | var = torch.var(x, dim=1, unbiased=False, keepdim=True) 106 | mean = torch.mean(x, dim=1, keepdim=True) 107 | return (x - mean) * (var + eps).rsqrt() * self.g 108 | 109 | 110 | class PreNorm(nn.Module): 111 | def __init__(self, dim, fn): 112 | super().__init__() 113 | self.fn = fn 114 | self.norm = LayerNorm(dim) 115 | 116 | def forward(self, x): 117 | x = self.norm(x) 118 | return self.fn(x) 119 | 120 | class SinusoidalPosEmb(nn.Module): 121 | def __init__(self, dim): 122 | super().__init__() 123 | self.dim = dim 124 | 125 | def forward(self, x): 126 | device = x.device 127 | half_dim = self.dim // 2 128 | emb = math.log(10000) / (half_dim - 1) 129 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb) 130 | emb = x[:, None] * emb[None, :] 131 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 132 | return emb 133 | 134 | 135 | class RandomOrLearnedSinusoidalPosEmb(nn.Module): 136 | def __init__(self, dim, is_random=False): 137 | super().__init__() 138 | assert (dim % 2) == 0 139 | half_dim = dim // 2 140 | self.weights = nn.Parameter(torch.randn( 141 | half_dim), requires_grad=not is_random) 142 | 143 | def forward(self, x): 144 | x = rearrange(x, 'b -> b 1') 145 | freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi 146 | fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) 147 | fouriered = torch.cat((x, fouriered), dim=-1) 148 | return fouriered 149 | 150 | class Block(nn.Module): 151 | def __init__(self, dim, dim_out, groups=8): 152 | super().__init__() 153 | self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding=1) 154 | self.norm = nn.GroupNorm(groups, dim_out) 155 | self.act = nn.SiLU() 156 | 157 | def forward(self, x, scale_shift=None): 158 | x = self.proj(x) 159 | x = self.norm(x) 160 | 161 | if exists(scale_shift): 162 | scale, shift = scale_shift 163 | x = x * (scale + 1) + shift 164 | 165 | x = self.act(x) 166 | return x 167 | 168 | 169 | class ResnetBlock(nn.Module): 170 | def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8): 171 | super().__init__() 172 | self.mlp = nn.Sequential( 173 | nn.SiLU(), 174 | nn.Linear(time_emb_dim, dim_out * 2) 175 | ) if exists(time_emb_dim) else None 176 | 177 | self.block1 = Block(dim, dim_out, groups=groups) 178 | self.block2 = Block(dim_out, dim_out, groups=groups) 179 | self.res_conv = nn.Conv2d( 180 | dim, dim_out, 1) if dim != dim_out else nn.Identity() 181 | 182 | def forward(self, x, time_emb=None): 183 | 184 | scale_shift = None 185 | if exists(self.mlp) and exists(time_emb): 186 | time_emb = self.mlp(time_emb) 187 | time_emb = rearrange(time_emb, 'b c -> b c 1 1') 188 | scale_shift = time_emb.chunk(2, dim=1) 189 | 190 | h = self.block1(x, scale_shift=scale_shift) 191 | 192 | h = self.block2(h) 193 | 194 | return h + self.res_conv(x) 195 | 196 | 197 | class LinearAttention(nn.Module): 198 | def __init__(self, dim, heads=4, dim_head=32): 199 | super().__init__() 200 | self.scale = dim_head ** -0.5 201 | self.heads = heads 202 | hidden_dim = dim_head * heads 203 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) 204 | 205 | self.to_out = nn.Sequential( 206 | nn.Conv2d(hidden_dim, dim, 1), 207 | LayerNorm(dim) 208 | ) 209 | 210 | def forward(self, x): 211 | b, c, h, w = x.shape 212 | qkv = self.to_qkv(x).chunk(3, dim=1) 213 | q, k, v = map(lambda t: rearrange( 214 | t, 'b (h c) x y -> b h c (x y)', h=self.heads), qkv) 215 | 216 | q = q.softmax(dim=-2) 217 | k = k.softmax(dim=-1) 218 | 219 | q = q * self.scale 220 | v = v / (h * w) 221 | 222 | context = torch.einsum('b h d n, b h e n -> b h d e', k, v) 223 | 224 | out = torch.einsum('b h d e, b h d n -> b h e n', context, q) 225 | out = rearrange(out, 'b h c (x y) -> b (h c) x y', 226 | h=self.heads, x=h, y=w) 227 | return self.to_out(out) 228 | 229 | 230 | class Attention(nn.Module): 231 | def __init__(self, dim, heads=4, dim_head=32): 232 | super().__init__() 233 | self.scale = dim_head ** -0.5 234 | self.heads = heads 235 | hidden_dim = dim_head * heads 236 | 237 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) 238 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 239 | 240 | def forward(self, x): 241 | b, c, h, w = x.shape 242 | qkv = self.to_qkv(x).chunk(3, dim=1) 243 | q, k, v = map(lambda t: rearrange( 244 | t, 'b (h c) x y -> b h c (x y)', h=self.heads), qkv) 245 | 246 | q = q * self.scale 247 | 248 | sim = einsum('b h d i, b h d j -> b h i j', q, k) 249 | attn = sim.softmax(dim=-1) 250 | out = einsum('b h i j, b h d j -> b h i d', attn, v) 251 | 252 | out = rearrange(out, 'b h (x y) d -> b (h d) x y', x=h, y=w) 253 | return self.to_out(out) 254 | 255 | 256 | class Unet(nn.Module): 257 | def __init__( 258 | self, 259 | dim, 260 | init_dim=None, 261 | out_dim=None, 262 | dim_mults=(1, 2, 4, 8), 263 | channels=3, 264 | resnet_block_groups=8, 265 | learned_variance=False, 266 | learned_sinusoidal_cond=False, 267 | random_fourier_features=False, 268 | learned_sinusoidal_dim=16, 269 | condition=False, 270 | ): 271 | super().__init__() 272 | 273 | self.channels = channels 274 | self.depth = len(dim_mults) 275 | input_channels = channels + channels * (1 if condition else 0) 276 | 277 | init_dim = default(init_dim, dim) 278 | self.init_conv = nn.Conv2d(input_channels, init_dim, 7, padding=3) 279 | 280 | dims = [init_dim, *map(lambda m: dim * m, dim_mults)] 281 | in_out = list(zip(dims[:-1], dims[1:])) 282 | 283 | block_klass = partial(ResnetBlock, groups=resnet_block_groups) 284 | 285 | time_dim = dim * 4 286 | 287 | self.random_or_learned_sinusoidal_cond = learned_sinusoidal_cond or random_fourier_features 288 | 289 | if self.random_or_learned_sinusoidal_cond: 290 | sinu_pos_emb = RandomOrLearnedSinusoidalPosEmb( 291 | learned_sinusoidal_dim, random_fourier_features) 292 | fourier_dim = learned_sinusoidal_dim + 1 293 | else: 294 | sinu_pos_emb = SinusoidalPosEmb(dim) 295 | fourier_dim = dim 296 | 297 | self.time_mlp = nn.Sequential( 298 | sinu_pos_emb, 299 | nn.Linear(fourier_dim, time_dim), 300 | nn.GELU(), 301 | nn.Linear(time_dim, time_dim) 302 | ) 303 | 304 | self.downs = nn.ModuleList([]) 305 | self.ups = nn.ModuleList([]) 306 | num_resolutions = len(in_out) 307 | 308 | for ind, (dim_in, dim_out) in enumerate(in_out): 309 | is_last = ind >= (num_resolutions - 1) 310 | 311 | self.downs.append(nn.ModuleList([ 312 | block_klass(dim_in, dim_in, time_emb_dim=time_dim), 313 | block_klass(dim_in, dim_in, time_emb_dim=time_dim), 314 | Residual(PreNorm(dim_in, LinearAttention(dim_in))), 315 | Downsample(dim_in, dim_out) if not is_last else nn.Conv2d( 316 | dim_in, dim_out, 3, padding=1) 317 | ])) 318 | 319 | mid_dim = dims[-1] 320 | self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) 321 | self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim))) 322 | self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) 323 | 324 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out)): 325 | is_last = ind == (len(in_out) - 1) 326 | 327 | self.ups.append(nn.ModuleList([ 328 | block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim), 329 | block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim), 330 | Residual(PreNorm(dim_out, LinearAttention(dim_out))), 331 | Upsample(dim_out, dim_in) if not is_last else nn.Conv2d( 332 | dim_out, dim_in, 3, padding=1) 333 | ])) 334 | 335 | default_out_dim = channels * (1 if not learned_variance else 2) 336 | self.out_dim = default(out_dim, default_out_dim) 337 | 338 | self.final_res_block = block_klass(dim * 2, dim, time_emb_dim=time_dim) 339 | self.final_conv = nn.Conv2d(dim, self.out_dim, 1) 340 | 341 | def check_image_size(self, x, h, w): 342 | s = int(math.pow(2, self.depth)) 343 | mod_pad_h = (s - h % s) % s 344 | mod_pad_w = (s - w % s) % s 345 | x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') 346 | return x 347 | 348 | def forward(self, x, time): 349 | H, W = x.shape[2:] 350 | x = self.check_image_size(x, H, W) 351 | x = self.init_conv(x) 352 | r = x.clone() 353 | 354 | t = self.time_mlp(time) 355 | 356 | h = [] 357 | 358 | for block1, block2, attn, downsample in self.downs: 359 | x = block1(x, t) 360 | h.append(x) 361 | 362 | x = block2(x, t) 363 | x = attn(x) 364 | h.append(x) 365 | 366 | x = downsample(x) 367 | 368 | x = self.mid_block1(x, t) 369 | x = self.mid_attn(x) 370 | x = self.mid_block2(x, t) 371 | 372 | for block1, block2, attn, upsample in self.ups: 373 | x = torch.cat((x, h.pop()), dim=1) 374 | x = block1(x, t) 375 | 376 | x = torch.cat((x, h.pop()), dim=1) 377 | x = block2(x, t) 378 | x = attn(x) 379 | 380 | x = upsample(x) 381 | 382 | x = torch.cat((x, r), dim=1) 383 | 384 | x = self.final_res_block(x, t) 385 | x = self.final_conv(x) 386 | x = x[..., :H, :W].contiguous() 387 | return x 388 | 389 | 390 | class UnetRes(nn.Module): 391 | def __init__( 392 | self, 393 | dim, 394 | init_dim=None, 395 | out_dim=None, 396 | dim_mults=(1, 2, 4, 8), 397 | channels=3, 398 | resnet_block_groups=8, 399 | learned_variance=False, 400 | learned_sinusoidal_cond=False, 401 | random_fourier_features=False, 402 | learned_sinusoidal_dim=16, 403 | num_unet=1, 404 | condition=False, 405 | objective='pred_res_noise', 406 | test_res_or_noise="res_noise" 407 | ): 408 | super().__init__() 409 | self.condition = condition 410 | self.channels = channels 411 | default_out_dim = channels * (1 if not learned_variance else 2) 412 | self.out_dim = default(out_dim, default_out_dim) 413 | self.random_or_learned_sinusoidal_cond = learned_sinusoidal_cond or random_fourier_features 414 | self.num_unet = num_unet 415 | self.objective = objective 416 | self.test_res_or_noise = test_res_or_noise 417 | 418 | self.unet0 = Unet(dim, 419 | init_dim=init_dim, 420 | out_dim=out_dim, 421 | dim_mults=dim_mults, 422 | channels=channels, 423 | resnet_block_groups=resnet_block_groups, 424 | learned_variance=learned_variance, 425 | learned_sinusoidal_cond=learned_sinusoidal_cond, 426 | random_fourier_features=random_fourier_features, 427 | learned_sinusoidal_dim=learned_sinusoidal_dim, 428 | condition=condition) 429 | 430 | def forward(self, x, time): 431 | if self.objective == "pred_noise": 432 | time = time[1] 433 | elif self.objective == "pred_res": 434 | time = time[0] 435 | return [self.unet0(x, time)] 436 | 437 | def extract(a, t, x_shape): 438 | b, *_ = t.shape 439 | out = a.gather(-1, t) 440 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 441 | 442 | def gen_coefficients(timesteps, schedule="increased", sum_scale=1, ratio=1): 443 | if schedule == "increased": 444 | x = np.linspace(0, 1, timesteps, dtype=np.float32) 445 | y = x**ratio 446 | y = torch.from_numpy(y) 447 | y_sum = y.sum() 448 | alphas = y/y_sum 449 | elif schedule == "decreased": 450 | x = np.linspace(0, 1, timesteps, dtype=np.float32) 451 | y = x**ratio 452 | y = torch.from_numpy(y) 453 | y_sum = y.sum() 454 | y = torch.flip(y, dims=[0]) 455 | alphas = y/y_sum 456 | elif schedule == "lamda": 457 | x = np.linspace(0.0001, 0.02, timesteps, dtype=np.float32) 458 | y = x**ratio 459 | y = torch.from_numpy(y) 460 | alphas = 1 - y 461 | elif schedule == "average": 462 | alphas = torch.full([timesteps], 1/timesteps, dtype=torch.float32) 463 | elif schedule == "normal": 464 | sigma = 1.0 465 | mu = 0.0 466 | x = np.linspace(-3+mu, 3+mu, timesteps, dtype=np.float32) 467 | y = np.e**(-((x-mu)**2)/(2*(sigma**2)))/(np.sqrt(2*np.pi)*(sigma**2)) 468 | y = torch.from_numpy(y) 469 | alphas = y/y.sum() 470 | else: 471 | alphas = torch.full([timesteps], 1/timesteps, dtype=torch.float32) 472 | 473 | return alphas*sum_scale 474 | 475 | def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor: 476 | def alpha_bar(time_step): 477 | return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 478 | 479 | betas = [] 480 | for i in range(num_diffusion_timesteps): 481 | t1 = i / num_diffusion_timesteps 482 | t2 = (i + 1) / num_diffusion_timesteps 483 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 484 | return torch.tensor(betas, dtype=torch.float32) 485 | 486 | class ResidualDiffusion(nn.Module): 487 | def __init__( 488 | self, 489 | model, 490 | *, 491 | image_size, 492 | timesteps=1000, 493 | delta_end = 2.0e-3, 494 | sampling_timesteps=None, 495 | loss_type='l1', 496 | objective='pred_res_noise', 497 | ddim_sampling_eta= 0, 498 | condition=False, 499 | sum_scale=None, 500 | test_res_or_noise="None", 501 | ): 502 | super().__init__() 503 | assert not ( 504 | type(self) == ResidualDiffusion and model.channels != model.out_dim) 505 | assert not model.random_or_learned_sinusoidal_cond 506 | 507 | self.model = model 508 | self.channels = self.model.channels 509 | self.image_size = image_size 510 | self.objective = objective 511 | self.condition = condition 512 | self.test_res_or_noise = test_res_or_noise 513 | self.delta_end = delta_end 514 | 515 | if self.condition: 516 | self.sum_scale = sum_scale if sum_scale else 0.01 517 | else: 518 | self.sum_scale = sum_scale if sum_scale else 1. 519 | 520 | beta_schedule = "linear" 521 | beta_start = 0.0001 522 | beta_end = 0.02 523 | if beta_schedule == "linear": 524 | betas = torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float32) 525 | elif beta_schedule == "scaled_linear": 526 | betas = (torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype=torch.float32) ** 2) 527 | elif beta_schedule == "squaredcos_cap_v2": 528 | betas = betas_for_alpha_bar(timesteps) 529 | else: 530 | raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") 531 | 532 | delta_start = 1e-6 533 | delta = torch.linspace(delta_start, self.delta_end, timesteps, dtype=torch.float32) 534 | delta_cumsum = delta.cumsum(dim=0).clip(0, 1) 535 | 536 | alphas = 1.0 - betas 537 | alphas_cumprod = torch.cumprod(alphas, dim=0) 538 | alphas_cumsum = 1-alphas_cumprod ** 0.5 539 | betas2_cumsum = 1-alphas_cumprod 540 | 541 | alphas_cumsum_prev = F.pad(alphas_cumsum[:-1], (1, 0), value=1.) 542 | betas2_cumsum_prev = F.pad(betas2_cumsum[:-1], (1, 0), value=1.) 543 | delta_cumsum_prev = F.pad(delta_cumsum[:-1], (1, 0), value=1.) 544 | alphas = alphas_cumsum-alphas_cumsum_prev 545 | alphas[0] = 0 546 | betas2 = betas2_cumsum-betas2_cumsum_prev 547 | betas2[0] = 0 548 | betas_cumsum = torch.sqrt(betas2_cumsum) 549 | 550 | posterior_variance = betas2*betas2_cumsum_prev/betas2_cumsum 551 | posterior_variance[0] = 0 552 | 553 | timesteps, = alphas.shape 554 | self.num_timesteps = int(timesteps) 555 | self.loss_type = loss_type 556 | 557 | self.sampling_timesteps = default(sampling_timesteps, timesteps) 558 | 559 | assert self.sampling_timesteps <= timesteps 560 | self.is_ddim_sampling = self.sampling_timesteps < timesteps 561 | self.ddim_sampling_eta = ddim_sampling_eta 562 | 563 | def register_buffer(name, val): return self.register_buffer( 564 | name, val.to(torch.float32)) 565 | 566 | register_buffer('alphas', alphas) 567 | register_buffer('alphas_cumsum', alphas_cumsum) 568 | register_buffer('delta', delta) 569 | register_buffer('delta_cumsum', delta_cumsum) 570 | register_buffer('one_minus_alphas_cumsum', 1-alphas_cumsum) 571 | register_buffer('betas2', betas2) 572 | register_buffer('betas', torch.sqrt(betas2)) 573 | register_buffer('betas2_cumsum', betas2_cumsum) 574 | register_buffer('betas_cumsum', betas_cumsum) 575 | register_buffer('posterior_mean_coef1', 576 | betas2_cumsum_prev/betas2_cumsum) 577 | register_buffer('posterior_mean_coef2', 578 | (betas2_cumsum_prev)/(betas2_cumsum)*(alphas - delta) + (betas2)/(betas2_cumsum)*(alphas_cumsum_prev - delta_cumsum_prev) 579 | ) 580 | register_buffer('posterior_mean_coef3', delta + betas2/betas2_cumsum*(1 - delta_cumsum_prev)) 581 | register_buffer('posterior_variance', posterior_variance) 582 | register_buffer('posterior_log_variance_clipped', 583 | torch.log(posterior_variance.clamp(min=1e-20))) 584 | 585 | self.posterior_mean_coef1[0] = 0 586 | self.posterior_mean_coef2[0] = 0 587 | self.posterior_mean_coef3[0] = 1 588 | self.one_minus_alphas_cumsum[-1] = 1e-6 589 | 590 | 591 | def predict_noise_from_res(self, x_t, t, x_input, pred_res): 592 | return ( 593 | (x_t - (1-extract(self.delta_cumsum,t,x_t.shape)) * x_input - (extract(self.alphas_cumsum, t, x_t.shape)-1) * pred_res) /extract(self.betas_cumsum, t, x_t.shape) 594 | ) 595 | 596 | def predict_start_from_xinput_noise(self, x_t, t, x_input, noise): 597 | return ( 598 | (x_t-extract(self.alphas_cumsum, t, x_t.shape)*x_input - 599 | extract(self.betas_cumsum, t, x_t.shape) * noise + extract(self.delta_cumsum, t, x_t.shape) * x_input )/extract(self.one_minus_alphas_cumsum, t, x_t.shape) 600 | ) 601 | 602 | def predict_start_from_res_noise(self, x_t, t, x_res, noise, x_input): 603 | return ( 604 | x_t-extract(self.alphas_cumsum, t, x_t.shape) * x_res - 605 | extract(self.betas_cumsum, t, x_t.shape) * noise + extract(self.delta_cumsum, t, x_t.shape) * x_input 606 | ) 607 | 608 | def q_posterior_from_res_noise(self, x_res, noise, x_t, t, x_input): 609 | return (x_t-extract(self.alphas, t, x_t.shape) * x_res + extract(self.delta, t, x_t.shape) * x_input - 610 | (extract(self.betas2, t, x_t.shape)/extract(self.betas_cumsum, t, x_t.shape)) * noise) 611 | 612 | def q_posterior(self, pred_res, x_start, x_t, t): 613 | posterior_mean = ( 614 | extract(self.posterior_mean_coef1, t, x_t.shape) * x_t + 615 | extract(self.posterior_mean_coef2, t, x_t.shape) * pred_res + 616 | extract(self.posterior_mean_coef3, t, x_t.shape) * x_start 617 | ) 618 | posterior_variance = extract(self.posterior_variance, t, x_t.shape) 619 | posterior_log_variance_clipped = extract( 620 | self.posterior_log_variance_clipped, t, x_t.shape) 621 | return posterior_mean, posterior_variance, posterior_log_variance_clipped 622 | 623 | def model_predictions(self, x_input, x, t, task=None, clip_denoised=True): 624 | if not self.condition: 625 | x_in = x 626 | else: 627 | x_in = torch.cat((x, x_input), dim=1) 628 | model_output = self.model(x_in,[t,t]) 629 | 630 | maybe_clip = partial(torch.clamp, min=-1., 631 | max=1.) if clip_denoised else identity 632 | 633 | if self.objective == 'pred_res_noise': 634 | if self.test_res_or_noise == "res_noise": 635 | pred_res = model_output[0] 636 | pred_noise = model_output[1] 637 | pred_res = maybe_clip(pred_res) 638 | x_start = self.predict_start_from_res_noise( 639 | x, t, pred_res, pred_noise, x_input) 640 | x_start = maybe_clip(x_start) 641 | elif self.test_res_or_noise == "res": 642 | pred_res = model_output[0] 643 | pred_res = maybe_clip(pred_res) 644 | pred_noise = self.predict_noise_from_res( 645 | x, t, x_input, pred_res) 646 | x_start = x_input - pred_res 647 | x_start = maybe_clip(x_start) 648 | elif self.test_res_or_noise == "noise": 649 | pred_noise = model_output[1] 650 | x_start = self.predict_start_from_xinput_noise( 651 | x, t, x_input, pred_noise) 652 | x_start = maybe_clip(x_start) 653 | pred_res = x_input - x_start 654 | pred_res = maybe_clip(pred_res) 655 | elif self.objective == 'pred_x0_noise': 656 | pred_res = x_input-model_output[0] 657 | pred_noise = model_output[1] 658 | pred_res = maybe_clip(pred_res) 659 | x_start = maybe_clip(model_output[0]) 660 | elif self.objective == "pred_noise": 661 | pred_noise = model_output[0] 662 | x_start = self.predict_start_from_xinput_noise( 663 | x, t, x_input, pred_noise) 664 | x_start = maybe_clip(x_start) 665 | pred_res = x_input - x_start 666 | pred_res = maybe_clip(pred_res) 667 | elif self.objective == "pred_res": 668 | pred_res = model_output[0] 669 | pred_res = maybe_clip(pred_res) 670 | pred_noise = self.predict_noise_from_res(x, t, x_input, pred_res) 671 | x_start = self.predict_start_from_res_noise(x, t, pred_res, pred_noise, x_input) 672 | x_start = maybe_clip(x_start) 673 | 674 | return ModelResPrediction(pred_res, pred_noise, x_start) 675 | 676 | def p_mean_variance(self, x_input, x, t): 677 | preds = self.model_predictions(x_input, x, t) 678 | pred_res = preds.pred_res 679 | x_start = preds.pred_x_start 680 | 681 | model_mean, posterior_variance, posterior_log_variance = self.q_posterior( 682 | pred_res=pred_res, x_start=x_start, x_t=x, t=t) 683 | return model_mean, posterior_variance, posterior_log_variance, x_start 684 | 685 | @torch.no_grad() 686 | def p_sample(self, x_input, x, t: int): 687 | b, *_, device = *x.shape, x.device 688 | batched_times = torch.full( 689 | (x.shape[0],), t, device=x.device, dtype=torch.long) 690 | model_mean, _, model_log_variance, x_start = self.p_mean_variance(x_input, x=x, t=batched_times) 691 | noise = torch.randn_like(x) if t > 0 else 0. 692 | pred_img = model_mean + (0.5 * model_log_variance).exp() * noise 693 | return pred_img, x_start 694 | 695 | @torch.no_grad() 696 | def p_sample_loop(self, x_input, shape, last=True): 697 | x_input = x_input[0] 698 | 699 | batch, device = shape[0], self.betas.device 700 | 701 | if self.condition: 702 | img = x_input+math.sqrt(self.sum_scale) * \ 703 | torch.randn(shape, device=device) 704 | input_add_noise = img 705 | else: 706 | img = torch.randn(shape, device=device) 707 | 708 | x_start = None 709 | 710 | if not last: 711 | img_list = [] 712 | 713 | for t in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps): 714 | img, x_start = self.p_sample(x_input, img, t) 715 | 716 | if not last: 717 | img_list.append(img) 718 | 719 | if self.condition: 720 | if not last: 721 | img_list = [input_add_noise]+img_list 722 | else: 723 | img_list = [input_add_noise, img] 724 | return unnormalize_to_zero_to_one(img_list) 725 | else: 726 | if not last: 727 | img_list = img_list 728 | else: 729 | img_list = [img] 730 | return unnormalize_to_zero_to_one(img_list) 731 | 732 | @torch.no_grad() 733 | def first_order_sample(self, x_input, shape, last=True, task=None): 734 | x_input = x_input[0] 735 | batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[ 736 | 0], self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective 737 | 738 | times = torch.linspace(-1, total_timesteps - 1, 739 | steps=sampling_timesteps + 1) 740 | 741 | times = list(reversed(times.int().tolist())) 742 | time_pairs = list(zip(times[:-1], times[1:])) 743 | 744 | if self.condition: 745 | img = self.betas_cumsum[-1] * torch.randn(shape, device=device) 746 | input_add_noise = img 747 | else: 748 | img = torch.randn(shape, device=device) 749 | 750 | x_start = None 751 | type = "use_pred_noise" 752 | if not last: 753 | img_list = [] 754 | 755 | for time, time_next in tqdm(time_pairs, desc='sampling loop time step',disable = True): 756 | time_cond = torch.full( 757 | (batch,), time, device=device, dtype=torch.long) 758 | preds = self.model_predictions(x_input, img, time_cond, task) 759 | 760 | pred_res = preds.pred_res 761 | pred_noise = preds.pred_noise 762 | x_start = preds.pred_x_start 763 | 764 | if time_next < 0: 765 | img = x_start 766 | if not last: 767 | img_list.append(img) 768 | continue 769 | 770 | alpha_cumsum = self.alphas_cumsum[time] 771 | alpha_cumsum_next = self.alphas_cumsum[time_next] 772 | alpha = alpha_cumsum-alpha_cumsum_next 773 | delta_cumsum = self.delta_cumsum[time] 774 | delta_cumsum_next = self.delta_cumsum[time_next] 775 | delta = delta_cumsum-delta_cumsum_next 776 | betas2_cumsum = self.betas2_cumsum[time] 777 | betas2_cumsum_next = self.betas2_cumsum[time_next] 778 | betas2 = betas2_cumsum-betas2_cumsum_next 779 | betas = betas2.sqrt() 780 | betas_cumsum = self.betas_cumsum[time] 781 | betas_cumsum_next = self.betas_cumsum[time_next] 782 | betas2_div_betas_cumsum = betas_cumsum-betas_cumsum_next 783 | 784 | if type == "use_pred_noise": 785 | img = img - alpha*pred_res + delta*x_input - betas2_div_betas_cumsum * pred_noise 786 | 787 | elif type == "use_x_start": 788 | img = q*img + \ 789 | (1-q)*x_start + \ 790 | (alpha_cumsum_next-alpha_cumsum*q)*pred_res + \ 791 | (delta_cumsum*q-delta_cumsum_next)*x_input + \ 792 | sigma2.sqrt()*noise 793 | elif type == "special_eta_0": 794 | img = img - alpha*pred_res - \ 795 | (betas_cumsum-betas_cumsum_next)*pred_noise 796 | elif type == "special_eta_1": 797 | img = img - alpha*pred_res - betas2/betas_cumsum*pred_noise + \ 798 | betas*betas2_cumsum_next.sqrt()/betas_cumsum*noise 799 | 800 | if not last: 801 | img_list.append(img) 802 | 803 | if self.condition: 804 | if not last: 805 | img_list = [input_add_noise]+img_list 806 | else: 807 | img_list = [input_add_noise, img] 808 | return unnormalize_to_zero_to_one(img_list) 809 | else: 810 | if not last: 811 | img_list = img_list 812 | else: 813 | img_list = [img] 814 | return unnormalize_to_zero_to_one(img_list) 815 | 816 | 817 | @torch.no_grad() 818 | def second_order_sample(self, x_input, shape, last=True, task=None): 819 | x_input = x_input[0] 820 | batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[ 821 | 0], self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective 822 | 823 | times = torch.linspace(-1, total_timesteps - 1,steps=sampling_timesteps + 1) 824 | times = list(reversed(times.int().tolist())) 825 | time_pairs = list(zip(times[:-1], times[1:])) 826 | 827 | if self.condition: 828 | img = (1-self.delta_cumsum[-1]) * x_input + math.sqrt(self.sum_scale) * torch.randn(shape, device=device) 829 | input_add_noise = img 830 | else: 831 | img = torch.randn(shape, device=device) 832 | 833 | x_start = None 834 | type = "use_pred_noise" 835 | 836 | if not last: 837 | img_list = [] 838 | 839 | r = 0.5 840 | for time, time_next in tqdm(time_pairs, desc='sampling loop time step',disable = True): 841 | time_internal = int((1-r) * time + r * time_next) 842 | time_cond = torch.full((batch,), time, device=device, dtype=torch.long) 843 | time_cond_internal = torch.full((batch,), time_internal, device=device, dtype=torch.long) 844 | time_cond_next = torch.full((batch,), time_next, device=device, dtype=torch.long) 845 | 846 | alpha_cumsum = self.alphas_cumsum[time] 847 | alpha_cumsum_internal = self.alphas_cumsum[time_internal] 848 | alpha_cumsum_next = self.alphas_cumsum[time_next] 849 | alpha_u = alpha_cumsum-alpha_cumsum_internal 850 | alpha_t = alpha_cumsum-alpha_cumsum_next 851 | delta_cumsum = self.delta_cumsum[time] 852 | delta_cumsum_internal = self.delta_cumsum[time_internal] 853 | delta_cumsum_next = self.delta_cumsum[time_next] 854 | delta_u = delta_cumsum-delta_cumsum_internal 855 | delta_t = delta_cumsum-delta_cumsum_next 856 | betas2_cumsum = self.betas2_cumsum[time] 857 | betas2_cumsum_internal = self.betas2_cumsum[time_internal] 858 | betas2_cumsum_next = self.betas2_cumsum[time_next] 859 | betas2_u = betas2_cumsum-betas2_cumsum_internal 860 | betas2_t = betas2_cumsum-betas2_cumsum_next 861 | betas_cumsum = self.betas_cumsum[time] 862 | betas_cumsum_internal = self.betas_cumsum[time_internal] 863 | betas_cumsum_next = self.betas_cumsum[time_next] 864 | betas_u = betas_cumsum-betas_cumsum_internal 865 | betas_t = betas_cumsum-betas_cumsum_next 866 | 867 | preds_time_cond = self.model_predictions(x_input, img, time_cond) 868 | pred_res_time_cond = preds_time_cond.pred_res 869 | pred_noise_time_cond = preds_time_cond.pred_noise 870 | x_start_time_cond = preds_time_cond.pred_x_start 871 | 872 | if time_next < 0: 873 | img = x_start_time_cond 874 | if not last: 875 | img_list.append(img) 876 | continue 877 | 878 | img_u = img + delta_u*x_input - alpha_u*pred_res_time_cond - betas_u * pred_noise_time_cond 879 | 880 | preds_time_cond_internal = self.model_predictions(x_input, img_u.clone().detach(), time_cond_internal) 881 | pred_res_time_cond_internal = preds_time_cond_internal.pred_res 882 | pred_noise_time_cond_internal = preds_time_cond_internal.pred_noise 883 | x_start_time_cond_internal = preds_time_cond_internal.pred_x_start 884 | 885 | img_target = img + delta_t*x_input - alpha_t*pred_res_time_cond - betas_u * pred_noise_time_cond 886 | - 1/(2*r)*alpha_t * (pred_res_time_cond_internal - pred_res_time_cond) 887 | - 1/(2*r)*betas_t * (pred_noise_time_cond_internal - pred_noise_time_cond) 888 | 889 | img = img_target.clone().detach() 890 | 891 | if not last: 892 | img_list.append(img) 893 | 894 | if self.condition: 895 | if not last: 896 | img_list = [input_add_noise]+img_list 897 | else: 898 | img_list = [input_add_noise, img] 899 | return unnormalize_to_zero_to_one(img_list) 900 | else: 901 | if not last: 902 | img_list = img_list 903 | else: 904 | img_list = [img] 905 | return unnormalize_to_zero_to_one(img_list) 906 | 907 | @torch.no_grad() 908 | def third_order_sample(self, x_input, shape, last=True, task=None): 909 | x_input = x_input[0] 910 | batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[ 911 | 0], self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective 912 | 913 | times = torch.linspace(-1, total_timesteps - 1,steps=sampling_timesteps + 1) 914 | times = list(reversed(times.int().tolist())) 915 | time_pairs = list(zip(times[:-1], times[1:])) 916 | 917 | if self.condition: 918 | img = (1-self.delta_cumsum[-1]) * x_input + math.sqrt(self.sum_scale) * torch.randn(shape, device=device) 919 | input_add_noise = img 920 | else: 921 | img = torch.randn(shape, device=device) 922 | 923 | x_start = None 924 | type = "use_pred_noise" 925 | 926 | if not last: 927 | img_list = [] 928 | 929 | r1 = 1/3 930 | r2 = 2/3 931 | for time, time_next in tqdm(time_pairs, desc='sampling loop time step',disable = True): 932 | time_u = int((1-r1) * time + r1 * time_next) 933 | time_s = int((1-r2) * time + r2 * time_next) 934 | time_cond = torch.full((batch,), time, device=device, dtype=torch.long) 935 | time_cond_u = torch.full((batch,), time_u, device=device, dtype=torch.long) 936 | time_cond_s = torch.full((batch,), time_s, device=device, dtype=torch.long) 937 | time_cond_next = torch.full((batch,), time_next, device=device, dtype=torch.long) 938 | 939 | alpha_cumsum = self.alphas_cumsum[time] 940 | alpha_cumsum_u = self.alphas_cumsum[time_u] 941 | alpha_cumsum_s = self.alphas_cumsum[time_s] 942 | alpha_cumsum_next = self.alphas_cumsum[time_next] 943 | alpha_u = alpha_cumsum-alpha_cumsum_u 944 | alpha_s = alpha_cumsum-alpha_cumsum_s 945 | alpha_t = alpha_cumsum-alpha_cumsum_next 946 | delta_cumsum = self.delta_cumsum[time] 947 | delta_cumsum_u = self.delta_cumsum[time_u] 948 | delta_cumsum_s = self.delta_cumsum[time_s] 949 | delta_cumsum_next = self.delta_cumsum[time_next] 950 | delta_u = delta_cumsum-delta_cumsum_u 951 | delta_s = delta_cumsum-delta_cumsum_s 952 | delta_t = delta_cumsum-delta_cumsum_next 953 | betas2_cumsum = self.betas2_cumsum[time] 954 | betas2_cumsum_u = self.betas2_cumsum[time_u] 955 | betas2_cumsum_s = self.betas2_cumsum[time_s] 956 | betas2_cumsum_next = self.betas2_cumsum[time_next] 957 | betas2_u = betas2_cumsum-betas2_cumsum_u 958 | betas2_s = betas2_cumsum-betas2_cumsum_s 959 | betas2_t = betas2_cumsum-betas2_cumsum_next 960 | betas_cumsum = self.betas_cumsum[time] 961 | betas_cumsum_u = self.betas_cumsum[time_u] 962 | betas_cumsum_s = self.betas_cumsum[time_s] 963 | betas_cumsum_next = self.betas_cumsum[time_next] 964 | betas_u = betas_cumsum-betas_cumsum_u 965 | betas_s = betas_cumsum-betas_cumsum_s 966 | betas_t = betas_cumsum-betas_cumsum_next 967 | 968 | preds_time_cond = self.model_predictions(x_input, img, time_cond) 969 | pred_res_time_cond = preds_time_cond.pred_res 970 | pred_noise_time_cond = preds_time_cond.pred_noise 971 | x_start_time_cond = preds_time_cond.pred_x_start 972 | 973 | if time_next < 0: 974 | img = x_start_time_cond 975 | if not last: 976 | img_list.append(img) 977 | continue 978 | 979 | img_u = img + delta_u*x_input - alpha_u*pred_res_time_cond - betas_u * pred_noise_time_cond 980 | preds_time_u = self.model_predictions(x_input, img_u, time_cond_u) 981 | pred_res_time_u = preds_time_u.pred_res 982 | pred_noise_time_u = preds_time_u.pred_noise 983 | x_start_time_u = preds_time_u.pred_x_start 984 | 985 | img_s = img + delta_s*x_input - alpha_s*pred_res_time_cond - betas_s * pred_noise_time_cond 986 | preds_time_s = self.model_predictions(x_input, img_s, time_cond_s) 987 | pred_res_time_s = preds_time_s.pred_res 988 | pred_noise_time_s = preds_time_s.pred_noise 989 | x_start_time_s = preds_time_s.pred_x_start 990 | 991 | 992 | D1_res = pred_res_time_u - pred_res_time_cond 993 | D1_eps = pred_noise_time_u - pred_noise_time_cond 994 | D2_res = (2/(r1*r2*(r2-r1)))*(r1*pred_res_time_s - r2*pred_res_time_u + (r2-r1)*pred_res_time_cond) 995 | D2_eps = (2/(r1*r2*(r2-r1)))*(r1*pred_noise_time_s - r2*pred_noise_time_u + (r2-r1)*pred_noise_time_cond) 996 | 997 | img_target = img + delta_t*x_input - alpha_t*pred_res_time_cond - betas_u * pred_noise_time_cond 998 | - 1/(2*r1)*alpha_t * D1_res - 1/(2*r1)*betas_t * D1_eps 999 | - 1/6*alpha_t*D2_res - 1/6*betas_t*D2_eps 1000 | 1001 | img = img_target 1002 | 1003 | if not last: 1004 | img_list.append(img) 1005 | 1006 | if self.condition: 1007 | if not last: 1008 | img_list = [input_add_noise]+img_list 1009 | else: 1010 | img_list = [input_add_noise, img] 1011 | return unnormalize_to_zero_to_one(img_list) 1012 | else: 1013 | if not last: 1014 | img_list = img_list 1015 | else: 1016 | img_list = [img] 1017 | return unnormalize_to_zero_to_one(img_list) 1018 | 1019 | 1020 | def grad_and_value(self, x_prev, x_0_hat, y0, x_res): 1021 | x_pre = x_prev 1022 | difference = y0- (x_0_hat + x_res) 1023 | norm = torch.linalg.norm(difference) 1024 | norm_grad = torch.autograd.grad(outputs=norm, inputs=x_prev)[0] 1025 | return norm_grad, norm 1026 | 1027 | def first_order_UPS(self, x_input, shape, last=True, task=None): 1028 | x_input = x_input[0] 1029 | batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[ 1030 | 0], self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective 1031 | 1032 | times = torch.linspace(-1, total_timesteps - 1, steps=sampling_timesteps + 1) 1033 | 1034 | times = list(reversed(times.int().tolist())) 1035 | time_pairs = list(zip(times[:-1], times[1:])) 1036 | 1037 | if self.condition: 1038 | img = (1-self.delta_cumsum[-1]) * x_input + self.betas_cumsum[-1] * torch.randn(shape, device=device) 1039 | input_add_noise = img 1040 | else: 1041 | img = torch.randn(shape, device=device) 1042 | 1043 | x_start = None 1044 | type = "use_pred_noise" 1045 | 1046 | if not last: 1047 | img_list = [] 1048 | 1049 | for time, time_next in tqdm(time_pairs, desc='sampling loop time step',disable = True): 1050 | img = img.requires_grad_() 1051 | 1052 | time_cond = torch.full( 1053 | (batch,), time, device=device, dtype=torch.long) 1054 | preds = self.model_predictions(x_input, img, time_cond, task) 1055 | 1056 | pred_res = preds.pred_res 1057 | pred_noise = preds.pred_noise 1058 | x_start = preds.pred_x_start 1059 | 1060 | if time_next < 0: 1061 | img = x_start 1062 | if not last: 1063 | img_list.append(img) 1064 | continue 1065 | 1066 | alpha_cumsum = self.alphas_cumsum[time] 1067 | alpha_cumsum_next = self.alphas_cumsum[time_next] 1068 | alpha = alpha_cumsum-alpha_cumsum_next 1069 | delta_cumsum = self.delta_cumsum[time] 1070 | delta_cumsum_next = self.delta_cumsum[time_next] 1071 | delta = delta_cumsum-delta_cumsum_next 1072 | betas2_cumsum = self.betas2_cumsum[time] 1073 | betas2_cumsum_next = self.betas2_cumsum[time_next] 1074 | betas2 = betas2_cumsum-betas2_cumsum_next 1075 | betas = betas2.sqrt() 1076 | betas_cumsum = self.betas_cumsum[time] 1077 | betas_cumsum_next = self.betas_cumsum[time_next] 1078 | 1079 | betas2_div_betas_cumsum = betas2 / betas_cumsum 1080 | 1081 | norm_grad, norm = self.grad_and_value(x_prev=img, x_0_hat=x_start, y0=x_input, x_res = pred_res) 1082 | pred_noise = pred_noise + betas_cumsum / norm * norm_grad 1083 | 1084 | if type == "use_pred_noise": 1085 | img = img - alpha*pred_res + delta*x_input - betas2_div_betas_cumsum * pred_noise 1086 | 1087 | elif type == "use_x_start": 1088 | img = q*img + \ 1089 | (1-q)*x_start + \ 1090 | (alpha_cumsum_next-alpha_cumsum*q)*pred_res + \ 1091 | (delta_cumsum*q-delta_cumsum_next)*x_input + \ 1092 | sigma2.sqrt()*noise 1093 | elif type == "special_eta_0": 1094 | img = img - alpha*pred_res - \ 1095 | (betas_cumsum-betas_cumsum_next)*pred_noise 1096 | elif type == "special_eta_1": 1097 | img = img - alpha*pred_res - betas2/betas_cumsum*pred_noise + \ 1098 | betas*betas2_cumsum_next.sqrt()/betas_cumsum*noise 1099 | 1100 | if not last: 1101 | img_list.append(img) 1102 | 1103 | if self.condition: 1104 | if not last: 1105 | img_list = [input_add_noise]+img_list 1106 | else: 1107 | img_list = [input_add_noise, img] 1108 | return unnormalize_to_zero_to_one(img_list) 1109 | else: 1110 | if not last: 1111 | img_list = img_lists 1112 | else: 1113 | img_list = [img] 1114 | return unnormalize_to_zero_to_one(img_list) 1115 | 1116 | def second_order_UPS(self, x_input, shape, last=True, task=None): 1117 | x_input = x_input[0] 1118 | batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[ 1119 | 0], self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective 1120 | 1121 | times = torch.linspace(-1, total_timesteps - 1,steps=sampling_timesteps + 1) 1122 | times = list(reversed(times.int().tolist())) 1123 | time_pairs = list(zip(times[:-1], times[1:])) 1124 | 1125 | if self.condition: 1126 | img = (1-self.delta_cumsum[-1]) * x_input + math.sqrt(self.sum_scale) * torch.randn(shape, device=device) 1127 | input_add_noise = img 1128 | else: 1129 | img = torch.randn(shape, device=device) 1130 | 1131 | x_start = None 1132 | type = "use_pred_noise" 1133 | 1134 | if not last: 1135 | img_list = [] 1136 | 1137 | r = 0.5 1138 | for time, time_next in tqdm(time_pairs, desc='sampling loop time step',disable = True): 1139 | img = img.requires_grad_() 1140 | 1141 | time_internal = int((1-r) * time + r * time_next) 1142 | time_cond = torch.full((batch,), time, device=device, dtype=torch.long) 1143 | time_cond_internal = torch.full((batch,), time_internal, device=device, dtype=torch.long) 1144 | time_cond_next = torch.full((batch,), time_next, device=device, dtype=torch.long) 1145 | 1146 | alpha_cumsum = self.alphas_cumsum[time] 1147 | alpha_cumsum_internal = self.alphas_cumsum[time_internal] 1148 | alpha_cumsum_next = self.alphas_cumsum[time_next] 1149 | alpha_u = alpha_cumsum-alpha_cumsum_internal 1150 | alpha_t = alpha_cumsum-alpha_cumsum_next 1151 | delta_cumsum = self.delta_cumsum[time] 1152 | delta_cumsum_internal = self.delta_cumsum[time_internal] 1153 | delta_cumsum_next = self.delta_cumsum[time_next] 1154 | delta_u = delta_cumsum-delta_cumsum_internal 1155 | delta_t = delta_cumsum-delta_cumsum_next 1156 | betas2_cumsum = self.betas2_cumsum[time] 1157 | betas2_cumsum_internal = self.betas2_cumsum[time_internal] 1158 | betas2_cumsum_next = self.betas2_cumsum[time_next] 1159 | betas2_u = betas2_cumsum-betas2_cumsum_internal 1160 | betas2_t = betas2_cumsum-betas2_cumsum_next 1161 | betas_cumsum = self.betas_cumsum[time] 1162 | betas_cumsum_internal = self.betas_cumsum[time_internal] 1163 | betas_cumsum_next = self.betas_cumsum[time_next] 1164 | betas_u = betas_cumsum-betas_cumsum_internal 1165 | betas_t = betas_cumsum-betas_cumsum_next 1166 | 1167 | preds_time_cond = self.model_predictions(x_input, img, time_cond) 1168 | pred_res_time_cond = preds_time_cond.pred_res 1169 | pred_noise_time_cond = preds_time_cond.pred_noise 1170 | x_start_time_cond = preds_time_cond.pred_x_start 1171 | 1172 | if time_next < 0: 1173 | img = x_start_time_cond 1174 | if not last: 1175 | img_list.append(img) 1176 | continue 1177 | 1178 | norm_grad, norm = self.grad_and_value(x_prev=img, x_0_hat=x_start_time_cond, y0=x_input, x_res = pred_res_time_cond) 1179 | pred_noise_time_cond = pred_noise_time_cond + betas_cumsum / norm * norm_grad 1180 | 1181 | img_u = img + delta_u*x_input - alpha_u*pred_res_time_cond - betas_u * pred_noise_time_cond 1182 | 1183 | img_u = img_u.clone().detach_() 1184 | img_u = img_u.requires_grad_() 1185 | 1186 | preds_time_cond_internal = self.model_predictions(x_input, img_u, time_cond_internal) 1187 | pred_res_time_cond_internal = preds_time_cond_internal.pred_res 1188 | pred_noise_time_cond_internal = preds_time_cond_internal.pred_noise 1189 | x_start_time_cond_internal = preds_time_cond_internal.pred_x_start 1190 | 1191 | norm_grad_internal, norm_internal = self.grad_and_value(x_prev=img_u, x_0_hat=x_start_time_cond_internal, y0=x_input, x_res = pred_res_time_cond_internal) 1192 | pred_noise_time_cond_internal = pred_noise_time_cond_internal + betas_cumsum_internal / norm_internal * norm_grad_internal 1193 | 1194 | img_target = img + delta_t*x_input - alpha_t*pred_res_time_cond - betas_t * pred_noise_time_cond 1195 | - 1/(2*r)*alpha_t * (pred_res_time_cond_internal - pred_res_time_cond) 1196 | - 1/(2*r)*betas_t * (pred_noise_time_cond_internal - pred_noise_time_cond) 1197 | 1198 | img = img_target.clone().detach_() 1199 | 1200 | if not last: 1201 | img_list.append(img) 1202 | 1203 | if self.condition: 1204 | if not last: 1205 | img_list = [input_add_noise]+img_list 1206 | else: 1207 | img_list = [input_add_noise, img] 1208 | return unnormalize_to_zero_to_one(img_list) 1209 | else: 1210 | if not last: 1211 | img_list = img_list 1212 | else: 1213 | img_list = [img] 1214 | return unnormalize_to_zero_to_one(img_list) 1215 | 1216 | def third_order_UPS(self, x_input, shape, last=True, task=None): 1217 | x_input = x_input[0] 1218 | batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[ 1219 | 0], self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective 1220 | 1221 | times = torch.linspace(-1, total_timesteps - 1,steps=sampling_timesteps + 1) 1222 | times = list(reversed(times.int().tolist())) 1223 | time_pairs = list(zip(times[:-1], times[1:])) 1224 | 1225 | if self.condition: 1226 | img = (1-self.delta_cumsum[-1]) * x_input + math.sqrt(self.sum_scale) * torch.randn(shape, device=device) 1227 | input_add_noise = img 1228 | else: 1229 | img = torch.randn(shape, device=device) 1230 | 1231 | x_start = None 1232 | type = "use_pred_noise" 1233 | 1234 | if not last: 1235 | img_list = [] 1236 | 1237 | r1 = 1/3 1238 | r2 = 2/3 1239 | for time, time_next in tqdm(time_pairs, desc='sampling loop time step',disable = True): 1240 | img = img.requires_grad_() 1241 | 1242 | time_u = int((1-r1) * time + r1 * time_next) 1243 | time_s = int((1-r2) * time + r2 * time_next) 1244 | time_cond = torch.full((batch,), time, device=device, dtype=torch.long) 1245 | time_cond_u = torch.full((batch,), time_u, device=device, dtype=torch.long) 1246 | time_cond_s = torch.full((batch,), time_s, device=device, dtype=torch.long) 1247 | time_cond_next = torch.full((batch,), time_next, device=device, dtype=torch.long) 1248 | 1249 | alpha_cumsum = self.alphas_cumsum[time] 1250 | alpha_cumsum_u = self.alphas_cumsum[time_u] 1251 | alpha_cumsum_s = self.alphas_cumsum[time_s] 1252 | alpha_cumsum_next = self.alphas_cumsum[time_next] 1253 | alpha_u = alpha_cumsum-alpha_cumsum_u 1254 | alpha_s = alpha_cumsum-alpha_cumsum_s 1255 | alpha_t = alpha_cumsum-alpha_cumsum_next 1256 | delta_cumsum = self.delta_cumsum[time] 1257 | delta_cumsum_u = self.delta_cumsum[time_u] 1258 | delta_cumsum_s = self.delta_cumsum[time_s] 1259 | delta_cumsum_next = self.delta_cumsum[time_next] 1260 | delta_u = delta_cumsum-delta_cumsum_u 1261 | delta_s = delta_cumsum-delta_cumsum_s 1262 | delta_t = delta_cumsum-delta_cumsum_next 1263 | betas2_cumsum = self.betas2_cumsum[time] 1264 | betas2_cumsum_u = self.betas2_cumsum[time_u] 1265 | betas2_cumsum_s = self.betas2_cumsum[time_s] 1266 | betas2_cumsum_next = self.betas2_cumsum[time_next] 1267 | betas2_u = betas2_cumsum-betas2_cumsum_u 1268 | betas2_s = betas2_cumsum-betas2_cumsum_s 1269 | betas2_t = betas2_cumsum-betas2_cumsum_next 1270 | betas_cumsum = self.betas_cumsum[time] 1271 | betas_cumsum_u = self.betas_cumsum[time_u] 1272 | betas_cumsum_s = self.betas_cumsum[time_s] 1273 | betas_cumsum_next = self.betas_cumsum[time_next] 1274 | betas_u = betas_cumsum-betas_cumsum_u 1275 | betas_s = betas_cumsum-betas_cumsum_s 1276 | betas_t = betas_cumsum-betas_cumsum_next 1277 | 1278 | preds_time_cond = self.model_predictions(x_input, img, time_cond) 1279 | pred_res_time_cond = preds_time_cond.pred_res 1280 | pred_noise_time_cond = preds_time_cond.pred_noise 1281 | x_start_time_cond = preds_time_cond.pred_x_start 1282 | 1283 | if time_next < 0: 1284 | img = x_start_time_cond 1285 | if not last: 1286 | img_list.append(img) 1287 | continue 1288 | 1289 | norm_grad, norm = self.grad_and_value(x_prev=img, x_0_hat=x_start_time_cond, y0=x_input, x_res = pred_res_time_cond) 1290 | pred_noise_time_cond = pred_noise_time_cond + betas_cumsum / norm * norm_grad 1291 | img = img.detach_() 1292 | 1293 | img_u = img + delta_u*x_input - alpha_u*pred_res_time_cond - betas_u * pred_noise_time_cond 1294 | img_u = img_u.clone().detach_() 1295 | img_u = img_u.requires_grad_() 1296 | preds_time_cond_u = self.model_predictions(x_input, img_u, time_cond_u) 1297 | pred_res_time_cond_u = preds_time_cond_u.pred_res 1298 | pred_noise_time_cond_u = preds_time_cond_u.pred_noise 1299 | x_start_time_cond_u = preds_time_cond_u.pred_x_start 1300 | 1301 | norm_grad_u, norm_u = self.grad_and_value(x_prev=img_u, x_0_hat=x_start_time_cond_u, y0=x_input, x_res = pred_res_time_cond_u) 1302 | pred_noise_time_cond_u = pred_noise_time_cond_u + betas_cumsum_u / norm_u * norm_grad_u 1303 | img_u = img_u.detach_() 1304 | 1305 | img_s = img + delta_s*x_input - alpha_s*pred_res_time_cond - betas_s * pred_noise_time_cond 1306 | - r2/(2*r1)*alpha_s * (pred_res_time_cond_u - pred_res_time_cond) 1307 | - r2/(2*r1)*betas_s * (pred_noise_time_cond_u - pred_noise_time_cond) 1308 | 1309 | img_s = img_s.clone().detach_() 1310 | img_s = img_s.requires_grad_() 1311 | preds_time_cond_s = self.model_predictions(x_input, img_s, time_cond_s) 1312 | pred_res_time_cond_s = preds_time_cond_s.pred_res 1313 | pred_noise_time_cond_s = preds_time_cond_s.pred_noise 1314 | x_start_time_cond_s = preds_time_cond_s.pred_x_start 1315 | 1316 | norm_grad_s, norm_s = self.grad_and_value(x_prev=img_s, x_0_hat=x_start_time_cond_s, y0=x_input, x_res = pred_res_time_cond_s) 1317 | pred_noise_time_cond_s = pred_noise_time_cond_s + betas_cumsum_s / norm_s * norm_grad_s 1318 | img_s = img_s.detach_() 1319 | 1320 | D2_res = (2/(r1*r2*(r2-r1)))*(r1*pred_res_time_cond_s - r2*pred_res_time_cond_u + (r2-r1)*pred_res_time_cond) 1321 | D2_eps = (2/(r1*r2*(r2-r1)))*(r1*pred_noise_time_cond_s - r2*pred_noise_time_cond_u + (r2-r1)*pred_noise_time_cond) 1322 | 1323 | 1324 | img = img + delta_t*x_input - alpha_t*pred_res_time_cond - betas_t * pred_noise_time_cond 1325 | - 1/(2*r1)*alpha_t * (pred_res_time_cond_u - pred_res_time_cond) - 1/(2*r1)*betas_t * (pred_noise_time_cond_u - pred_noise_time_cond) 1326 | - 1/6*alpha_t*D2_res - 1/6*betas_t*D2_eps 1327 | 1328 | img = img.clone().detach_() 1329 | 1330 | if not last: 1331 | img_list.append(img) 1332 | 1333 | if self.condition: 1334 | if not last: 1335 | img_list = [input_add_noise]+img_list 1336 | else: 1337 | img_list = [input_add_noise, img] 1338 | return unnormalize_to_zero_to_one(img_list) 1339 | else: 1340 | if not last: 1341 | img_list = img_list 1342 | else: 1343 | img_list = [img] 1344 | return unnormalize_to_zero_to_one(img_list) 1345 | 1346 | 1347 | def sample(self, x_input=None, batch_size=16, last=True, task=None): 1348 | image_size, channels = self.image_size, self.channels 1349 | sample_fn = self.second_order_UPS 1350 | # self.first_order_sample 1st 1351 | # self.first_order_UPS UPS_1st 1352 | # self.second_order_sample 2nd 1353 | # self.second_order_UPS UPS_2nd 1354 | # self.third_order_sample 3rd 1355 | # self.third_order_UPS UPS_3rd 1356 | if self.condition: 1357 | x_input = 2 * x_input - 1 1358 | x_input = x_input.unsqueeze(0) 1359 | 1360 | batch_size, channels, h, w = x_input[0].shape 1361 | size = (batch_size, channels, h, w) 1362 | else: 1363 | size = (batch_size, channels, image_size, image_size) 1364 | 1365 | gen_samples = sample_fn(x_input, size, last=last, task=task)[1] 1366 | gen_samples = gen_samples.detach() 1367 | return gen_samples 1368 | 1369 | def q_sample(self, x_start, x_res, condition, t, noise=None): 1370 | noise = default(noise, lambda: torch.randn_like(x_start)) 1371 | 1372 | return ( 1373 | x_start+extract(self.alphas_cumsum, t, x_start.shape) * x_res + 1374 | extract(self.betas_cumsum, t, x_start.shape) * noise - 1375 | extract(self.delta_cumsum, t, x_start.shape) * condition 1376 | ) 1377 | 1378 | @property 1379 | def loss_fn(self, loss_type='l1'): 1380 | if loss_type == 'l1': 1381 | return F.l1_loss 1382 | elif loss_type == 'l2': 1383 | return F.mse_loss 1384 | else: 1385 | raise ValueError(f'invalid loss type {loss_type}') 1386 | 1387 | def p_losses(self, imgs, t, noise=None): 1388 | if isinstance(imgs, list): 1389 | x_input = 2 * imgs[1] - 1 1390 | x_start = 2 * imgs[0] - 1 1391 | task = None 1392 | 1393 | noise = default(noise, lambda: torch.randn_like(x_start)) 1394 | x_res = x_input - x_start 1395 | b, c, h, w = x_start.shape 1396 | x = self.q_sample(x_start, x_res, x_input, t, noise=noise) 1397 | if not self.condition: 1398 | x_in = x 1399 | else: 1400 | x_in = torch.cat((x, x_input), dim=1) 1401 | 1402 | model_out = self.model(x_in,[t,t]) 1403 | 1404 | target = [] 1405 | if self.objective == 'pred_res_noise': 1406 | target.append(x_res) 1407 | target.append(noise) 1408 | 1409 | pred_res = model_out[0] 1410 | pred_noise = model_out[1] 1411 | elif self.objective == 'pred_x0_noise': 1412 | target.append(x_start) 1413 | target.append(noise) 1414 | 1415 | pred_res = x_input-model_out[0] 1416 | pred_noise = model_out[1] 1417 | elif self.objective == "pred_noise": 1418 | target.append(noise) 1419 | pred_noise = model_out[0] 1420 | 1421 | elif self.objective == "pred_res": 1422 | target.append(x_res) 1423 | pred_res = model_out[0] 1424 | 1425 | else: 1426 | raise ValueError(f'unknown objective {self.objective}') 1427 | 1428 | u_loss = False 1429 | if u_loss: 1430 | x_u = self.q_posterior_from_res_noise(pred_res, pred_noise, x, t,x_input) 1431 | u_gt = self.q_posterior_from_res_noise(x_res, noise, x, t,x_input) 1432 | loss = 10000*self.loss_fn(x_u, u_gt, reduction='none') 1433 | return [loss] 1434 | else: 1435 | loss_list = [] 1436 | for i in range(len(model_out)): 1437 | loss = self.loss_fn(model_out[i], target[i], reduction='none') 1438 | loss = reduce(loss, 'b ... -> b (...)', 'mean').mean() 1439 | loss_list.append(loss) 1440 | return loss_list 1441 | 1442 | def forward(self, img, *args, **kwargs): 1443 | if isinstance(img, list): 1444 | b, c, h, w, device, img_size, = * \ 1445 | img[0].shape, img[0].device, self.image_size 1446 | else: 1447 | b, c, h, w, device, img_size, = *img.shape, img.device, self.image_size 1448 | t = torch.randint(0, self.num_timesteps, (b,), device=device).long() 1449 | 1450 | return self.p_losses(img, t, *args, **kwargs) 1451 | 1452 | 1453 | if __name__ == '__main__': 1454 | print('Hello World') --------------------------------------------------------------------------------