├── assets
├── method.png
└── visualization.png
├── code
├── ckpt
│ └── pretrained.txt
├── test.sh
├── train.sh
├── datasets_setting.py
├── test.py
├── train.py
└── model.py
├── LICENSE.md
└── Readme.md
/assets/method.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MiliLab/DGSolver/HEAD/assets/method.png
--------------------------------------------------------------------------------
/assets/visualization.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MiliLab/DGSolver/HEAD/assets/visualization.png
--------------------------------------------------------------------------------
/code/ckpt/pretrained.txt:
--------------------------------------------------------------------------------
1 | [Baidu Cloud](https://pan.baidu.com/s/1CvZ2HAiwqM1t2VOJe5dYkg?pwd=fz2d)[fz2d]
2 | [Google Cloud](https://drive.google.com/file/d/1cc7WkG2E8gGKEzBt7WZFXmkpV7VzLWk7/view?usp=drive_link)
--------------------------------------------------------------------------------
/code/test.sh:
--------------------------------------------------------------------------------
1 | export CUDA_DEVICE_ORDER="PCI_BUS_ID"
2 | export CUDA_VISIBLE_DEVICES="0,1,2,3"
3 |
4 | start_time=$(date +%s)
5 | echo "start_time: ${start_time}"
6 |
7 | # TORCH_DISTRIBUTED_DEBUG=DETAIL
8 | # nohup python train.py > ./train.log 2>&1 &
9 |
10 | nohup python -m torch.distributed.launch --nproc_per_node 4 --use_env test.py > ./test.log 2>&1 &
11 |
12 | end_time=$(date +%s)
13 | e2e_time=$(($end_time - $start_time))
14 |
15 | echo "------------------ Final result ------------------"
--------------------------------------------------------------------------------
/code/train.sh:
--------------------------------------------------------------------------------
1 | export CUDA_DEVICE_ORDER="PCI_BUS_ID"
2 | export CUDA_VISIBLE_DEVICES="0,1,2,3"
3 |
4 | start_time=$(date +%s)
5 | echo "start_time: ${start_time}"
6 |
7 | # TORCH_DISTRIBUTED_DEBUG=DETAIL
8 | # nohup python train.py > ./train.log 2>&1 &
9 |
10 | nohup python -m torch.distributed.launch --nproc_per_node 4 --use_env train.py > ./train.log 2>&1 &
11 |
12 | end_time=$(date +%s)
13 | e2e_time=$(($end_time - $start_time))
14 |
15 | echo "------------------ Final result ------------------"
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 MiliLab
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Readme.md:
--------------------------------------------------------------------------------
1 | # Project
2 |
3 | # DGSolver: Diffusion Generalist Solver with Universal Posterior Sampling for Image Restoration
4 |
5 | Hebaixu Wang, Jing Zhang, Haonan Guo, Di Wang, Jiayi Ma and Bo Du.
6 |
7 | [Paper](https://arxiv.org/abs/2504.21487) | [Github Code](https://github.com/MiliLab/DGSolver)
8 |
9 | ## Abstract
10 |
11 | Diffusion models have achieved remarkable progress in universal image restoration. While existing methods speed up inference by reducing sampling steps, substantial step intervals often introduce cumulative errors. Moreover, they struggle to balance the commonality of degradation representations and restoration quality. To address these challenges, we introduce \textbf{DGSolver}, a diffusion generalist solver with universal posterior sampling. We first derive the exact ordinary differential equations for generalist diffusion models and tailor high-order solvers with a queue-based accelerated sampling strategy to improve both accuracy and efficiency. We then integrate universal posterior sampling to better approximate manifold-constrained gradients, yielding a more accurate noise estimation and correcting errors in inverse inference. Extensive experiments show that DGSolver outperforms state-of-the-art methods in restoration accuracy, stability, and scalability, both qualitatively and quantitatively.
12 |
13 | ## Overview
14 |
15 |
16 |
17 | ## Visualization
18 |
19 |
20 |
21 | ## Datasets Information
22 |
23 | | Task | Dataset | Synthetic/Real | Download Links |
24 | |--------------------------|--------------------------------|---------------------|----------------|
25 | | **Deraining** | DID | Synthetic | [URL](https://github.com/hezhangsprinter/DID-MDN) |
26 | | | DeRaindrop | Real | [URL](https://github.com/rui1996/DeRaindrop) |
27 | | | Rain13K | Synthetic | [URL](https://github.com/kuijiang94/MSPFN) |
28 | | | Rain_100H | Synthetic | [URL](https://github.com/kuijiang94/MSPFN) |
29 | | | Rain_100L | Synthetic | [URL](https://github.com/kuijiang94/MSPFN) |
30 | | | GT-Rain | Real | [URL](https://github.com/UCLA-VMG/GT-RAIN) |
31 | | | RealRain-1k | Real | [URL](https://github.com/hiker-lw/RealRain-1k) |
32 | | **Low-light Enhancement**| LOL | Real | [URL](https://github.com/weichen582/RetinexNet?tab=readme-ov-file) |
33 | | | MEF | Real | [URL](https://ieeexplore.ieee.org/abstract/document/7120119) |
34 | | | VE-LOL-L | Synthetic/Real | [URL](https://flyywh.github.io/IJCV2021LowLight_VELOL/) |
35 | | | NPE | Real | [URL](https://ieeexplore.ieee.org/abstract/document/6512558) |
36 | | **Desnowing** | CSD | Synthetic | [URL](https://github.com/weitingchen83/ICCV2021-Single-Image-Desnowing-HDCWNet) |
37 | | | Snow100K-Real | Real | [URL](https://sites.google.com/view/yunfuliu/desnownet) |
38 | | **Dehazing** | SOTS | Synthetic | [URL](https://sites.google.com/view/reside-dehaze-datasets/reside-standard?authuser=3D0) |
39 | | | ITS_v2 | Synthetic | [URL](https://sites.google.com/view/reside-dehaze-datasets/reside-standard?authuser=3D0) |
40 | | | D-HAZY | Synthetic | [URL](https://www.cvmart.net/dataSets/detail/559?channel_id=op10&utm_source=cvmartmp&utm_campaign=datasets&utm_medium=article) |
41 | | | NH-HAZE | Real | [URL](https://data.vision.ee.ethz.ch/cvl/ntire20/nh-haze/) |
42 | | | Dense-Haze | Real | [URL](https://data.vision.ee.ethz.ch/cvl/ntire19/dense-haze/) |
43 | | | NHRW | Real | [URL](https://github.com/chaimi2013/3R) |
44 | | **Deblur** | GoPro | Synthetic | [URL](https://github.com/SeungjunNah/DeepDeblur-PyTorch) |
45 | | | RealBlur | Real | [URL](https://github.com/rimchang/RealBlur) |
46 |
47 | ## Model Checkpoint
48 |
49 | [Google Cloud](https://drive.google.com/file/d/1cc7WkG2E8gGKEzBt7WZFXmkpV7VzLWk7/view?usp=drive_link)
50 |
51 | [Baidu Cloud](https://pan.baidu.com/s/1CvZ2HAiwqM1t2VOJe5dYkg?pwd=fz2d)[fz2d]
52 |
53 | [Huggingface](https://huggingface.co/BaiXuYa/DGSolver)
54 |
55 | ### Contributor
56 |
57 | Baixuzx7 @ wanghebaixu@gmail.com
58 |
59 | ### Copyright statement
60 |
61 | The project is signed under the MIT license, see the [LICENSE.md](https://github.com/MiliLab/DGSolver/LICENSE.md)
62 |
--------------------------------------------------------------------------------
/code/datasets_setting.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from torch.utils.data import Dataset, DataLoader
4 | from torchvision import transforms as T, utils
5 |
6 | from PIL import Image
7 | from torch import nn
8 | import imageio
9 | import cv2
10 | import numpy as np
11 | import random
12 |
13 | # import torch_npu
14 | # from torch_npu.contrib import transfer_to_npu
15 |
16 | def exists(x):
17 | return x is not None
18 |
19 | def cycle(dl):
20 | while True:
21 | for data in dl:
22 | yield data
23 |
24 | def default(val, d):
25 | if exists(val) and (val is not None):
26 | return val
27 | return d() if callable(d) else d
28 |
29 |
30 | def set_seed(SEED):
31 | torch.manual_seed(SEED)
32 | torch.cuda.manual_seed_all(SEED)
33 | np.random.seed(SEED)
34 | random.seed(SEED)
35 |
36 | degradtion_cache = ['Enlighening', 'Desnowing', 'Deraining', 'Deblur', 'Dehazing']
37 |
38 | class train_dataset(Dataset):
39 | def __init__(self, root_dir, task_folder, sub_folder = None, image_size = 256):
40 | super().__init__()
41 | assert task_folder in degradtion_cache
42 | self.root_path = root_dir
43 | self.task_path = os.path.join(root_dir,task_folder)
44 | if sub_folder is not None:
45 | self.sub_datasets = [sub_folder]
46 | else:
47 | self.sub_datasets = os.listdir(self.task_path)
48 | self.image_size = image_size
49 | self.skip_datasets = []
50 | self.image_load_path,self.image_name_list = [],[]
51 | self.condi_load_path,self.condi_name_list = [],[]
52 |
53 | for sub_dataset in self.sub_datasets:
54 | assert sub_dataset in os.listdir(self.task_path)
55 | if sub_dataset in self.skip_datasets:
56 | continue
57 | else:
58 | if sub_dataset == 'ITS_v2' or sub_dataset == 'SOTS' or sub_dataset == 'RESIDE' :
59 | dataset_path = os.path.join(self.task_path,sub_dataset,'train')
60 | image_path = os.path.join(dataset_path,'label')
61 | condi_path = os.path.join(dataset_path ,'condition')
62 | condi_file_list = os.listdir(condi_path)
63 | condi_path_list = [os.path.join(condi_path,x) for x in condi_file_list]
64 | image_file_list = os.listdir(condi_path)
65 | image_path_list = [os.path.join(image_path,x.split('_',1)[0]+x[-4:]) for x in condi_file_list]
66 | else:
67 | dataset_path = os.path.join(self.task_path,sub_dataset,'train')
68 | image_path = os.path.join(dataset_path,'label')
69 | image_file_list = os.listdir(image_path)
70 | condi_path = os.path.join(dataset_path ,'condition')
71 | condi_file_list = os.listdir(condi_path)
72 | image_path_list = [os.path.join(image_path,x) for x in image_file_list]
73 | condi_path_list = [os.path.join(condi_path,x) for x in condi_file_list]
74 |
75 | self.image_load_path = self.image_load_path + image_path_list
76 | self.image_name_list = self.image_name_list + image_file_list
77 | self.condi_load_path = self.condi_load_path + condi_path_list
78 | self.condi_name_list = self.condi_name_list + condi_file_list
79 |
80 | self.transform = T.Compose([T.ToTensor()])
81 |
82 | def __len__(self):
83 | assert len(self.condi_name_list) == len(self.image_name_list),f'the number of label files does not match the number of condition files'
84 | return len(self.condi_name_list)
85 |
86 | def __getitem__(self, index):
87 | image_name_file = self.image_name_list[index]
88 | condi_name_file = self.condi_name_list[index]
89 | if condi_name_file != image_name_file:
90 | print(self.condi_load_path[index],'\n',self.image_load_path[index])
91 | assert condi_name_file == image_name_file, f'image pairs are not matched'
92 | image_file_path = self.image_load_path[index]
93 | condi_file_path = self.condi_load_path[index]
94 | image = imageio.imread(image_file_path)
95 | condi = imageio.imread(condi_file_path)
96 | if image.shape != condi.shape:
97 | print(image_file_path,'\n',condi_file_path)
98 | assert image.shape == condi.shape, 'image sizes are not matched'
99 | image,condi = self.random_crop_size(image,condi,self.image_size)
100 |
101 | if self.transform is not None:
102 | image_tf = self.transform(image)
103 | condi_tf = self.transform(condi)
104 |
105 | return condi_name_file,image_tf,condi_tf
106 |
107 | def resize_shape(self, image, short_side_length):
108 | oldh, oldw, _ = image.shape[0],image.shape[1],image.shape[2]
109 | if min(oldh, oldw) < short_side_length:
110 | scale = short_side_length * 1.0 / min(oldh, oldw)
111 | newh, neww = oldh * scale, oldw * scale
112 | image = cv2.resize(image,(int(neww + 0.5),int(newh + 0.5)))
113 | return image
114 |
115 | def random_crop_size(self,imageA,imageB,crop_size):
116 | imageA = self.resize_shape(imageA,crop_size)
117 | imageB = self.resize_shape(imageB,crop_size)
118 | assert imageA.shape == imageB.shape, f'image sizes are not matched'
119 | h,w,_ = imageA.shape
120 | h_start,w_start = np.random.randint(0,h-crop_size+1),np.random.randint(0,w-crop_size+1)
121 | imageA_crop = imageA[h_start:h_start+crop_size,w_start:w_start+crop_size,:]
122 | imageB_crop = imageB[h_start:h_start+crop_size,w_start:w_start+crop_size,:]
123 | return imageA_crop,imageB_crop
124 |
125 |
126 | class test_dataset(Dataset):
127 | def __init__(self, root_dir, task_folder, sub_folder = None):
128 | super().__init__()
129 | assert task_folder in degradtion_cache
130 | self.root_path = root_dir
131 | self.task_path = os.path.join(root_dir,task_folder)
132 | if sub_folder is not None:
133 | self.sub_datasets = [sub_folder]
134 | else:
135 | self.sub_datasets = os.listdir(self.task_path)
136 | self.skip_datasets = []
137 | self.image_load_path,self.image_name_list = [],[]
138 | self.condi_load_path,self.condi_name_list = [],[]
139 |
140 | for sub_dataset in self.sub_datasets:
141 | assert sub_dataset in os.listdir(self.task_path)
142 | if sub_dataset in self.skip_datasets:
143 | continue
144 | else:
145 | if sub_dataset == 'ITS_v2' or sub_dataset == 'SOTS':
146 | dataset_path = os.path.join(self.task_path,sub_dataset,'test')
147 | image_path = os.path.join(dataset_path,'label')
148 | condi_path = os.path.join(dataset_path ,'condition')
149 | condi_file_list = os.listdir(condi_path)
150 | condi_path_list = [os.path.join(condi_path,x) for x in condi_file_list]
151 | image_file_list = os.listdir(condi_path)
152 | image_path_list = [os.path.join(image_path,x.split('_',1)[0]+x[-4:]) for x in condi_file_list]
153 | else:
154 | dataset_path = os.path.join(self.task_path,sub_dataset,'test')
155 | image_path = os.path.join(dataset_path,'label')
156 | image_file_list = os.listdir(image_path)
157 | condi_path = os.path.join(dataset_path ,'condition')
158 | condi_file_list = os.listdir(condi_path)
159 | image_path_list = [os.path.join(image_path,x) for x in image_file_list]
160 | condi_path_list = [os.path.join(condi_path,x) for x in condi_file_list]
161 |
162 | self.image_load_path = self.image_load_path + image_path_list
163 | self.image_name_list = self.image_name_list + image_file_list
164 | self.condi_load_path = self.condi_load_path + condi_path_list
165 | self.condi_name_list = self.condi_name_list + condi_file_list
166 |
167 | self.transform = T.Compose([T.ToTensor()])
168 |
169 | def __len__(self):
170 | assert len(self.condi_name_list) == len(self.image_name_list),f'the number of label files does not match the number of condition files'
171 | return len(self.condi_name_list)
172 |
173 | def __getitem__(self, index):
174 | image_name_file = self.image_name_list[index]
175 | condi_name_file = self.condi_name_list[index]
176 | if condi_name_file != image_name_file:
177 | print(self.condi_load_path[index],'\n',self.image_load_path[index])
178 | assert condi_name_file == image_name_file, f'image pairs are not matched'
179 | image_file_path = self.image_load_path[index]
180 | condi_file_path = self.condi_load_path[index]
181 | image = imageio.imread(image_file_path)
182 | condi = imageio.imread(condi_file_path)
183 | assert image.shape == condi.shape, 'image sizes are not matched'
184 |
185 | if self.transform is not None:
186 | image_tf = self.transform(image)
187 | condi_tf = self.transform(condi)
188 |
189 | return condi_file_path,image_tf,condi_tf
--------------------------------------------------------------------------------
/code/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import math
4 | import time
5 | import imageio
6 | import json
7 |
8 | import accelerate
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 |
13 | from argparse import ArgumentParser
14 | from tqdm.auto import tqdm
15 | from ema_pytorch import EMA
16 | from pathlib import Path
17 | from skimage.metrics import structural_similarity
18 | from torch.optim import Adam
19 | from torchvision import transforms as T, utils
20 | from torch.utils.data import DataLoader
21 | from model import (ResidualDiffusion, Unet, UnetRes, set_seed)
22 | from datasets_setting import train_dataset,test_dataset,set_seed
23 |
24 | # import torch_npu
25 | # from torch_npu.contrib import transfer_to_npu
26 |
27 |
28 | parser = ArgumentParser()
29 | parser.add_argument("--project_description", type=str, default="UniDiffSolver For Image Restoration", help="Name of Project")
30 |
31 | parser.add_argument("--step_number", type=int, default=5000, help="step_number")
32 | parser.add_argument("--batch_size", type=int, default=8, help="batch_size")
33 | parser.add_argument("--image_size", type=int, default=512, help="image_size")
34 | parser.add_argument("--num_unet", type=int, default=1, help="num_unet")
35 | parser.add_argument("--objective", type=str, default='pred_res', help="[pred_res_noise,pred_x0_noise,pred_noise,pred_res]")
36 | parser.add_argument("--test_res_or_noise", type=str, default='res', help="[res_noise,res,noise]")
37 | parser.add_argument("--lr", type=float, default=0.0003, help="learning_rate")
38 | parser.add_argument("--sampling_timesteps", type=int, default=1, help="sampling_timesteps")
39 |
40 | def exists(x):
41 | return x is not None
42 |
43 | def has_int_squareroot(num):
44 | return (math.sqrt(num) ** 2) == num
45 |
46 | def cycle(dl):
47 | while True:
48 | for data in dl:
49 | yield data
50 |
51 | def create_folder(folder_path):
52 | if not os.path.exists(folder_path):
53 | os.makedirs(folder_path)
54 |
55 | def create_empty_json(json_path):
56 | with open(json_path, 'w') as file:
57 | pass
58 |
59 | def remove_json(json_path):
60 | os.remove(json_path)
61 |
62 | def write_json(json_path,item):
63 | with open(json_path, 'a+', encoding='utf-8') as f:
64 | line = json.dumps(item)
65 | f.write(line+'\n')
66 |
67 | def readline_json(json_path,key=None):
68 | data = []
69 | with open(json_path, 'r') as f:
70 | items = f.readlines()
71 | file_flag = []
72 | if key is not None:
73 | for item in items:
74 | file_name = json.loads(item)['file_path']
75 | if file_name not in file_flag:
76 | file_flag.append(file_name)
77 | data.append(json.loads(item)[key])
78 | return np.asarray(data).mean()
79 | else:
80 | for item in items:
81 | data.append(json.loads(item))
82 | return data
83 |
84 |
85 | class Trainer(object):
86 | def __init__(
87 | self,
88 | diffusion_model,
89 | train_folder,
90 | eval_folder,
91 | train_num_steps = 100000,
92 | train_batch_size = 1,
93 | save_and_sample_every = 5000,
94 | save_best_and_latest_only = True,
95 | calculate_metric = True,
96 | results_folder = './results/',
97 | gradient_accumulate_every = 1,
98 | *,
99 | augment_horizontal_flip = True,
100 | train_lr = 8e-5,
101 | ema_update_every = 1,
102 | ema_decay = 0.995,
103 | adam_betas = (0.9, 0.99),
104 | save_row = 10,
105 | amp = False,
106 | mixed_precision_type = 'fp16',
107 | split_batches = True,
108 | convert_image_to = None,
109 | max_grad_norm = 1.,
110 | ):
111 | super().__init__()
112 |
113 | self.accelerator = accelerate.Accelerator(split_batches = split_batches)
114 | self.model = diffusion_model
115 | is_ddim_sampling = diffusion_model.is_ddim_sampling
116 | self.save_row = save_row
117 | self.save_and_sample_every = save_and_sample_every
118 | self.batch_size = train_batch_size
119 | self.gradient_accumulate_every = gradient_accumulate_every
120 | self.image_size = diffusion_model.image_size
121 | self.max_grad_norm = max_grad_norm
122 |
123 | self.train_folder = train_folder
124 | self.eval_folder = eval_folder
125 | self.ds_eval_hazy = test_dataset(eval_folder,task_folder='Dehazing')
126 | self.ds_eval_light = test_dataset(eval_folder,task_folder='Enlighening')
127 | self.ds_eval_rain = test_dataset(eval_folder,task_folder='Deraining')
128 | self.ds_eval_snow = test_dataset(eval_folder,task_folder='Desnowing')
129 | self.ds_eval_blur = test_dataset(eval_folder,task_folder='Deblur')
130 | self.dl_eval_hazy = self.accelerator.prepare(DataLoader(self.ds_eval_hazy, batch_size = 1))
131 | self.dl_eval_light = self.accelerator.prepare(DataLoader(self.ds_eval_light, batch_size = 1))
132 | self.dl_eval_rain = self.accelerator.prepare(DataLoader(self.ds_eval_rain, batch_size = 1))
133 | self.dl_eval_snow = self.accelerator.prepare(DataLoader(self.ds_eval_snow, batch_size = 1))
134 | self.dl_eval_blur = self.accelerator.prepare(DataLoader(self.ds_eval_blur, batch_size = 1))
135 |
136 | if self.accelerator.is_main_process:
137 | self.accelerator.print('Validation Samplies :')
138 | self.accelerator.print(' : (hazy :{})'.format(len(self.ds_eval_hazy)))
139 | self.accelerator.print(' : (light:{})'.format(len(self.ds_eval_light)))
140 | self.accelerator.print(' : (rain :{})'.format(len(self.ds_eval_rain)))
141 | self.accelerator.print(' : (snow :{})'.format(len(self.ds_eval_snow)))
142 | self.accelerator.print(' : (blur :{})'.format(len(self.ds_eval_blur)))
143 |
144 | self.opt = Adam(diffusion_model.parameters(), lr = train_lr, betas = adam_betas)
145 | self.ema = EMA(diffusion_model, beta = ema_decay, update_every = ema_update_every)
146 | self.ema.to(self.device)
147 | self.results_folder = Path(results_folder)
148 | self.results_folder.mkdir(exist_ok = True)
149 | self.train_num_steps = train_num_steps
150 | self.step = 0
151 | self.model, self.opt = self.accelerator.prepare(self.model, self.opt)
152 | self.calculate_metric = calculate_metric and self.accelerator.is_main_process
153 |
154 | @property
155 | def device(self):
156 | return self.accelerator.device
157 |
158 | def save(self, milestone = None):
159 | if not self.accelerator.is_local_main_process:
160 | return
161 | data = {
162 | 'step': self.step,
163 | 'model': self.accelerator.get_state_dict(self.model),
164 | 'opt': self.opt.state_dict(),
165 | 'ema': self.ema.state_dict(),
166 | 'scaler': self.accelerator.scaler.state_dict() if exists(self.accelerator.scaler) else None,
167 | }
168 | checkpoint_save_path = os.path.join(self.results_folder,f'model-{milestone}')
169 | if not os.path.exists(checkpoint_save_path):
170 | os.makedirs(checkpoint_save_path)
171 | torch.save(data, checkpoint_save_path + '/' + f'model-{milestone}.pt')
172 |
173 | def load(self, milestone = None):
174 | accelerator = self.accelerator
175 | device = accelerator.device
176 | checkpoint_save_path = os.path.join(self.results_folder,f'model-{milestone}')
177 | data = torch.load('./ckpt/pretrained.pt', map_location=device)
178 | model = self.accelerator.unwrap_model(self.model)
179 | model.load_state_dict(data['model'])
180 | self.step = data['step']
181 | self.opt.load_state_dict(data['opt'])
182 | self.ema.load_state_dict(data["ema"])
183 | if exists(self.accelerator.scaler) and exists(data['scaler']):
184 | self.accelerator.scaler.load_state_dict(data['scaler'])
185 |
186 | def cal_psnr(self,img_ref, img_gen, data_range = 255.0):
187 | mse = np.mean((img_ref.astype(np.float32)/data_range - img_gen.astype(np.float32)/data_range) ** 2)
188 | if mse < 1.0e-10:
189 | return 100
190 | PIXEL_MAX = 1
191 | return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))
192 |
193 | def cal_ssim(self,img_ref, img_gen):
194 | ssim_val = 0
195 | for i in range(img_ref.shape[-1]):
196 | ssim_val = ssim_val + structural_similarity(img_ref[:,:,i], img_gen[:,:,i])
197 | return ssim_val/img_ref.shape[-1]
198 |
199 | def train(self):
200 | accelerator = self.accelerator
201 | device = accelerator.device
202 | track_metric_json_path = os.path.join(self.results_folder,'metric.json')
203 | if self.accelerator.is_main_process:
204 | create_empty_json(track_metric_json_path)
205 |
206 | self.test(dataloader = self.dl_eval_rain, degradation = 'Deraining')
207 | self.test(dataloader = self.dl_eval_light, degradation = 'Enlighening')
208 | self.test(dataloader = self.dl_eval_blur, degradation = 'Deblur')
209 | self.test(dataloader = self.dl_eval_snow, degradation = 'Desnowing')
210 | self.test(dataloader = self.dl_eval_hazy, degradation = 'Dehazing')
211 |
212 | if self.accelerator.is_main_process:
213 | write_json(track_metric_json_path,f'model-{self.step} : ')
214 | degradation_types = ['Deraining', 'Enlighening', 'Desnowing', 'Dehazing' , 'Deblur']
215 | for degradation in degradation_types:
216 | json_path = os.path.join(self.results_folder,f'model-{self.step}') + '/{}.json'.format(degradation)
217 | psnr_val,ssim_val = readline_json(json_path,'psnr'),readline_json(json_path,'ssim')
218 | accelerator.print('{} -> (PSNR/SSIM) : {:.6f}/{:.6f} '.format(degradation,psnr_val,ssim_val))
219 | write_json(track_metric_json_path,'{} -> (PSNR/SSIM) : {:.6f}/{:.6f} '.format(degradation,psnr_val,ssim_val))
220 |
221 | accelerator.print('Testing complete')
222 |
223 | def test(self,dataloader,degradation):
224 | self.accelerator.wait_for_everyone()
225 | if self.accelerator.is_main_process:
226 | start_time = time.time()
227 | save_json_dir = os.path.join(self.results_folder,f'model-{self.step}')
228 | create_folder(save_json_dir)
229 | save_json_path = save_json_dir + '/{}.json'.format(degradation)
230 | create_empty_json(save_json_path)
231 | self.accelerator.wait_for_everyone()
232 | save_json_path = os.path.join(self.results_folder,f'model-{self.step}') + '/{}.json'.format(degradation)
233 | self.ema.model.eval()
234 | for batch_id,batch in enumerate(dataloader):
235 | name_path,image_tf,condi_tf = batch
236 | img_gen = self.ema.model.sample(condi_tf.to(self.device))
237 | for element_id in range(len(name_path)):
238 | image_np_ref = self.tf2img(image_tf[element_id,:,:,].unsqueeze(0))
239 | image_np_gen = self.tf2img(img_gen[element_id,:,:,].unsqueeze(0))
240 | psnr_val = self.cal_psnr(image_np_ref,image_np_gen)
241 | ssim_val = self.cal_ssim(image_np_ref,image_np_gen)
242 | data_dump_info = {
243 | 'file_path' : name_path[element_id],
244 | 'psnr' : psnr_val,
245 | 'ssim' : ssim_val,
246 | }
247 | print(batch_id,name_path,'PSNR / SSIM : {:.6f} : {:.6f}'.format(psnr_val,ssim_val))
248 | write_json(save_json_path,data_dump_info)
249 | image_save_dir = os.path.join(self.results_folder,f'model-{self.step}',name_path[element_id].split('/')[-5],name_path[element_id].split('/')[-4])
250 | create_folder(image_save_dir)
251 | imageio.imwrite(os.path.join(image_save_dir,name_path[element_id].split('/')[-1]),image_np_gen)
252 |
253 | if self.accelerator.is_main_process:
254 | end_time = time.time()
255 | test_time_consuming = end_time - start_time
256 | self.accelerator.print('Test_time_consuming : {:.6} s'.format(test_time_consuming))
257 |
258 | self.accelerator.wait_for_everyone()
259 |
260 | def tf2np(self,image_tf):
261 | n,c,h,w = image_tf.size()
262 | assert n == 1
263 | if c == 1:
264 | image_np = image_tf.squeeze(0).squeeze(0).detach().cpu().numpy()
265 | else:
266 | image_np = image_tf.squeeze(0).permute(1,2,0).detach().cpu().numpy()
267 |
268 | return image_np
269 |
270 | def tf2img(self,image_tf):
271 | image_np = self.tf2np(torch.clamp(image_tf,min=0.,max=1.))
272 | image_np = (image_np * 255).astype(np.uint8)
273 | return image_np
274 |
275 |
276 | def test_ddp_accelerate(args):
277 | train_folder = ' '
278 | eval_folder = ' '
279 | print('Procedure Running: ',args.project_description)
280 | image_size = 256
281 | num_unet = 1
282 | objective = 'pred_res'
283 | ddim_sampling_eta = 0.0
284 | test_res_or_noise = "res"
285 | sum_scale = 0.01
286 | delta_end = 2.0e-3
287 | condition = True
288 | sampling_timesteps = 2
289 | model = UnetRes(dim=64, dim_mults=(1, 2, 4, 8),num_unet=num_unet, condition=condition, objective=objective, test_res_or_noise = test_res_or_noise)
290 | diffusion = ResidualDiffusion(model,image_size=image_size, timesteps=1000,delta_end = delta_end,sampling_timesteps=sampling_timesteps, objective=objective,ddim_sampling_eta= ddim_sampling_eta,loss_type='l1',condition=condition,sum_scale=sum_scale,test_res_or_noise = test_res_or_noise)
291 | diffusion_process_trainer = Trainer(
292 | diffusion_model = diffusion,
293 | train_folder = train_folder,
294 | eval_folder = eval_folder,
295 | train_num_steps = 500000,
296 | train_batch_size = 16,
297 | save_and_sample_every = 5000,
298 | save_best_and_latest_only = True,
299 | calculate_metric = True,
300 | results_folder = './save_folder',
301 | gradient_accumulate_every = 1,
302 | )
303 |
304 | diffusion_process_trainer.load()
305 | diffusion_process_trainer.train()
306 | print('Procedure Termination: (Finished)')
307 |
308 |
309 |
310 | if __name__ == '__main__':
311 | args = parser.parse_args()
312 | set_seed(0)
313 | test_ddp_accelerate(args)
--------------------------------------------------------------------------------
/code/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import math
4 | import time
5 | import imageio
6 | import json
7 |
8 | import accelerate
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 |
13 | from argparse import ArgumentParser
14 | from tqdm.auto import tqdm
15 | from ema_pytorch import EMA
16 | from pathlib import Path
17 | from skimage.metrics import structural_similarity
18 | from torch.optim import Adam
19 | from torchvision import transforms as T, utils
20 | from torch.utils.data import DataLoader
21 | from model import (ResidualDiffusion, Unet, UnetRes, set_seed)
22 | from datasets_setting import train_dataset,test_dataset,set_seed
23 | # import torch_npu
24 | # from torch_npu.contrib import transfer_to_npu
25 |
26 |
27 | parser = ArgumentParser()
28 | parser.add_argument("--project_description", type=str, default="UniDiffSolver For Image Restoration", help="Name of Project")
29 |
30 | parser.add_argument("--step_number", type=int, default=5000, help="step_number")
31 | parser.add_argument("--batch_size", type=int, default=8, help="batch_size")
32 | parser.add_argument("--image_size", type=int, default=512, help="image_size")
33 | parser.add_argument("--num_unet", type=int, default=1, help="num_unet")
34 | parser.add_argument("--objective", type=str, default='pred_res', help="[pred_res_noise,pred_x0_noise,pred_noise,pred_res]")
35 | parser.add_argument("--test_res_or_noise", type=str, default='res', help="[res_noise,res,noise]")
36 | parser.add_argument("--lr", type=float, default=0.0003, help="learning_rate")
37 | parser.add_argument("--sampling_timesteps", type=int, default=1, help="sampling_timesteps")
38 |
39 | def exists(x):
40 | return x is not None
41 |
42 | def has_int_squareroot(num):
43 | return (math.sqrt(num) ** 2) == num
44 |
45 | def cycle(dl):
46 | while True:
47 | for data in dl:
48 | yield data
49 |
50 | def divisible_by(numer, denom):
51 | return (numer % denom) == 0
52 |
53 | def create_folder(folder_path):
54 | if not os.path.exists(folder_path):
55 | os.makedirs(folder_path)
56 |
57 | def create_empty_json(json_path):
58 | with open(json_path, 'w') as file:
59 | pass
60 |
61 | def remove_json(json_path):
62 | os.remove(json_path)
63 |
64 | def write_json(json_path,item):
65 | with open(json_path, 'a+', encoding='utf-8') as f:
66 | line = json.dumps(item)
67 | f.write(line+'\n')
68 |
69 | def readline_json(json_path,key=None):
70 | data = []
71 | with open(json_path, 'r') as f:
72 | items = f.readlines()
73 | if key is not None:
74 | for item in items:
75 | data.append(json.loads(item)[key])
76 | return np.asarray(data).mean()
77 | else:
78 | for item in items:
79 | data.append(json.loads(item))
80 | return data
81 |
82 | class Trainer(object):
83 | def __init__(
84 | self,
85 | diffusion_model,
86 | train_folder,
87 | eval_folder,
88 | train_num_steps = 100000,
89 | train_batch_size = 1,
90 | save_and_sample_every = 5000,
91 | save_best_and_latest_only = True,
92 | calculate_metric = True,
93 | results_folder = './results/',
94 | gradient_accumulate_every = 1,
95 | *,
96 | augment_horizontal_flip = True,
97 | train_lr = 8e-5,
98 | ema_update_every = 1,
99 | ema_decay = 0.995,
100 | adam_betas = (0.9, 0.99),
101 | save_row = 10,
102 | amp = False,
103 | mixed_precision_type = 'fp16',
104 | split_batches = True,
105 | convert_image_to = None,
106 | max_grad_norm = 1.,
107 | ):
108 | super().__init__()
109 |
110 | self.accelerator = accelerate.Accelerator(
111 | split_batches = split_batches,
112 | mixed_precision = mixed_precision_type if amp else 'no'
113 | )
114 | self.model = diffusion_model
115 | is_ddim_sampling = diffusion_model.is_ddim_sampling
116 | self.save_row = save_row
117 | self.save_and_sample_every = save_and_sample_every
118 | self.batch_size = train_batch_size
119 | self.gradient_accumulate_every = gradient_accumulate_every
120 | self.image_size = diffusion_model.image_size
121 | self.max_grad_norm = max_grad_norm
122 |
123 | self.train_folder = train_folder
124 | self.eval_folder = eval_folder
125 |
126 | self.ds_hazy = train_dataset(train_folder,task_folder='Dehazing',image_size = 256)
127 | self.ds_light = train_dataset(train_folder,task_folder='Enlighening',image_size = 256)
128 | self.ds_rain = train_dataset(train_folder,task_folder='Deraining',image_size = 256)
129 | self.ds_snow = train_dataset(train_folder,task_folder='Desnowing',image_size = 256)
130 | self.ds_blur = train_dataset(train_folder,task_folder='Deblur',image_size = 256)
131 |
132 | self.dl_hazy = cycle(self.accelerator.prepare(DataLoader(self.ds_hazy, batch_size = 1)))
133 | self.dl_light = cycle(self.accelerator.prepare(DataLoader(self.ds_light, batch_size = 1)))
134 | self.dl_rain = cycle(self.accelerator.prepare(DataLoader(self.ds_rain, batch_size = 1)))
135 | self.dl_snow = cycle(self.accelerator.prepare(DataLoader(self.ds_snow, batch_size = 1)))
136 | self.dl_blur = cycle(self.accelerator.prepare(DataLoader(self.ds_blur, batch_size = 1)))
137 |
138 | if self.accelerator.is_main_process:
139 | self.accelerator.print('Training Samplies : (hazy :{})'.format(len(self.ds_hazy)))
140 | self.accelerator.print(' : (light:{})'.format(len(self.ds_light)))
141 | self.accelerator.print(' : (rain :{})'.format(len(self.ds_rain)))
142 | self.accelerator.print(' : (snow :{})'.format(len(self.ds_snow)))
143 | self.accelerator.print(' : (blur :{})'.format(len(self.ds_blur)))
144 |
145 | self.ds_eval_hazy = test_dataset(eval_folder,task_folder='Dehazing')
146 | self.ds_eval_light = test_dataset(eval_folder,task_folder='Enlighening')
147 | self.ds_eval_rain = test_dataset(eval_folder,task_folder='Deraining')
148 | self.ds_eval_snow = test_dataset(eval_folder,task_folder='Desnowing')
149 | self.ds_eval_blur = test_dataset(eval_folder,task_folder='Deblur')
150 |
151 | self.dl_eval_hazy = self.accelerator.prepare(DataLoader(self.ds_eval_hazy, batch_size = 1))
152 | self.dl_eval_light = self.accelerator.prepare(DataLoader(self.ds_eval_light, batch_size = 1))
153 | self.dl_eval_rain = self.accelerator.prepare(DataLoader(self.ds_eval_rain, batch_size = 1))
154 | self.dl_eval_snow = self.accelerator.prepare(DataLoader(self.ds_eval_snow, batch_size = 1))
155 | self.dl_eval_blur = self.accelerator.prepare(DataLoader(self.ds_eval_blur, batch_size = 1))
156 |
157 | if self.accelerator.is_main_process:
158 | self.accelerator.print('Validation Samplies : (hazy :{})'.format(len(self.ds_eval_hazy)))
159 | self.accelerator.print(' : (light:{})'.format(len(self.ds_eval_light)))
160 | self.accelerator.print(' : (rain :{})'.format(len(self.ds_eval_rain)))
161 | self.accelerator.print(' : (snow :{})'.format(len(self.ds_eval_snow)))
162 | self.accelerator.print(' : (blur :{})'.format(len(self.ds_eval_blur)))
163 |
164 | self.opt = Adam(diffusion_model.parameters(), lr = train_lr, betas = adam_betas)
165 | self.ema = EMA(diffusion_model, beta = ema_decay, update_every = ema_update_every)
166 | self.ema.to(self.device)
167 |
168 | self.results_folder = Path(results_folder)
169 | self.results_folder.mkdir(exist_ok = True)
170 |
171 | self.train_num_steps = train_num_steps
172 | self.step = 0
173 |
174 | self.model, self.opt = self.accelerator.prepare(self.model, self.opt)
175 | self.calculate_metric = calculate_metric and self.accelerator.is_main_process
176 |
177 | if save_best_and_latest_only:
178 | self.best_metric = 0
179 |
180 | self.save_best_and_latest_only = save_best_and_latest_only
181 |
182 | @property
183 | def device(self):
184 | return self.accelerator.device
185 |
186 | def save(self, milestone = None):
187 | if not self.accelerator.is_local_main_process:
188 | return
189 | data = {
190 | 'step': self.step,
191 | 'model': self.accelerator.get_state_dict(self.model),
192 | 'opt': self.opt.state_dict(),
193 | 'ema': self.ema.state_dict(),
194 | 'scaler': self.accelerator.scaler.state_dict() if exists(self.accelerator.scaler) else None,
195 | }
196 | checkpoint_save_path = os.path.join(self.results_folder,f'model-{milestone}')
197 | if not os.path.exists(checkpoint_save_path):
198 | os.makedirs(checkpoint_save_path)
199 | torch.save(data, checkpoint_save_path + '/' + f'model-{milestone}.pt')
200 |
201 | def load(self, milestone= None):
202 | accelerator = self.accelerator
203 | device = accelerator.device
204 | checkpoint_save_path = os.path.join(self.results_folder,f'model-{milestone}')
205 | data = torch.load('./ckpt/pretrained.pt', map_location=device)
206 | model = self.accelerator.unwrap_model(self.model)
207 | model.load_state_dict(data['model'])
208 | self.step = data['step']
209 | self.opt.load_state_dict(data['opt'])
210 | self.ema.load_state_dict(data["ema"])
211 | if exists(self.accelerator.scaler) and exists(data['scaler']):
212 | self.accelerator.scaler.load_state_dict(data['scaler'])
213 |
214 | def cal_psnr(self,img_ref, img_gen, data_range = 255.0):
215 | mse = np.mean((img_ref.astype(np.float32)/data_range - img_gen.astype(np.float32)/data_range) ** 2)
216 | if mse < 1.0e-10:
217 | return 100
218 | PIXEL_MAX = 1
219 | return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))
220 |
221 | def cal_ssim(self,img_ref, img_gen):
222 | ssim_val = 0
223 | for i in range(img_ref.shape[-1]):
224 | ssim_val = ssim_val + structural_similarity(img_ref[:,:,i], img_gen[:,:,i])
225 | return ssim_val/img_ref.shape[-1]
226 |
227 | def train(self):
228 | accelerator = self.accelerator
229 | device = accelerator.device
230 | track_metric_json_path = os.path.join(self.results_folder,'metric.json')
231 | if self.accelerator.is_main_process:
232 | create_empty_json(track_metric_json_path)
233 | with tqdm(initial=self.step, total=self.train_num_steps, disable=not accelerator.is_main_process) as pbar:
234 | while self.step < self.train_num_steps:
235 | self.model.train()
236 | hazy_name,hazy_target,hazy_input = next(self.dl_hazy)
237 | light_name,light_target,light_input = next(self.dl_light)
238 | rain_name,rain_target,rain_input = next(self.dl_rain)
239 | snow_name,snow_target,snow_input = next(self.dl_snow)
240 | blur_name,blur_target,blur_input = next(self.dl_blur)
241 | file_name = hazy_name + light_name + rain_name + snow_name + blur_name
242 | file_target = torch.cat([hazy_target,light_target,rain_target,snow_target,blur_target],dim=0)
243 | file_input = torch.cat([hazy_input,light_input,rain_input,snow_input,blur_input],dim=0)
244 | with self.accelerator.autocast():
245 | loss = self.model(img = [file_target.to(device),file_input.to(device)])[0]
246 | self.accelerator.backward(loss)
247 | pbar.set_description(f'loss: {loss.item():.4f}')
248 | accelerator.wait_for_everyone()
249 | accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
250 | self.opt.step()
251 | self.opt.zero_grad()
252 | accelerator.wait_for_everyone()
253 |
254 | if self.step != 0 and divisible_by(self.step, self.save_and_sample_every):
255 | self.test(dataloader = self.dl_eval_rain, degradation = 'Deraining')
256 | self.test(dataloader = self.dl_eval_light, degradation = 'Enlighening')
257 | self.test(dataloader = self.dl_eval_blur, degradation = 'Deblur')
258 | self.test(dataloader = self.dl_eval_snow, degradation = 'Desnowing')
259 | self.test(dataloader = self.dl_eval_hazy, degradation = 'Dehazing')
260 |
261 | if self.accelerator.is_main_process:
262 | if self.step != 0 and divisible_by(self.step, self.save_and_sample_every):
263 | write_json(track_metric_json_path,f'model-{self.step} : ')
264 | degradation_types = ['Enlighening', 'Desnowing', 'Deraining', 'Deblur', 'Dehazing']
265 | for degradation in degradation_types:
266 | json_path = os.path.join(self.results_folder,f'model-{self.step}') + '/{}.json'.format(degradation)
267 | psnr_val,ssim_val = readline_json(json_path,'psnr'),readline_json(json_path,'ssim')
268 | accelerator.print('{} -> (PSNR/SSIM) : {:.6f}/{:.6f} '.format(degradation,psnr_val,ssim_val))
269 | write_json(track_metric_json_path,'{} -> (PSNR/SSIM) : {:.6f}/{:.6f} '.format(degradation,psnr_val,ssim_val))
270 |
271 | accelerator.wait_for_everyone()
272 |
273 | self.ema.update()
274 | if self.accelerator.is_main_process:
275 | if self.step != 0 and divisible_by(self.step, self.save_and_sample_every):
276 | accelerator.print('save model checkpoint')
277 | self.save(self.step)
278 |
279 | accelerator.wait_for_everyone()
280 | self.step += 1
281 | pbar.update(1)
282 |
283 | accelerator.print('Training complete')
284 |
285 | def test(self,dataloader,degradation):
286 | self.accelerator.wait_for_everyone()
287 | if self.accelerator.is_main_process:
288 | start_time = time.time()
289 | save_json_dir = os.path.join(self.results_folder,f'model-{self.step}')
290 | create_folder(save_json_dir)
291 | save_json_path = save_json_dir + '/{}.json'.format(degradation)
292 | create_empty_json(save_json_path)
293 | self.accelerator.wait_for_everyone()
294 | save_json_path = os.path.join(self.results_folder,f'model-{self.step}') + '/{}.json'.format(degradation)
295 | self.ema.model.eval()
296 | for batch_id,batch in enumerate(dataloader):
297 | name_path,image_tf,condi_tf = batch
298 | img_gen = self.ema.model.sample(condi_tf.to(self.device))
299 | for element_id in range(len(name_path)):
300 | image_np_ref = self.tf2img(image_tf[element_id,:,:,].unsqueeze(0))
301 | image_np_gen = self.tf2img(img_gen[element_id,:,:,].unsqueeze(0))
302 |
303 | psnr_val = self.cal_psnr(image_np_ref,image_np_gen)
304 | ssim_val = self.cal_ssim(image_np_ref,image_np_gen)
305 |
306 | data_dump_info = {
307 | 'file_path' : name_path[element_id],
308 | 'psnr' : psnr_val,
309 | 'ssim' : ssim_val,
310 | }
311 | print(batch_id,name_path,'PSNR / SSIM : {:.6f} : {:.6f}'.format(psnr_val,ssim_val))
312 | write_json(save_json_path,data_dump_info)
313 | image_save_dir = os.path.join(self.results_folder,f'model-{self.step}',name_path[element_id].split('/')[-5],name_path[element_id].split('/')[-4])
314 | create_folder(image_save_dir)
315 | imageio.imwrite(os.path.join(image_save_dir,name_path[element_id].split('/')[-1]),image_np_gen)
316 |
317 | if self.accelerator.is_main_process:
318 | end_time = time.time()
319 | test_time_consuming = end_time - start_time
320 | self.accelerator.print('Test_time_consuming : {:.6} s'.format(test_time_consuming))
321 |
322 | self.accelerator.wait_for_everyone()
323 |
324 | def tf2np(self,image_tf):
325 | n,c,h,w = image_tf.size()
326 | assert n == 1
327 | if c == 1:
328 | image_np = image_tf.squeeze(0).squeeze(0).detach().cpu().numpy()
329 | else:
330 | image_np = image_tf.squeeze(0).permute(1,2,0).detach().cpu().numpy()
331 |
332 | return image_np
333 |
334 | def tf2img(self,image_tf):
335 | image_np = self.tf2np(torch.clamp(image_tf,min=0.,max=1.))
336 | image_np = (image_np * 255).astype(np.uint8)
337 | return image_np
338 |
339 | def train_ddp_accelerate(args):
340 | train_folder = ''
341 | eval_folder = ''
342 | print('Procedure Running: ',args.project_description)
343 | image_size = 256
344 | num_unet = 1
345 | objective = 'pred_res'
346 | ddim_sampling_eta = 0.0
347 | test_res_or_noise = "res"
348 | sum_scale = 0.01
349 | delta_end = 2.0e-3
350 | condition = True
351 | sampling_timesteps = 8
352 | model = UnetRes(dim=64, dim_mults=(1, 2, 4, 8),num_unet=num_unet, condition=condition, objective=objective, test_res_or_noise = test_res_or_noise)
353 | diffusion = ResidualDiffusion(model,image_size=image_size, timesteps=1000,delta_end = delta_end,sampling_timesteps=sampling_timesteps, objective=objective,ddim_sampling_eta= ddim_sampling_eta,loss_type='l1',condition=condition,sum_scale=sum_scale,test_res_or_noise = test_res_or_noise)
354 | diffusion_process_trainer = Trainer(
355 | diffusion_model = diffusion,
356 | train_folder = train_folder,
357 | eval_folder = eval_folder,
358 | train_num_steps = 500000,
359 | train_batch_size = 16,
360 | save_and_sample_every = 5000,
361 | save_best_and_latest_only = True,
362 | calculate_metric = True,
363 | results_folder = './save_folder',
364 | gradient_accumulate_every = 1,
365 | )
366 | diffusion_process_trainer.load()
367 | diffusion_process_trainer.train()
368 | print('Procedure Termination: (Finished)')
369 |
370 | if __name__ == '__main__':
371 | args = parser.parse_args()
372 | set_seed(0)
373 | train_ddp_accelerate(args)
--------------------------------------------------------------------------------
/code/model.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import glob
3 | import math
4 | import os
5 | import random
6 | from collections import namedtuple
7 | from functools import partial
8 | from multiprocessing import cpu_count
9 | from pathlib import Path
10 | import torchvision.transforms as transforms
11 |
12 | import cv2
13 | import numpy as np
14 | import torch
15 | import torch.nn.functional as F
16 | import torchvision.transforms.functional as TF
17 | from accelerate import Accelerator
18 | from einops import rearrange, reduce
19 | from einops.layers.torch import Rearrange
20 | from ema_pytorch import EMA
21 | from PIL import Image
22 | import time
23 | from torch import einsum, nn
24 | from torch.optim import Adam, RAdam
25 | from torch.utils.data import DataLoader
26 | from torchvision import transforms as T
27 | from torchvision import utils
28 | from tqdm.auto import tqdm
29 | import copy
30 |
31 | # import torch_npu
32 | # from torch_npu.contrib import transfer_to_npu
33 |
34 | ModelResPrediction = namedtuple('ModelResPrediction', ['pred_res', 'pred_noise', 'pred_x_start'])
35 |
36 | def set_seed(SEED):
37 | torch.manual_seed(SEED)
38 | torch.cuda.manual_seed_all(SEED)
39 | np.random.seed(SEED)
40 | random.seed(SEED)
41 |
42 | def exists(x):
43 | return x is not None
44 |
45 | def default(val, d):
46 | if exists(val):
47 | return val
48 | return d() if callable(d) else d
49 |
50 | def identity(t, *args, **kwargs):
51 | return t
52 |
53 | def normalize_to_neg_one_to_one(img):
54 | if isinstance(img, list):
55 | return [img[k] * 2 - 1 for k in range(len(img))]
56 | else:
57 | return img * 2 - 1
58 |
59 | def unnormalize_to_zero_to_one(img):
60 | if isinstance(img, list):
61 | return [(img[k] + 1) * 0.5 for k in range(len(img))]
62 | else:
63 | return (img + 1) * 0.5
64 |
65 | class Residual(nn.Module):
66 | def __init__(self, fn):
67 | super().__init__()
68 | self.fn = fn
69 |
70 | def forward(self, x, *args, **kwargs):
71 | return self.fn(x, *args, **kwargs) + x
72 |
73 |
74 | def Upsample(dim, dim_out=None):
75 | return nn.Sequential(
76 | nn.Upsample(scale_factor=2, mode='nearest'),
77 | nn.Conv2d(dim, default(dim_out, dim), 3, padding=1)
78 | )
79 |
80 |
81 | def Downsample(dim, dim_out=None):
82 | return nn.Conv2d(dim, default(dim_out, dim), 4, 2, 1)
83 |
84 |
85 | class WeightStandardizedConv2d(nn.Conv2d):
86 | def forward(self, x):
87 | eps = 1e-5 if x.dtype == torch.float32 else 1e-3
88 |
89 | weight = self.weight
90 | mean = reduce(weight, 'o ... -> o 1 1 1', 'mean')
91 | var = reduce(weight, 'o ... -> o 1 1 1',
92 | partial(torch.var, unbiased=False))
93 | normalized_weight = (weight - mean) * (var + eps).rsqrt()
94 |
95 | return F.conv2d(x, normalized_weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
96 |
97 |
98 | class LayerNorm(nn.Module):
99 | def __init__(self, dim):
100 | super().__init__()
101 | self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
102 |
103 | def forward(self, x):
104 | eps = 1e-5 if x.dtype == torch.float32 else 1e-3
105 | var = torch.var(x, dim=1, unbiased=False, keepdim=True)
106 | mean = torch.mean(x, dim=1, keepdim=True)
107 | return (x - mean) * (var + eps).rsqrt() * self.g
108 |
109 |
110 | class PreNorm(nn.Module):
111 | def __init__(self, dim, fn):
112 | super().__init__()
113 | self.fn = fn
114 | self.norm = LayerNorm(dim)
115 |
116 | def forward(self, x):
117 | x = self.norm(x)
118 | return self.fn(x)
119 |
120 | class SinusoidalPosEmb(nn.Module):
121 | def __init__(self, dim):
122 | super().__init__()
123 | self.dim = dim
124 |
125 | def forward(self, x):
126 | device = x.device
127 | half_dim = self.dim // 2
128 | emb = math.log(10000) / (half_dim - 1)
129 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
130 | emb = x[:, None] * emb[None, :]
131 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
132 | return emb
133 |
134 |
135 | class RandomOrLearnedSinusoidalPosEmb(nn.Module):
136 | def __init__(self, dim, is_random=False):
137 | super().__init__()
138 | assert (dim % 2) == 0
139 | half_dim = dim // 2
140 | self.weights = nn.Parameter(torch.randn(
141 | half_dim), requires_grad=not is_random)
142 |
143 | def forward(self, x):
144 | x = rearrange(x, 'b -> b 1')
145 | freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
146 | fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
147 | fouriered = torch.cat((x, fouriered), dim=-1)
148 | return fouriered
149 |
150 | class Block(nn.Module):
151 | def __init__(self, dim, dim_out, groups=8):
152 | super().__init__()
153 | self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding=1)
154 | self.norm = nn.GroupNorm(groups, dim_out)
155 | self.act = nn.SiLU()
156 |
157 | def forward(self, x, scale_shift=None):
158 | x = self.proj(x)
159 | x = self.norm(x)
160 |
161 | if exists(scale_shift):
162 | scale, shift = scale_shift
163 | x = x * (scale + 1) + shift
164 |
165 | x = self.act(x)
166 | return x
167 |
168 |
169 | class ResnetBlock(nn.Module):
170 | def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8):
171 | super().__init__()
172 | self.mlp = nn.Sequential(
173 | nn.SiLU(),
174 | nn.Linear(time_emb_dim, dim_out * 2)
175 | ) if exists(time_emb_dim) else None
176 |
177 | self.block1 = Block(dim, dim_out, groups=groups)
178 | self.block2 = Block(dim_out, dim_out, groups=groups)
179 | self.res_conv = nn.Conv2d(
180 | dim, dim_out, 1) if dim != dim_out else nn.Identity()
181 |
182 | def forward(self, x, time_emb=None):
183 |
184 | scale_shift = None
185 | if exists(self.mlp) and exists(time_emb):
186 | time_emb = self.mlp(time_emb)
187 | time_emb = rearrange(time_emb, 'b c -> b c 1 1')
188 | scale_shift = time_emb.chunk(2, dim=1)
189 |
190 | h = self.block1(x, scale_shift=scale_shift)
191 |
192 | h = self.block2(h)
193 |
194 | return h + self.res_conv(x)
195 |
196 |
197 | class LinearAttention(nn.Module):
198 | def __init__(self, dim, heads=4, dim_head=32):
199 | super().__init__()
200 | self.scale = dim_head ** -0.5
201 | self.heads = heads
202 | hidden_dim = dim_head * heads
203 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
204 |
205 | self.to_out = nn.Sequential(
206 | nn.Conv2d(hidden_dim, dim, 1),
207 | LayerNorm(dim)
208 | )
209 |
210 | def forward(self, x):
211 | b, c, h, w = x.shape
212 | qkv = self.to_qkv(x).chunk(3, dim=1)
213 | q, k, v = map(lambda t: rearrange(
214 | t, 'b (h c) x y -> b h c (x y)', h=self.heads), qkv)
215 |
216 | q = q.softmax(dim=-2)
217 | k = k.softmax(dim=-1)
218 |
219 | q = q * self.scale
220 | v = v / (h * w)
221 |
222 | context = torch.einsum('b h d n, b h e n -> b h d e', k, v)
223 |
224 | out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
225 | out = rearrange(out, 'b h c (x y) -> b (h c) x y',
226 | h=self.heads, x=h, y=w)
227 | return self.to_out(out)
228 |
229 |
230 | class Attention(nn.Module):
231 | def __init__(self, dim, heads=4, dim_head=32):
232 | super().__init__()
233 | self.scale = dim_head ** -0.5
234 | self.heads = heads
235 | hidden_dim = dim_head * heads
236 |
237 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
238 | self.to_out = nn.Conv2d(hidden_dim, dim, 1)
239 |
240 | def forward(self, x):
241 | b, c, h, w = x.shape
242 | qkv = self.to_qkv(x).chunk(3, dim=1)
243 | q, k, v = map(lambda t: rearrange(
244 | t, 'b (h c) x y -> b h c (x y)', h=self.heads), qkv)
245 |
246 | q = q * self.scale
247 |
248 | sim = einsum('b h d i, b h d j -> b h i j', q, k)
249 | attn = sim.softmax(dim=-1)
250 | out = einsum('b h i j, b h d j -> b h i d', attn, v)
251 |
252 | out = rearrange(out, 'b h (x y) d -> b (h d) x y', x=h, y=w)
253 | return self.to_out(out)
254 |
255 |
256 | class Unet(nn.Module):
257 | def __init__(
258 | self,
259 | dim,
260 | init_dim=None,
261 | out_dim=None,
262 | dim_mults=(1, 2, 4, 8),
263 | channels=3,
264 | resnet_block_groups=8,
265 | learned_variance=False,
266 | learned_sinusoidal_cond=False,
267 | random_fourier_features=False,
268 | learned_sinusoidal_dim=16,
269 | condition=False,
270 | ):
271 | super().__init__()
272 |
273 | self.channels = channels
274 | self.depth = len(dim_mults)
275 | input_channels = channels + channels * (1 if condition else 0)
276 |
277 | init_dim = default(init_dim, dim)
278 | self.init_conv = nn.Conv2d(input_channels, init_dim, 7, padding=3)
279 |
280 | dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
281 | in_out = list(zip(dims[:-1], dims[1:]))
282 |
283 | block_klass = partial(ResnetBlock, groups=resnet_block_groups)
284 |
285 | time_dim = dim * 4
286 |
287 | self.random_or_learned_sinusoidal_cond = learned_sinusoidal_cond or random_fourier_features
288 |
289 | if self.random_or_learned_sinusoidal_cond:
290 | sinu_pos_emb = RandomOrLearnedSinusoidalPosEmb(
291 | learned_sinusoidal_dim, random_fourier_features)
292 | fourier_dim = learned_sinusoidal_dim + 1
293 | else:
294 | sinu_pos_emb = SinusoidalPosEmb(dim)
295 | fourier_dim = dim
296 |
297 | self.time_mlp = nn.Sequential(
298 | sinu_pos_emb,
299 | nn.Linear(fourier_dim, time_dim),
300 | nn.GELU(),
301 | nn.Linear(time_dim, time_dim)
302 | )
303 |
304 | self.downs = nn.ModuleList([])
305 | self.ups = nn.ModuleList([])
306 | num_resolutions = len(in_out)
307 |
308 | for ind, (dim_in, dim_out) in enumerate(in_out):
309 | is_last = ind >= (num_resolutions - 1)
310 |
311 | self.downs.append(nn.ModuleList([
312 | block_klass(dim_in, dim_in, time_emb_dim=time_dim),
313 | block_klass(dim_in, dim_in, time_emb_dim=time_dim),
314 | Residual(PreNorm(dim_in, LinearAttention(dim_in))),
315 | Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(
316 | dim_in, dim_out, 3, padding=1)
317 | ]))
318 |
319 | mid_dim = dims[-1]
320 | self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
321 | self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
322 | self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
323 |
324 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
325 | is_last = ind == (len(in_out) - 1)
326 |
327 | self.ups.append(nn.ModuleList([
328 | block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim),
329 | block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim),
330 | Residual(PreNorm(dim_out, LinearAttention(dim_out))),
331 | Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(
332 | dim_out, dim_in, 3, padding=1)
333 | ]))
334 |
335 | default_out_dim = channels * (1 if not learned_variance else 2)
336 | self.out_dim = default(out_dim, default_out_dim)
337 |
338 | self.final_res_block = block_klass(dim * 2, dim, time_emb_dim=time_dim)
339 | self.final_conv = nn.Conv2d(dim, self.out_dim, 1)
340 |
341 | def check_image_size(self, x, h, w):
342 | s = int(math.pow(2, self.depth))
343 | mod_pad_h = (s - h % s) % s
344 | mod_pad_w = (s - w % s) % s
345 | x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
346 | return x
347 |
348 | def forward(self, x, time):
349 | H, W = x.shape[2:]
350 | x = self.check_image_size(x, H, W)
351 | x = self.init_conv(x)
352 | r = x.clone()
353 |
354 | t = self.time_mlp(time)
355 |
356 | h = []
357 |
358 | for block1, block2, attn, downsample in self.downs:
359 | x = block1(x, t)
360 | h.append(x)
361 |
362 | x = block2(x, t)
363 | x = attn(x)
364 | h.append(x)
365 |
366 | x = downsample(x)
367 |
368 | x = self.mid_block1(x, t)
369 | x = self.mid_attn(x)
370 | x = self.mid_block2(x, t)
371 |
372 | for block1, block2, attn, upsample in self.ups:
373 | x = torch.cat((x, h.pop()), dim=1)
374 | x = block1(x, t)
375 |
376 | x = torch.cat((x, h.pop()), dim=1)
377 | x = block2(x, t)
378 | x = attn(x)
379 |
380 | x = upsample(x)
381 |
382 | x = torch.cat((x, r), dim=1)
383 |
384 | x = self.final_res_block(x, t)
385 | x = self.final_conv(x)
386 | x = x[..., :H, :W].contiguous()
387 | return x
388 |
389 |
390 | class UnetRes(nn.Module):
391 | def __init__(
392 | self,
393 | dim,
394 | init_dim=None,
395 | out_dim=None,
396 | dim_mults=(1, 2, 4, 8),
397 | channels=3,
398 | resnet_block_groups=8,
399 | learned_variance=False,
400 | learned_sinusoidal_cond=False,
401 | random_fourier_features=False,
402 | learned_sinusoidal_dim=16,
403 | num_unet=1,
404 | condition=False,
405 | objective='pred_res_noise',
406 | test_res_or_noise="res_noise"
407 | ):
408 | super().__init__()
409 | self.condition = condition
410 | self.channels = channels
411 | default_out_dim = channels * (1 if not learned_variance else 2)
412 | self.out_dim = default(out_dim, default_out_dim)
413 | self.random_or_learned_sinusoidal_cond = learned_sinusoidal_cond or random_fourier_features
414 | self.num_unet = num_unet
415 | self.objective = objective
416 | self.test_res_or_noise = test_res_or_noise
417 |
418 | self.unet0 = Unet(dim,
419 | init_dim=init_dim,
420 | out_dim=out_dim,
421 | dim_mults=dim_mults,
422 | channels=channels,
423 | resnet_block_groups=resnet_block_groups,
424 | learned_variance=learned_variance,
425 | learned_sinusoidal_cond=learned_sinusoidal_cond,
426 | random_fourier_features=random_fourier_features,
427 | learned_sinusoidal_dim=learned_sinusoidal_dim,
428 | condition=condition)
429 |
430 | def forward(self, x, time):
431 | if self.objective == "pred_noise":
432 | time = time[1]
433 | elif self.objective == "pred_res":
434 | time = time[0]
435 | return [self.unet0(x, time)]
436 |
437 | def extract(a, t, x_shape):
438 | b, *_ = t.shape
439 | out = a.gather(-1, t)
440 | return out.reshape(b, *((1,) * (len(x_shape) - 1)))
441 |
442 | def gen_coefficients(timesteps, schedule="increased", sum_scale=1, ratio=1):
443 | if schedule == "increased":
444 | x = np.linspace(0, 1, timesteps, dtype=np.float32)
445 | y = x**ratio
446 | y = torch.from_numpy(y)
447 | y_sum = y.sum()
448 | alphas = y/y_sum
449 | elif schedule == "decreased":
450 | x = np.linspace(0, 1, timesteps, dtype=np.float32)
451 | y = x**ratio
452 | y = torch.from_numpy(y)
453 | y_sum = y.sum()
454 | y = torch.flip(y, dims=[0])
455 | alphas = y/y_sum
456 | elif schedule == "lamda":
457 | x = np.linspace(0.0001, 0.02, timesteps, dtype=np.float32)
458 | y = x**ratio
459 | y = torch.from_numpy(y)
460 | alphas = 1 - y
461 | elif schedule == "average":
462 | alphas = torch.full([timesteps], 1/timesteps, dtype=torch.float32)
463 | elif schedule == "normal":
464 | sigma = 1.0
465 | mu = 0.0
466 | x = np.linspace(-3+mu, 3+mu, timesteps, dtype=np.float32)
467 | y = np.e**(-((x-mu)**2)/(2*(sigma**2)))/(np.sqrt(2*np.pi)*(sigma**2))
468 | y = torch.from_numpy(y)
469 | alphas = y/y.sum()
470 | else:
471 | alphas = torch.full([timesteps], 1/timesteps, dtype=torch.float32)
472 |
473 | return alphas*sum_scale
474 |
475 | def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor:
476 | def alpha_bar(time_step):
477 | return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
478 |
479 | betas = []
480 | for i in range(num_diffusion_timesteps):
481 | t1 = i / num_diffusion_timesteps
482 | t2 = (i + 1) / num_diffusion_timesteps
483 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
484 | return torch.tensor(betas, dtype=torch.float32)
485 |
486 | class ResidualDiffusion(nn.Module):
487 | def __init__(
488 | self,
489 | model,
490 | *,
491 | image_size,
492 | timesteps=1000,
493 | delta_end = 2.0e-3,
494 | sampling_timesteps=None,
495 | loss_type='l1',
496 | objective='pred_res_noise',
497 | ddim_sampling_eta= 0,
498 | condition=False,
499 | sum_scale=None,
500 | test_res_or_noise="None",
501 | ):
502 | super().__init__()
503 | assert not (
504 | type(self) == ResidualDiffusion and model.channels != model.out_dim)
505 | assert not model.random_or_learned_sinusoidal_cond
506 |
507 | self.model = model
508 | self.channels = self.model.channels
509 | self.image_size = image_size
510 | self.objective = objective
511 | self.condition = condition
512 | self.test_res_or_noise = test_res_or_noise
513 | self.delta_end = delta_end
514 |
515 | if self.condition:
516 | self.sum_scale = sum_scale if sum_scale else 0.01
517 | else:
518 | self.sum_scale = sum_scale if sum_scale else 1.
519 |
520 | beta_schedule = "linear"
521 | beta_start = 0.0001
522 | beta_end = 0.02
523 | if beta_schedule == "linear":
524 | betas = torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float32)
525 | elif beta_schedule == "scaled_linear":
526 | betas = (torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype=torch.float32) ** 2)
527 | elif beta_schedule == "squaredcos_cap_v2":
528 | betas = betas_for_alpha_bar(timesteps)
529 | else:
530 | raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
531 |
532 | delta_start = 1e-6
533 | delta = torch.linspace(delta_start, self.delta_end, timesteps, dtype=torch.float32)
534 | delta_cumsum = delta.cumsum(dim=0).clip(0, 1)
535 |
536 | alphas = 1.0 - betas
537 | alphas_cumprod = torch.cumprod(alphas, dim=0)
538 | alphas_cumsum = 1-alphas_cumprod ** 0.5
539 | betas2_cumsum = 1-alphas_cumprod
540 |
541 | alphas_cumsum_prev = F.pad(alphas_cumsum[:-1], (1, 0), value=1.)
542 | betas2_cumsum_prev = F.pad(betas2_cumsum[:-1], (1, 0), value=1.)
543 | delta_cumsum_prev = F.pad(delta_cumsum[:-1], (1, 0), value=1.)
544 | alphas = alphas_cumsum-alphas_cumsum_prev
545 | alphas[0] = 0
546 | betas2 = betas2_cumsum-betas2_cumsum_prev
547 | betas2[0] = 0
548 | betas_cumsum = torch.sqrt(betas2_cumsum)
549 |
550 | posterior_variance = betas2*betas2_cumsum_prev/betas2_cumsum
551 | posterior_variance[0] = 0
552 |
553 | timesteps, = alphas.shape
554 | self.num_timesteps = int(timesteps)
555 | self.loss_type = loss_type
556 |
557 | self.sampling_timesteps = default(sampling_timesteps, timesteps)
558 |
559 | assert self.sampling_timesteps <= timesteps
560 | self.is_ddim_sampling = self.sampling_timesteps < timesteps
561 | self.ddim_sampling_eta = ddim_sampling_eta
562 |
563 | def register_buffer(name, val): return self.register_buffer(
564 | name, val.to(torch.float32))
565 |
566 | register_buffer('alphas', alphas)
567 | register_buffer('alphas_cumsum', alphas_cumsum)
568 | register_buffer('delta', delta)
569 | register_buffer('delta_cumsum', delta_cumsum)
570 | register_buffer('one_minus_alphas_cumsum', 1-alphas_cumsum)
571 | register_buffer('betas2', betas2)
572 | register_buffer('betas', torch.sqrt(betas2))
573 | register_buffer('betas2_cumsum', betas2_cumsum)
574 | register_buffer('betas_cumsum', betas_cumsum)
575 | register_buffer('posterior_mean_coef1',
576 | betas2_cumsum_prev/betas2_cumsum)
577 | register_buffer('posterior_mean_coef2',
578 | (betas2_cumsum_prev)/(betas2_cumsum)*(alphas - delta) + (betas2)/(betas2_cumsum)*(alphas_cumsum_prev - delta_cumsum_prev)
579 | )
580 | register_buffer('posterior_mean_coef3', delta + betas2/betas2_cumsum*(1 - delta_cumsum_prev))
581 | register_buffer('posterior_variance', posterior_variance)
582 | register_buffer('posterior_log_variance_clipped',
583 | torch.log(posterior_variance.clamp(min=1e-20)))
584 |
585 | self.posterior_mean_coef1[0] = 0
586 | self.posterior_mean_coef2[0] = 0
587 | self.posterior_mean_coef3[0] = 1
588 | self.one_minus_alphas_cumsum[-1] = 1e-6
589 |
590 |
591 | def predict_noise_from_res(self, x_t, t, x_input, pred_res):
592 | return (
593 | (x_t - (1-extract(self.delta_cumsum,t,x_t.shape)) * x_input - (extract(self.alphas_cumsum, t, x_t.shape)-1) * pred_res) /extract(self.betas_cumsum, t, x_t.shape)
594 | )
595 |
596 | def predict_start_from_xinput_noise(self, x_t, t, x_input, noise):
597 | return (
598 | (x_t-extract(self.alphas_cumsum, t, x_t.shape)*x_input -
599 | extract(self.betas_cumsum, t, x_t.shape) * noise + extract(self.delta_cumsum, t, x_t.shape) * x_input )/extract(self.one_minus_alphas_cumsum, t, x_t.shape)
600 | )
601 |
602 | def predict_start_from_res_noise(self, x_t, t, x_res, noise, x_input):
603 | return (
604 | x_t-extract(self.alphas_cumsum, t, x_t.shape) * x_res -
605 | extract(self.betas_cumsum, t, x_t.shape) * noise + extract(self.delta_cumsum, t, x_t.shape) * x_input
606 | )
607 |
608 | def q_posterior_from_res_noise(self, x_res, noise, x_t, t, x_input):
609 | return (x_t-extract(self.alphas, t, x_t.shape) * x_res + extract(self.delta, t, x_t.shape) * x_input -
610 | (extract(self.betas2, t, x_t.shape)/extract(self.betas_cumsum, t, x_t.shape)) * noise)
611 |
612 | def q_posterior(self, pred_res, x_start, x_t, t):
613 | posterior_mean = (
614 | extract(self.posterior_mean_coef1, t, x_t.shape) * x_t +
615 | extract(self.posterior_mean_coef2, t, x_t.shape) * pred_res +
616 | extract(self.posterior_mean_coef3, t, x_t.shape) * x_start
617 | )
618 | posterior_variance = extract(self.posterior_variance, t, x_t.shape)
619 | posterior_log_variance_clipped = extract(
620 | self.posterior_log_variance_clipped, t, x_t.shape)
621 | return posterior_mean, posterior_variance, posterior_log_variance_clipped
622 |
623 | def model_predictions(self, x_input, x, t, task=None, clip_denoised=True):
624 | if not self.condition:
625 | x_in = x
626 | else:
627 | x_in = torch.cat((x, x_input), dim=1)
628 | model_output = self.model(x_in,[t,t])
629 |
630 | maybe_clip = partial(torch.clamp, min=-1.,
631 | max=1.) if clip_denoised else identity
632 |
633 | if self.objective == 'pred_res_noise':
634 | if self.test_res_or_noise == "res_noise":
635 | pred_res = model_output[0]
636 | pred_noise = model_output[1]
637 | pred_res = maybe_clip(pred_res)
638 | x_start = self.predict_start_from_res_noise(
639 | x, t, pred_res, pred_noise, x_input)
640 | x_start = maybe_clip(x_start)
641 | elif self.test_res_or_noise == "res":
642 | pred_res = model_output[0]
643 | pred_res = maybe_clip(pred_res)
644 | pred_noise = self.predict_noise_from_res(
645 | x, t, x_input, pred_res)
646 | x_start = x_input - pred_res
647 | x_start = maybe_clip(x_start)
648 | elif self.test_res_or_noise == "noise":
649 | pred_noise = model_output[1]
650 | x_start = self.predict_start_from_xinput_noise(
651 | x, t, x_input, pred_noise)
652 | x_start = maybe_clip(x_start)
653 | pred_res = x_input - x_start
654 | pred_res = maybe_clip(pred_res)
655 | elif self.objective == 'pred_x0_noise':
656 | pred_res = x_input-model_output[0]
657 | pred_noise = model_output[1]
658 | pred_res = maybe_clip(pred_res)
659 | x_start = maybe_clip(model_output[0])
660 | elif self.objective == "pred_noise":
661 | pred_noise = model_output[0]
662 | x_start = self.predict_start_from_xinput_noise(
663 | x, t, x_input, pred_noise)
664 | x_start = maybe_clip(x_start)
665 | pred_res = x_input - x_start
666 | pred_res = maybe_clip(pred_res)
667 | elif self.objective == "pred_res":
668 | pred_res = model_output[0]
669 | pred_res = maybe_clip(pred_res)
670 | pred_noise = self.predict_noise_from_res(x, t, x_input, pred_res)
671 | x_start = self.predict_start_from_res_noise(x, t, pred_res, pred_noise, x_input)
672 | x_start = maybe_clip(x_start)
673 |
674 | return ModelResPrediction(pred_res, pred_noise, x_start)
675 |
676 | def p_mean_variance(self, x_input, x, t):
677 | preds = self.model_predictions(x_input, x, t)
678 | pred_res = preds.pred_res
679 | x_start = preds.pred_x_start
680 |
681 | model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
682 | pred_res=pred_res, x_start=x_start, x_t=x, t=t)
683 | return model_mean, posterior_variance, posterior_log_variance, x_start
684 |
685 | @torch.no_grad()
686 | def p_sample(self, x_input, x, t: int):
687 | b, *_, device = *x.shape, x.device
688 | batched_times = torch.full(
689 | (x.shape[0],), t, device=x.device, dtype=torch.long)
690 | model_mean, _, model_log_variance, x_start = self.p_mean_variance(x_input, x=x, t=batched_times)
691 | noise = torch.randn_like(x) if t > 0 else 0.
692 | pred_img = model_mean + (0.5 * model_log_variance).exp() * noise
693 | return pred_img, x_start
694 |
695 | @torch.no_grad()
696 | def p_sample_loop(self, x_input, shape, last=True):
697 | x_input = x_input[0]
698 |
699 | batch, device = shape[0], self.betas.device
700 |
701 | if self.condition:
702 | img = x_input+math.sqrt(self.sum_scale) * \
703 | torch.randn(shape, device=device)
704 | input_add_noise = img
705 | else:
706 | img = torch.randn(shape, device=device)
707 |
708 | x_start = None
709 |
710 | if not last:
711 | img_list = []
712 |
713 | for t in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
714 | img, x_start = self.p_sample(x_input, img, t)
715 |
716 | if not last:
717 | img_list.append(img)
718 |
719 | if self.condition:
720 | if not last:
721 | img_list = [input_add_noise]+img_list
722 | else:
723 | img_list = [input_add_noise, img]
724 | return unnormalize_to_zero_to_one(img_list)
725 | else:
726 | if not last:
727 | img_list = img_list
728 | else:
729 | img_list = [img]
730 | return unnormalize_to_zero_to_one(img_list)
731 |
732 | @torch.no_grad()
733 | def first_order_sample(self, x_input, shape, last=True, task=None):
734 | x_input = x_input[0]
735 | batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[
736 | 0], self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective
737 |
738 | times = torch.linspace(-1, total_timesteps - 1,
739 | steps=sampling_timesteps + 1)
740 |
741 | times = list(reversed(times.int().tolist()))
742 | time_pairs = list(zip(times[:-1], times[1:]))
743 |
744 | if self.condition:
745 | img = self.betas_cumsum[-1] * torch.randn(shape, device=device)
746 | input_add_noise = img
747 | else:
748 | img = torch.randn(shape, device=device)
749 |
750 | x_start = None
751 | type = "use_pred_noise"
752 | if not last:
753 | img_list = []
754 |
755 | for time, time_next in tqdm(time_pairs, desc='sampling loop time step',disable = True):
756 | time_cond = torch.full(
757 | (batch,), time, device=device, dtype=torch.long)
758 | preds = self.model_predictions(x_input, img, time_cond, task)
759 |
760 | pred_res = preds.pred_res
761 | pred_noise = preds.pred_noise
762 | x_start = preds.pred_x_start
763 |
764 | if time_next < 0:
765 | img = x_start
766 | if not last:
767 | img_list.append(img)
768 | continue
769 |
770 | alpha_cumsum = self.alphas_cumsum[time]
771 | alpha_cumsum_next = self.alphas_cumsum[time_next]
772 | alpha = alpha_cumsum-alpha_cumsum_next
773 | delta_cumsum = self.delta_cumsum[time]
774 | delta_cumsum_next = self.delta_cumsum[time_next]
775 | delta = delta_cumsum-delta_cumsum_next
776 | betas2_cumsum = self.betas2_cumsum[time]
777 | betas2_cumsum_next = self.betas2_cumsum[time_next]
778 | betas2 = betas2_cumsum-betas2_cumsum_next
779 | betas = betas2.sqrt()
780 | betas_cumsum = self.betas_cumsum[time]
781 | betas_cumsum_next = self.betas_cumsum[time_next]
782 | betas2_div_betas_cumsum = betas_cumsum-betas_cumsum_next
783 |
784 | if type == "use_pred_noise":
785 | img = img - alpha*pred_res + delta*x_input - betas2_div_betas_cumsum * pred_noise
786 |
787 | elif type == "use_x_start":
788 | img = q*img + \
789 | (1-q)*x_start + \
790 | (alpha_cumsum_next-alpha_cumsum*q)*pred_res + \
791 | (delta_cumsum*q-delta_cumsum_next)*x_input + \
792 | sigma2.sqrt()*noise
793 | elif type == "special_eta_0":
794 | img = img - alpha*pred_res - \
795 | (betas_cumsum-betas_cumsum_next)*pred_noise
796 | elif type == "special_eta_1":
797 | img = img - alpha*pred_res - betas2/betas_cumsum*pred_noise + \
798 | betas*betas2_cumsum_next.sqrt()/betas_cumsum*noise
799 |
800 | if not last:
801 | img_list.append(img)
802 |
803 | if self.condition:
804 | if not last:
805 | img_list = [input_add_noise]+img_list
806 | else:
807 | img_list = [input_add_noise, img]
808 | return unnormalize_to_zero_to_one(img_list)
809 | else:
810 | if not last:
811 | img_list = img_list
812 | else:
813 | img_list = [img]
814 | return unnormalize_to_zero_to_one(img_list)
815 |
816 |
817 | @torch.no_grad()
818 | def second_order_sample(self, x_input, shape, last=True, task=None):
819 | x_input = x_input[0]
820 | batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[
821 | 0], self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective
822 |
823 | times = torch.linspace(-1, total_timesteps - 1,steps=sampling_timesteps + 1)
824 | times = list(reversed(times.int().tolist()))
825 | time_pairs = list(zip(times[:-1], times[1:]))
826 |
827 | if self.condition:
828 | img = (1-self.delta_cumsum[-1]) * x_input + math.sqrt(self.sum_scale) * torch.randn(shape, device=device)
829 | input_add_noise = img
830 | else:
831 | img = torch.randn(shape, device=device)
832 |
833 | x_start = None
834 | type = "use_pred_noise"
835 |
836 | if not last:
837 | img_list = []
838 |
839 | r = 0.5
840 | for time, time_next in tqdm(time_pairs, desc='sampling loop time step',disable = True):
841 | time_internal = int((1-r) * time + r * time_next)
842 | time_cond = torch.full((batch,), time, device=device, dtype=torch.long)
843 | time_cond_internal = torch.full((batch,), time_internal, device=device, dtype=torch.long)
844 | time_cond_next = torch.full((batch,), time_next, device=device, dtype=torch.long)
845 |
846 | alpha_cumsum = self.alphas_cumsum[time]
847 | alpha_cumsum_internal = self.alphas_cumsum[time_internal]
848 | alpha_cumsum_next = self.alphas_cumsum[time_next]
849 | alpha_u = alpha_cumsum-alpha_cumsum_internal
850 | alpha_t = alpha_cumsum-alpha_cumsum_next
851 | delta_cumsum = self.delta_cumsum[time]
852 | delta_cumsum_internal = self.delta_cumsum[time_internal]
853 | delta_cumsum_next = self.delta_cumsum[time_next]
854 | delta_u = delta_cumsum-delta_cumsum_internal
855 | delta_t = delta_cumsum-delta_cumsum_next
856 | betas2_cumsum = self.betas2_cumsum[time]
857 | betas2_cumsum_internal = self.betas2_cumsum[time_internal]
858 | betas2_cumsum_next = self.betas2_cumsum[time_next]
859 | betas2_u = betas2_cumsum-betas2_cumsum_internal
860 | betas2_t = betas2_cumsum-betas2_cumsum_next
861 | betas_cumsum = self.betas_cumsum[time]
862 | betas_cumsum_internal = self.betas_cumsum[time_internal]
863 | betas_cumsum_next = self.betas_cumsum[time_next]
864 | betas_u = betas_cumsum-betas_cumsum_internal
865 | betas_t = betas_cumsum-betas_cumsum_next
866 |
867 | preds_time_cond = self.model_predictions(x_input, img, time_cond)
868 | pred_res_time_cond = preds_time_cond.pred_res
869 | pred_noise_time_cond = preds_time_cond.pred_noise
870 | x_start_time_cond = preds_time_cond.pred_x_start
871 |
872 | if time_next < 0:
873 | img = x_start_time_cond
874 | if not last:
875 | img_list.append(img)
876 | continue
877 |
878 | img_u = img + delta_u*x_input - alpha_u*pred_res_time_cond - betas_u * pred_noise_time_cond
879 |
880 | preds_time_cond_internal = self.model_predictions(x_input, img_u.clone().detach(), time_cond_internal)
881 | pred_res_time_cond_internal = preds_time_cond_internal.pred_res
882 | pred_noise_time_cond_internal = preds_time_cond_internal.pred_noise
883 | x_start_time_cond_internal = preds_time_cond_internal.pred_x_start
884 |
885 | img_target = img + delta_t*x_input - alpha_t*pred_res_time_cond - betas_u * pred_noise_time_cond
886 | - 1/(2*r)*alpha_t * (pred_res_time_cond_internal - pred_res_time_cond)
887 | - 1/(2*r)*betas_t * (pred_noise_time_cond_internal - pred_noise_time_cond)
888 |
889 | img = img_target.clone().detach()
890 |
891 | if not last:
892 | img_list.append(img)
893 |
894 | if self.condition:
895 | if not last:
896 | img_list = [input_add_noise]+img_list
897 | else:
898 | img_list = [input_add_noise, img]
899 | return unnormalize_to_zero_to_one(img_list)
900 | else:
901 | if not last:
902 | img_list = img_list
903 | else:
904 | img_list = [img]
905 | return unnormalize_to_zero_to_one(img_list)
906 |
907 | @torch.no_grad()
908 | def third_order_sample(self, x_input, shape, last=True, task=None):
909 | x_input = x_input[0]
910 | batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[
911 | 0], self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective
912 |
913 | times = torch.linspace(-1, total_timesteps - 1,steps=sampling_timesteps + 1)
914 | times = list(reversed(times.int().tolist()))
915 | time_pairs = list(zip(times[:-1], times[1:]))
916 |
917 | if self.condition:
918 | img = (1-self.delta_cumsum[-1]) * x_input + math.sqrt(self.sum_scale) * torch.randn(shape, device=device)
919 | input_add_noise = img
920 | else:
921 | img = torch.randn(shape, device=device)
922 |
923 | x_start = None
924 | type = "use_pred_noise"
925 |
926 | if not last:
927 | img_list = []
928 |
929 | r1 = 1/3
930 | r2 = 2/3
931 | for time, time_next in tqdm(time_pairs, desc='sampling loop time step',disable = True):
932 | time_u = int((1-r1) * time + r1 * time_next)
933 | time_s = int((1-r2) * time + r2 * time_next)
934 | time_cond = torch.full((batch,), time, device=device, dtype=torch.long)
935 | time_cond_u = torch.full((batch,), time_u, device=device, dtype=torch.long)
936 | time_cond_s = torch.full((batch,), time_s, device=device, dtype=torch.long)
937 | time_cond_next = torch.full((batch,), time_next, device=device, dtype=torch.long)
938 |
939 | alpha_cumsum = self.alphas_cumsum[time]
940 | alpha_cumsum_u = self.alphas_cumsum[time_u]
941 | alpha_cumsum_s = self.alphas_cumsum[time_s]
942 | alpha_cumsum_next = self.alphas_cumsum[time_next]
943 | alpha_u = alpha_cumsum-alpha_cumsum_u
944 | alpha_s = alpha_cumsum-alpha_cumsum_s
945 | alpha_t = alpha_cumsum-alpha_cumsum_next
946 | delta_cumsum = self.delta_cumsum[time]
947 | delta_cumsum_u = self.delta_cumsum[time_u]
948 | delta_cumsum_s = self.delta_cumsum[time_s]
949 | delta_cumsum_next = self.delta_cumsum[time_next]
950 | delta_u = delta_cumsum-delta_cumsum_u
951 | delta_s = delta_cumsum-delta_cumsum_s
952 | delta_t = delta_cumsum-delta_cumsum_next
953 | betas2_cumsum = self.betas2_cumsum[time]
954 | betas2_cumsum_u = self.betas2_cumsum[time_u]
955 | betas2_cumsum_s = self.betas2_cumsum[time_s]
956 | betas2_cumsum_next = self.betas2_cumsum[time_next]
957 | betas2_u = betas2_cumsum-betas2_cumsum_u
958 | betas2_s = betas2_cumsum-betas2_cumsum_s
959 | betas2_t = betas2_cumsum-betas2_cumsum_next
960 | betas_cumsum = self.betas_cumsum[time]
961 | betas_cumsum_u = self.betas_cumsum[time_u]
962 | betas_cumsum_s = self.betas_cumsum[time_s]
963 | betas_cumsum_next = self.betas_cumsum[time_next]
964 | betas_u = betas_cumsum-betas_cumsum_u
965 | betas_s = betas_cumsum-betas_cumsum_s
966 | betas_t = betas_cumsum-betas_cumsum_next
967 |
968 | preds_time_cond = self.model_predictions(x_input, img, time_cond)
969 | pred_res_time_cond = preds_time_cond.pred_res
970 | pred_noise_time_cond = preds_time_cond.pred_noise
971 | x_start_time_cond = preds_time_cond.pred_x_start
972 |
973 | if time_next < 0:
974 | img = x_start_time_cond
975 | if not last:
976 | img_list.append(img)
977 | continue
978 |
979 | img_u = img + delta_u*x_input - alpha_u*pred_res_time_cond - betas_u * pred_noise_time_cond
980 | preds_time_u = self.model_predictions(x_input, img_u, time_cond_u)
981 | pred_res_time_u = preds_time_u.pred_res
982 | pred_noise_time_u = preds_time_u.pred_noise
983 | x_start_time_u = preds_time_u.pred_x_start
984 |
985 | img_s = img + delta_s*x_input - alpha_s*pred_res_time_cond - betas_s * pred_noise_time_cond
986 | preds_time_s = self.model_predictions(x_input, img_s, time_cond_s)
987 | pred_res_time_s = preds_time_s.pred_res
988 | pred_noise_time_s = preds_time_s.pred_noise
989 | x_start_time_s = preds_time_s.pred_x_start
990 |
991 |
992 | D1_res = pred_res_time_u - pred_res_time_cond
993 | D1_eps = pred_noise_time_u - pred_noise_time_cond
994 | D2_res = (2/(r1*r2*(r2-r1)))*(r1*pred_res_time_s - r2*pred_res_time_u + (r2-r1)*pred_res_time_cond)
995 | D2_eps = (2/(r1*r2*(r2-r1)))*(r1*pred_noise_time_s - r2*pred_noise_time_u + (r2-r1)*pred_noise_time_cond)
996 |
997 | img_target = img + delta_t*x_input - alpha_t*pred_res_time_cond - betas_u * pred_noise_time_cond
998 | - 1/(2*r1)*alpha_t * D1_res - 1/(2*r1)*betas_t * D1_eps
999 | - 1/6*alpha_t*D2_res - 1/6*betas_t*D2_eps
1000 |
1001 | img = img_target
1002 |
1003 | if not last:
1004 | img_list.append(img)
1005 |
1006 | if self.condition:
1007 | if not last:
1008 | img_list = [input_add_noise]+img_list
1009 | else:
1010 | img_list = [input_add_noise, img]
1011 | return unnormalize_to_zero_to_one(img_list)
1012 | else:
1013 | if not last:
1014 | img_list = img_list
1015 | else:
1016 | img_list = [img]
1017 | return unnormalize_to_zero_to_one(img_list)
1018 |
1019 |
1020 | def grad_and_value(self, x_prev, x_0_hat, y0, x_res):
1021 | x_pre = x_prev
1022 | difference = y0- (x_0_hat + x_res)
1023 | norm = torch.linalg.norm(difference)
1024 | norm_grad = torch.autograd.grad(outputs=norm, inputs=x_prev)[0]
1025 | return norm_grad, norm
1026 |
1027 | def first_order_UPS(self, x_input, shape, last=True, task=None):
1028 | x_input = x_input[0]
1029 | batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[
1030 | 0], self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective
1031 |
1032 | times = torch.linspace(-1, total_timesteps - 1, steps=sampling_timesteps + 1)
1033 |
1034 | times = list(reversed(times.int().tolist()))
1035 | time_pairs = list(zip(times[:-1], times[1:]))
1036 |
1037 | if self.condition:
1038 | img = (1-self.delta_cumsum[-1]) * x_input + self.betas_cumsum[-1] * torch.randn(shape, device=device)
1039 | input_add_noise = img
1040 | else:
1041 | img = torch.randn(shape, device=device)
1042 |
1043 | x_start = None
1044 | type = "use_pred_noise"
1045 |
1046 | if not last:
1047 | img_list = []
1048 |
1049 | for time, time_next in tqdm(time_pairs, desc='sampling loop time step',disable = True):
1050 | img = img.requires_grad_()
1051 |
1052 | time_cond = torch.full(
1053 | (batch,), time, device=device, dtype=torch.long)
1054 | preds = self.model_predictions(x_input, img, time_cond, task)
1055 |
1056 | pred_res = preds.pred_res
1057 | pred_noise = preds.pred_noise
1058 | x_start = preds.pred_x_start
1059 |
1060 | if time_next < 0:
1061 | img = x_start
1062 | if not last:
1063 | img_list.append(img)
1064 | continue
1065 |
1066 | alpha_cumsum = self.alphas_cumsum[time]
1067 | alpha_cumsum_next = self.alphas_cumsum[time_next]
1068 | alpha = alpha_cumsum-alpha_cumsum_next
1069 | delta_cumsum = self.delta_cumsum[time]
1070 | delta_cumsum_next = self.delta_cumsum[time_next]
1071 | delta = delta_cumsum-delta_cumsum_next
1072 | betas2_cumsum = self.betas2_cumsum[time]
1073 | betas2_cumsum_next = self.betas2_cumsum[time_next]
1074 | betas2 = betas2_cumsum-betas2_cumsum_next
1075 | betas = betas2.sqrt()
1076 | betas_cumsum = self.betas_cumsum[time]
1077 | betas_cumsum_next = self.betas_cumsum[time_next]
1078 |
1079 | betas2_div_betas_cumsum = betas2 / betas_cumsum
1080 |
1081 | norm_grad, norm = self.grad_and_value(x_prev=img, x_0_hat=x_start, y0=x_input, x_res = pred_res)
1082 | pred_noise = pred_noise + betas_cumsum / norm * norm_grad
1083 |
1084 | if type == "use_pred_noise":
1085 | img = img - alpha*pred_res + delta*x_input - betas2_div_betas_cumsum * pred_noise
1086 |
1087 | elif type == "use_x_start":
1088 | img = q*img + \
1089 | (1-q)*x_start + \
1090 | (alpha_cumsum_next-alpha_cumsum*q)*pred_res + \
1091 | (delta_cumsum*q-delta_cumsum_next)*x_input + \
1092 | sigma2.sqrt()*noise
1093 | elif type == "special_eta_0":
1094 | img = img - alpha*pred_res - \
1095 | (betas_cumsum-betas_cumsum_next)*pred_noise
1096 | elif type == "special_eta_1":
1097 | img = img - alpha*pred_res - betas2/betas_cumsum*pred_noise + \
1098 | betas*betas2_cumsum_next.sqrt()/betas_cumsum*noise
1099 |
1100 | if not last:
1101 | img_list.append(img)
1102 |
1103 | if self.condition:
1104 | if not last:
1105 | img_list = [input_add_noise]+img_list
1106 | else:
1107 | img_list = [input_add_noise, img]
1108 | return unnormalize_to_zero_to_one(img_list)
1109 | else:
1110 | if not last:
1111 | img_list = img_lists
1112 | else:
1113 | img_list = [img]
1114 | return unnormalize_to_zero_to_one(img_list)
1115 |
1116 | def second_order_UPS(self, x_input, shape, last=True, task=None):
1117 | x_input = x_input[0]
1118 | batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[
1119 | 0], self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective
1120 |
1121 | times = torch.linspace(-1, total_timesteps - 1,steps=sampling_timesteps + 1)
1122 | times = list(reversed(times.int().tolist()))
1123 | time_pairs = list(zip(times[:-1], times[1:]))
1124 |
1125 | if self.condition:
1126 | img = (1-self.delta_cumsum[-1]) * x_input + math.sqrt(self.sum_scale) * torch.randn(shape, device=device)
1127 | input_add_noise = img
1128 | else:
1129 | img = torch.randn(shape, device=device)
1130 |
1131 | x_start = None
1132 | type = "use_pred_noise"
1133 |
1134 | if not last:
1135 | img_list = []
1136 |
1137 | r = 0.5
1138 | for time, time_next in tqdm(time_pairs, desc='sampling loop time step',disable = True):
1139 | img = img.requires_grad_()
1140 |
1141 | time_internal = int((1-r) * time + r * time_next)
1142 | time_cond = torch.full((batch,), time, device=device, dtype=torch.long)
1143 | time_cond_internal = torch.full((batch,), time_internal, device=device, dtype=torch.long)
1144 | time_cond_next = torch.full((batch,), time_next, device=device, dtype=torch.long)
1145 |
1146 | alpha_cumsum = self.alphas_cumsum[time]
1147 | alpha_cumsum_internal = self.alphas_cumsum[time_internal]
1148 | alpha_cumsum_next = self.alphas_cumsum[time_next]
1149 | alpha_u = alpha_cumsum-alpha_cumsum_internal
1150 | alpha_t = alpha_cumsum-alpha_cumsum_next
1151 | delta_cumsum = self.delta_cumsum[time]
1152 | delta_cumsum_internal = self.delta_cumsum[time_internal]
1153 | delta_cumsum_next = self.delta_cumsum[time_next]
1154 | delta_u = delta_cumsum-delta_cumsum_internal
1155 | delta_t = delta_cumsum-delta_cumsum_next
1156 | betas2_cumsum = self.betas2_cumsum[time]
1157 | betas2_cumsum_internal = self.betas2_cumsum[time_internal]
1158 | betas2_cumsum_next = self.betas2_cumsum[time_next]
1159 | betas2_u = betas2_cumsum-betas2_cumsum_internal
1160 | betas2_t = betas2_cumsum-betas2_cumsum_next
1161 | betas_cumsum = self.betas_cumsum[time]
1162 | betas_cumsum_internal = self.betas_cumsum[time_internal]
1163 | betas_cumsum_next = self.betas_cumsum[time_next]
1164 | betas_u = betas_cumsum-betas_cumsum_internal
1165 | betas_t = betas_cumsum-betas_cumsum_next
1166 |
1167 | preds_time_cond = self.model_predictions(x_input, img, time_cond)
1168 | pred_res_time_cond = preds_time_cond.pred_res
1169 | pred_noise_time_cond = preds_time_cond.pred_noise
1170 | x_start_time_cond = preds_time_cond.pred_x_start
1171 |
1172 | if time_next < 0:
1173 | img = x_start_time_cond
1174 | if not last:
1175 | img_list.append(img)
1176 | continue
1177 |
1178 | norm_grad, norm = self.grad_and_value(x_prev=img, x_0_hat=x_start_time_cond, y0=x_input, x_res = pred_res_time_cond)
1179 | pred_noise_time_cond = pred_noise_time_cond + betas_cumsum / norm * norm_grad
1180 |
1181 | img_u = img + delta_u*x_input - alpha_u*pred_res_time_cond - betas_u * pred_noise_time_cond
1182 |
1183 | img_u = img_u.clone().detach_()
1184 | img_u = img_u.requires_grad_()
1185 |
1186 | preds_time_cond_internal = self.model_predictions(x_input, img_u, time_cond_internal)
1187 | pred_res_time_cond_internal = preds_time_cond_internal.pred_res
1188 | pred_noise_time_cond_internal = preds_time_cond_internal.pred_noise
1189 | x_start_time_cond_internal = preds_time_cond_internal.pred_x_start
1190 |
1191 | norm_grad_internal, norm_internal = self.grad_and_value(x_prev=img_u, x_0_hat=x_start_time_cond_internal, y0=x_input, x_res = pred_res_time_cond_internal)
1192 | pred_noise_time_cond_internal = pred_noise_time_cond_internal + betas_cumsum_internal / norm_internal * norm_grad_internal
1193 |
1194 | img_target = img + delta_t*x_input - alpha_t*pred_res_time_cond - betas_t * pred_noise_time_cond
1195 | - 1/(2*r)*alpha_t * (pred_res_time_cond_internal - pred_res_time_cond)
1196 | - 1/(2*r)*betas_t * (pred_noise_time_cond_internal - pred_noise_time_cond)
1197 |
1198 | img = img_target.clone().detach_()
1199 |
1200 | if not last:
1201 | img_list.append(img)
1202 |
1203 | if self.condition:
1204 | if not last:
1205 | img_list = [input_add_noise]+img_list
1206 | else:
1207 | img_list = [input_add_noise, img]
1208 | return unnormalize_to_zero_to_one(img_list)
1209 | else:
1210 | if not last:
1211 | img_list = img_list
1212 | else:
1213 | img_list = [img]
1214 | return unnormalize_to_zero_to_one(img_list)
1215 |
1216 | def third_order_UPS(self, x_input, shape, last=True, task=None):
1217 | x_input = x_input[0]
1218 | batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[
1219 | 0], self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective
1220 |
1221 | times = torch.linspace(-1, total_timesteps - 1,steps=sampling_timesteps + 1)
1222 | times = list(reversed(times.int().tolist()))
1223 | time_pairs = list(zip(times[:-1], times[1:]))
1224 |
1225 | if self.condition:
1226 | img = (1-self.delta_cumsum[-1]) * x_input + math.sqrt(self.sum_scale) * torch.randn(shape, device=device)
1227 | input_add_noise = img
1228 | else:
1229 | img = torch.randn(shape, device=device)
1230 |
1231 | x_start = None
1232 | type = "use_pred_noise"
1233 |
1234 | if not last:
1235 | img_list = []
1236 |
1237 | r1 = 1/3
1238 | r2 = 2/3
1239 | for time, time_next in tqdm(time_pairs, desc='sampling loop time step',disable = True):
1240 | img = img.requires_grad_()
1241 |
1242 | time_u = int((1-r1) * time + r1 * time_next)
1243 | time_s = int((1-r2) * time + r2 * time_next)
1244 | time_cond = torch.full((batch,), time, device=device, dtype=torch.long)
1245 | time_cond_u = torch.full((batch,), time_u, device=device, dtype=torch.long)
1246 | time_cond_s = torch.full((batch,), time_s, device=device, dtype=torch.long)
1247 | time_cond_next = torch.full((batch,), time_next, device=device, dtype=torch.long)
1248 |
1249 | alpha_cumsum = self.alphas_cumsum[time]
1250 | alpha_cumsum_u = self.alphas_cumsum[time_u]
1251 | alpha_cumsum_s = self.alphas_cumsum[time_s]
1252 | alpha_cumsum_next = self.alphas_cumsum[time_next]
1253 | alpha_u = alpha_cumsum-alpha_cumsum_u
1254 | alpha_s = alpha_cumsum-alpha_cumsum_s
1255 | alpha_t = alpha_cumsum-alpha_cumsum_next
1256 | delta_cumsum = self.delta_cumsum[time]
1257 | delta_cumsum_u = self.delta_cumsum[time_u]
1258 | delta_cumsum_s = self.delta_cumsum[time_s]
1259 | delta_cumsum_next = self.delta_cumsum[time_next]
1260 | delta_u = delta_cumsum-delta_cumsum_u
1261 | delta_s = delta_cumsum-delta_cumsum_s
1262 | delta_t = delta_cumsum-delta_cumsum_next
1263 | betas2_cumsum = self.betas2_cumsum[time]
1264 | betas2_cumsum_u = self.betas2_cumsum[time_u]
1265 | betas2_cumsum_s = self.betas2_cumsum[time_s]
1266 | betas2_cumsum_next = self.betas2_cumsum[time_next]
1267 | betas2_u = betas2_cumsum-betas2_cumsum_u
1268 | betas2_s = betas2_cumsum-betas2_cumsum_s
1269 | betas2_t = betas2_cumsum-betas2_cumsum_next
1270 | betas_cumsum = self.betas_cumsum[time]
1271 | betas_cumsum_u = self.betas_cumsum[time_u]
1272 | betas_cumsum_s = self.betas_cumsum[time_s]
1273 | betas_cumsum_next = self.betas_cumsum[time_next]
1274 | betas_u = betas_cumsum-betas_cumsum_u
1275 | betas_s = betas_cumsum-betas_cumsum_s
1276 | betas_t = betas_cumsum-betas_cumsum_next
1277 |
1278 | preds_time_cond = self.model_predictions(x_input, img, time_cond)
1279 | pred_res_time_cond = preds_time_cond.pred_res
1280 | pred_noise_time_cond = preds_time_cond.pred_noise
1281 | x_start_time_cond = preds_time_cond.pred_x_start
1282 |
1283 | if time_next < 0:
1284 | img = x_start_time_cond
1285 | if not last:
1286 | img_list.append(img)
1287 | continue
1288 |
1289 | norm_grad, norm = self.grad_and_value(x_prev=img, x_0_hat=x_start_time_cond, y0=x_input, x_res = pred_res_time_cond)
1290 | pred_noise_time_cond = pred_noise_time_cond + betas_cumsum / norm * norm_grad
1291 | img = img.detach_()
1292 |
1293 | img_u = img + delta_u*x_input - alpha_u*pred_res_time_cond - betas_u * pred_noise_time_cond
1294 | img_u = img_u.clone().detach_()
1295 | img_u = img_u.requires_grad_()
1296 | preds_time_cond_u = self.model_predictions(x_input, img_u, time_cond_u)
1297 | pred_res_time_cond_u = preds_time_cond_u.pred_res
1298 | pred_noise_time_cond_u = preds_time_cond_u.pred_noise
1299 | x_start_time_cond_u = preds_time_cond_u.pred_x_start
1300 |
1301 | norm_grad_u, norm_u = self.grad_and_value(x_prev=img_u, x_0_hat=x_start_time_cond_u, y0=x_input, x_res = pred_res_time_cond_u)
1302 | pred_noise_time_cond_u = pred_noise_time_cond_u + betas_cumsum_u / norm_u * norm_grad_u
1303 | img_u = img_u.detach_()
1304 |
1305 | img_s = img + delta_s*x_input - alpha_s*pred_res_time_cond - betas_s * pred_noise_time_cond
1306 | - r2/(2*r1)*alpha_s * (pred_res_time_cond_u - pred_res_time_cond)
1307 | - r2/(2*r1)*betas_s * (pred_noise_time_cond_u - pred_noise_time_cond)
1308 |
1309 | img_s = img_s.clone().detach_()
1310 | img_s = img_s.requires_grad_()
1311 | preds_time_cond_s = self.model_predictions(x_input, img_s, time_cond_s)
1312 | pred_res_time_cond_s = preds_time_cond_s.pred_res
1313 | pred_noise_time_cond_s = preds_time_cond_s.pred_noise
1314 | x_start_time_cond_s = preds_time_cond_s.pred_x_start
1315 |
1316 | norm_grad_s, norm_s = self.grad_and_value(x_prev=img_s, x_0_hat=x_start_time_cond_s, y0=x_input, x_res = pred_res_time_cond_s)
1317 | pred_noise_time_cond_s = pred_noise_time_cond_s + betas_cumsum_s / norm_s * norm_grad_s
1318 | img_s = img_s.detach_()
1319 |
1320 | D2_res = (2/(r1*r2*(r2-r1)))*(r1*pred_res_time_cond_s - r2*pred_res_time_cond_u + (r2-r1)*pred_res_time_cond)
1321 | D2_eps = (2/(r1*r2*(r2-r1)))*(r1*pred_noise_time_cond_s - r2*pred_noise_time_cond_u + (r2-r1)*pred_noise_time_cond)
1322 |
1323 |
1324 | img = img + delta_t*x_input - alpha_t*pred_res_time_cond - betas_t * pred_noise_time_cond
1325 | - 1/(2*r1)*alpha_t * (pred_res_time_cond_u - pred_res_time_cond) - 1/(2*r1)*betas_t * (pred_noise_time_cond_u - pred_noise_time_cond)
1326 | - 1/6*alpha_t*D2_res - 1/6*betas_t*D2_eps
1327 |
1328 | img = img.clone().detach_()
1329 |
1330 | if not last:
1331 | img_list.append(img)
1332 |
1333 | if self.condition:
1334 | if not last:
1335 | img_list = [input_add_noise]+img_list
1336 | else:
1337 | img_list = [input_add_noise, img]
1338 | return unnormalize_to_zero_to_one(img_list)
1339 | else:
1340 | if not last:
1341 | img_list = img_list
1342 | else:
1343 | img_list = [img]
1344 | return unnormalize_to_zero_to_one(img_list)
1345 |
1346 |
1347 | def sample(self, x_input=None, batch_size=16, last=True, task=None):
1348 | image_size, channels = self.image_size, self.channels
1349 | sample_fn = self.second_order_UPS
1350 | # self.first_order_sample 1st
1351 | # self.first_order_UPS UPS_1st
1352 | # self.second_order_sample 2nd
1353 | # self.second_order_UPS UPS_2nd
1354 | # self.third_order_sample 3rd
1355 | # self.third_order_UPS UPS_3rd
1356 | if self.condition:
1357 | x_input = 2 * x_input - 1
1358 | x_input = x_input.unsqueeze(0)
1359 |
1360 | batch_size, channels, h, w = x_input[0].shape
1361 | size = (batch_size, channels, h, w)
1362 | else:
1363 | size = (batch_size, channels, image_size, image_size)
1364 |
1365 | gen_samples = sample_fn(x_input, size, last=last, task=task)[1]
1366 | gen_samples = gen_samples.detach()
1367 | return gen_samples
1368 |
1369 | def q_sample(self, x_start, x_res, condition, t, noise=None):
1370 | noise = default(noise, lambda: torch.randn_like(x_start))
1371 |
1372 | return (
1373 | x_start+extract(self.alphas_cumsum, t, x_start.shape) * x_res +
1374 | extract(self.betas_cumsum, t, x_start.shape) * noise -
1375 | extract(self.delta_cumsum, t, x_start.shape) * condition
1376 | )
1377 |
1378 | @property
1379 | def loss_fn(self, loss_type='l1'):
1380 | if loss_type == 'l1':
1381 | return F.l1_loss
1382 | elif loss_type == 'l2':
1383 | return F.mse_loss
1384 | else:
1385 | raise ValueError(f'invalid loss type {loss_type}')
1386 |
1387 | def p_losses(self, imgs, t, noise=None):
1388 | if isinstance(imgs, list):
1389 | x_input = 2 * imgs[1] - 1
1390 | x_start = 2 * imgs[0] - 1
1391 | task = None
1392 |
1393 | noise = default(noise, lambda: torch.randn_like(x_start))
1394 | x_res = x_input - x_start
1395 | b, c, h, w = x_start.shape
1396 | x = self.q_sample(x_start, x_res, x_input, t, noise=noise)
1397 | if not self.condition:
1398 | x_in = x
1399 | else:
1400 | x_in = torch.cat((x, x_input), dim=1)
1401 |
1402 | model_out = self.model(x_in,[t,t])
1403 |
1404 | target = []
1405 | if self.objective == 'pred_res_noise':
1406 | target.append(x_res)
1407 | target.append(noise)
1408 |
1409 | pred_res = model_out[0]
1410 | pred_noise = model_out[1]
1411 | elif self.objective == 'pred_x0_noise':
1412 | target.append(x_start)
1413 | target.append(noise)
1414 |
1415 | pred_res = x_input-model_out[0]
1416 | pred_noise = model_out[1]
1417 | elif self.objective == "pred_noise":
1418 | target.append(noise)
1419 | pred_noise = model_out[0]
1420 |
1421 | elif self.objective == "pred_res":
1422 | target.append(x_res)
1423 | pred_res = model_out[0]
1424 |
1425 | else:
1426 | raise ValueError(f'unknown objective {self.objective}')
1427 |
1428 | u_loss = False
1429 | if u_loss:
1430 | x_u = self.q_posterior_from_res_noise(pred_res, pred_noise, x, t,x_input)
1431 | u_gt = self.q_posterior_from_res_noise(x_res, noise, x, t,x_input)
1432 | loss = 10000*self.loss_fn(x_u, u_gt, reduction='none')
1433 | return [loss]
1434 | else:
1435 | loss_list = []
1436 | for i in range(len(model_out)):
1437 | loss = self.loss_fn(model_out[i], target[i], reduction='none')
1438 | loss = reduce(loss, 'b ... -> b (...)', 'mean').mean()
1439 | loss_list.append(loss)
1440 | return loss_list
1441 |
1442 | def forward(self, img, *args, **kwargs):
1443 | if isinstance(img, list):
1444 | b, c, h, w, device, img_size, = * \
1445 | img[0].shape, img[0].device, self.image_size
1446 | else:
1447 | b, c, h, w, device, img_size, = *img.shape, img.device, self.image_size
1448 | t = torch.randint(0, self.num_timesteps, (b,), device=device).long()
1449 |
1450 | return self.p_losses(img, t, *args, **kwargs)
1451 |
1452 |
1453 | if __name__ == '__main__':
1454 | print('Hello World')
--------------------------------------------------------------------------------