├── datasets
├── __init__.py
├── data_augment.py
└── dataset.py
├── models
├── __init__.py
├── restoration.py
├── ddm.py
├── decom.py
└── unet.py
├── Figures
├── visual.jpg
└── pipeline.jpg
├── utils
├── __init__.py
├── sampling.py
├── logging.py
└── optimize.py
├── configs
└── unsupervised.yml
├── requirements.txt
├── evaluate.py
├── train.py
└── README.md
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from datasets.dataset import *
2 |
3 | __all__ = ["LLdataset"]
4 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from models.ddm import *
2 | from models.restoration import *
3 |
--------------------------------------------------------------------------------
/Figures/visual.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JianghaiSCU/LightenDiffusion/HEAD/Figures/visual.jpg
--------------------------------------------------------------------------------
/Figures/pipeline.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JianghaiSCU/LightenDiffusion/HEAD/Figures/pipeline.jpg
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from utils.logging import *
2 | from utils.sampling import *
3 | from utils.optimize import *
4 |
--------------------------------------------------------------------------------
/utils/sampling.py:
--------------------------------------------------------------------------------
1 | import torch
2 | # This script is adapted from the following repository: https://github.com/ermongroup/ddim
3 |
4 |
5 | def data_transform(X):
6 | return 2 * X - 1.0
7 |
8 |
9 | def inverse_data_transform(X):
10 | return torch.clamp((X + 1.0) / 2.0, 0.0, 1.0)
--------------------------------------------------------------------------------
/utils/logging.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import shutil
3 | import os
4 | import torchvision.utils as tvu
5 |
6 |
7 | def save_image(img, file_directory):
8 | if not os.path.exists(os.path.dirname(file_directory)):
9 | os.makedirs(os.path.dirname(file_directory))
10 | tvu.save_image(img, file_directory)
11 |
12 |
13 | def save_checkpoint(state, filename):
14 | if not os.path.exists(os.path.dirname(filename)):
15 | os.makedirs(os.path.dirname(filename))
16 | torch.save(state, filename + '.pth.tar')
17 |
18 |
19 | def load_checkpoint(path, device):
20 | if device is None:
21 | return torch.load(path)
22 | else:
23 | return torch.load(path, map_location=device)
24 |
--------------------------------------------------------------------------------
/utils/optimize.py:
--------------------------------------------------------------------------------
1 | import torch.optim as optim
2 |
3 |
4 | def get_optimizer(config, parameters):
5 | if config.optim.optimizer == 'Adam':
6 | optimizer = optim.Adam(parameters, lr=config.optim.lr, weight_decay=config.optim.weight_decay,
7 | betas=(0.9, 0.999), amsgrad=config.optim.amsgrad, eps=config.optim.eps)
8 |
9 | elif config.optim.optimizer == 'RMSProp':
10 | optimizer = optim.RMSprop(parameters, lr=config.optim.lr, weight_decay=config.optim.weight_decay)
11 | elif config.optim.optimizer == 'SGD':
12 | optimizer = optim.SGD(parameters, lr=config.optim.lr, momentum=0.9)
13 | else:
14 | raise NotImplementedError('Optimizer {} not understood.'.format(config.optim.optimizer))
15 |
16 | return optimizer
17 |
--------------------------------------------------------------------------------
/configs/unsupervised.yml:
--------------------------------------------------------------------------------
1 | data:
2 | type: "LLdataset"
3 | train_dataset: "unpaired"
4 | val_dataset: "LOLv1"
5 | patch_size: 512
6 | channels: 3
7 | num_workers: 4
8 | data_dir: "/data/Image_restoration/LSRW_dataset"
9 | ckpt_dir: "ckpt/stage2"
10 | conditional: True
11 |
12 | model:
13 | in_channels: 3
14 | out_ch: 3
15 | ch: 64
16 | ch_mult: [1, 2, 3, 4]
17 | num_res_blocks: 2
18 | dropout: 0.0
19 | ema_rate: 0.999
20 | ema: True
21 | resamp_with_conv: True
22 |
23 | diffusion:
24 | beta_schedule: linear
25 | beta_start: 0.0001
26 | beta_end: 0.02
27 | num_diffusion_timesteps: 1000
28 | num_sampling_timesteps: 20
29 |
30 | training:
31 | batch_size: 12
32 | n_epochs: 100
33 | validation_freq: 2000
34 |
35 | sampling:
36 | batch_size: 1
37 |
38 | optim:
39 | weight_decay: 0.000
40 | optimizer: "Adam"
41 | lr: 0.00002
42 | amsgrad: False
43 | eps: 0.00000001
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | astunparse==1.6.3
2 | boto3==1.18.47
3 | botocore==1.21.47
4 | cffi==1.15.0
5 | coloredlogs==15.0
6 | configargparse==1.5.3
7 | cupy-cuda102==7.8.0
8 | cvxpy==1.1.18
9 | docutils==0.15.2
10 | ecos==2.0.8
11 | fastrlock==0.8
12 | flatbuffers==2.0
13 | gast==0.3.3
14 | gdown==4.2.0
15 | google-pasta==0.2.0
16 | grpcio==1.40.0
17 | humanfriendly==10.0
18 | influxdb==5.3.1
19 | jmespath==0.10.0
20 | jpeg4py==0.1.4
21 | jupyterlab==1.0.0
22 | jupyterlab-server==1.0.0
23 | keras==2.7.0
24 | keras-preprocessing==1.1.2
25 | libclang==12.0.0
26 | loguru==0.5.3
27 | ninja==1.10.2.2
28 | nori2==1.11.8
29 | opencv-python==4.5.3.56
30 | opt-einsum==3.3.0
31 | osqp==0.6.2.post4
32 | pip==21.3.1
33 | protobuf==3.18.0
34 | pycocotools==2.0.4
35 | pypng==0.0.21
36 | python-statemachine==0.8.0
37 | qdldl==0.1.5.post0
38 | redis==3.5.3
39 | refile==5.8.3.post1
40 | s3transfer==0.5.0
41 | scipy==1.4.1
42 | scs==3.0.0
43 | setuptools==59.4.0
44 | smart-open==5.2.1
45 | tabulate==0.8.9
46 | tensorboard==2.2.2
47 | tensorflow==2.2.0
48 | tensorflow-estimator==2.2.0
49 | tensorflow-io-gcs-filesystem==0.22.0
50 | termcolor==1.1.0
51 | timm==0.4.12
52 | torch==1.7.1
53 | torchsummary==1.5.1
54 | torchvision==0.8.2
55 | urllib3==1.25.11
56 | kornia==0.6.7
57 | accelerate==0.15.0
--------------------------------------------------------------------------------
/models/restoration.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import utils
4 | import os
5 | import time
6 | import torch.nn.functional as F
7 |
8 |
9 | class DiffusiveRestoration:
10 | def __init__(self, diffusion, args, config):
11 | super(DiffusiveRestoration, self).__init__()
12 | self.args = args
13 | self.config = config
14 | self.diffusion = diffusion
15 |
16 | if os.path.isfile(args.resume):
17 | self.diffusion.load_ddm_ckpt(args.resume, ema=False)
18 | self.diffusion.model.eval()
19 | else:
20 | print('Pre-trained model path is missing!')
21 |
22 | def restore(self, val_loader):
23 | image_folder = os.path.join(self.args.image_folder, self.config.data.val_dataset)
24 | with torch.no_grad():
25 | for i, (x, y) in enumerate(val_loader):
26 |
27 | x_cond = x[:, :3, :, :].to(self.diffusion.device)
28 | b, c, h, w = x_cond.shape
29 | img_h_64 = int(64 * np.ceil(h / 64.0))
30 | img_w_64 = int(64 * np.ceil(w / 64.0))
31 | x_cond = F.pad(x_cond, (0, img_w_64 - w, 0, img_h_64 - h), 'reflect')
32 |
33 | t1 = time.time()
34 | pred_x = self.diffusion.model(torch.cat((x_cond, x_cond),
35 | dim=1))["pred_x"][:, :, :h, :w]
36 | t2 = time.time()
37 |
38 | utils.logging.save_image(pred_x, os.path.join(image_folder, f"{y[0]}"))
39 | print(f"processing image {y[0]}, time={t2 - t1}")
40 |
41 |
42 |
43 |
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import random
4 | import socket
5 | import yaml
6 | import torch
7 | import torch.backends.cudnn as cudnn
8 | import numpy as np
9 | import torchvision
10 | import models
11 | import datasets
12 | import utils
13 | from models import DenoisingDiffusion, DiffusiveRestoration
14 |
15 |
16 | def parse_args_and_config():
17 | parser = argparse.ArgumentParser(description='Latent-Retinex Diffusion Models')
18 | parser.add_argument("--config", default='unsupervised.yml', type=str,
19 | help="Path to the config file")
20 | parser.add_argument('--mode', type=str, default='evaluation', help='training or evaluation')
21 | parser.add_argument('--resume', default='ckpt/stage2/stage2_weight.pth.tar', type=str,
22 | help='Path for the diffusion model checkpoint to load for evaluation')
23 | parser.add_argument("--image_folder", default='results/', type=str,
24 | help="Location to save restored images")
25 | args = parser.parse_args()
26 |
27 | with open(os.path.join("configs", args.config), "r") as f:
28 | config = yaml.safe_load(f)
29 | new_config = dict2namespace(config)
30 |
31 | return args, new_config
32 |
33 |
34 | def dict2namespace(config):
35 | namespace = argparse.Namespace()
36 | for key, value in config.items():
37 | if isinstance(value, dict):
38 | new_value = dict2namespace(value)
39 | else:
40 | new_value = value
41 | setattr(namespace, key, new_value)
42 | return namespace
43 |
44 |
45 | def main():
46 | args, config = parse_args_and_config()
47 |
48 | # setup device to run
49 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
50 | print("Using device: {}".format(device))
51 | config.device = device
52 |
53 | if torch.cuda.is_available():
54 | print('Note: Currently supports evaluations (restoration) when run only on a single GPU!')
55 |
56 | print("=> using dataset '{}'".format(config.data.val_dataset))
57 | DATASET = datasets.__dict__[config.data.type](config)
58 | _, val_loader = DATASET.get_loaders()
59 |
60 | # create model
61 | print("=> creating denoising-diffusion model")
62 | diffusion = DenoisingDiffusion(args, config)
63 | model = DiffusiveRestoration(diffusion, args, config)
64 | model.restore(val_loader)
65 |
66 |
67 | if __name__ == '__main__':
68 | main()
69 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import random
4 | import socket
5 | import yaml
6 | import torch
7 | import torch.backends.cudnn as cudnn
8 | import torch.utils.data
9 | import numpy as np
10 | import torchvision
11 | import models
12 | import datasets
13 | import utils
14 | from models import DenoisingDiffusion
15 |
16 |
17 | def parse_args_and_config():
18 | parser = argparse.ArgumentParser(description='Latent-Retinex Diffusion Models')
19 | parser.add_argument("--config", default='unsupervised.yml', type=str,
20 | help="Path to the config file")
21 | parser.add_argument('--mode', type=str, default='training', help='training or evaluation')
22 | parser.add_argument('--resume', default='', type=str,
23 | help='Path for checkpoint to load and resume')
24 | parser.add_argument("--image_folder", default='results/', type=str,
25 | help="Location to save restored validation image patches")
26 | parser.add_argument('--seed', default=230, type=int, metavar='N',
27 | help='Seed for initializing training (default: 230)')
28 | args = parser.parse_args()
29 |
30 | with open(os.path.join("configs", args.config), "r") as f:
31 | config = yaml.safe_load(f)
32 | new_config = dict2namespace(config)
33 |
34 | return args, new_config
35 |
36 |
37 | def dict2namespace(config):
38 | namespace = argparse.Namespace()
39 | for key, value in config.items():
40 | if isinstance(value, dict):
41 | new_value = dict2namespace(value)
42 | else:
43 | new_value = value
44 | setattr(namespace, key, new_value)
45 | return namespace
46 |
47 |
48 | def main():
49 | args, config = parse_args_and_config()
50 | # setup device to run
51 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
52 | print("Using device: {}".format(device))
53 | config.device = device
54 |
55 | # set random seed
56 | torch.manual_seed(args.seed)
57 | np.random.seed(args.seed)
58 | if torch.cuda.is_available():
59 | torch.cuda.manual_seed_all(args.seed)
60 | torch.backends.cudnn.benchmark = True
61 |
62 | # data loading
63 | print("=> using dataset '{}'".format(config.data.train_dataset))
64 | DATASET = datasets.__dict__[config.data.type](config)
65 |
66 | # create model
67 | print("=> creating denoising-diffusion model...")
68 | diffusion = DenoisingDiffusion(args, config)
69 | diffusion.train(DATASET)
70 |
71 |
72 | if __name__ == "__main__":
73 | main()
74 |
--------------------------------------------------------------------------------
/datasets/data_augment.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torchvision.transforms as transforms
3 | import torchvision.transforms.functional as F
4 |
5 |
6 | class PairRandomCrop(transforms.RandomCrop):
7 |
8 | def __call__(self, image, label):
9 |
10 | if self.padding is not None:
11 | image = F.pad(image, self.padding, self.fill, self.padding_mode)
12 | label = F.pad(label, self.padding, self.fill, self.padding_mode)
13 |
14 | # pad the width if needed
15 | if self.pad_if_needed and image.size[0] < self.size[1]:
16 | image = F.pad(image, (self.size[1] - image.size[0], 0), self.fill, self.padding_mode)
17 | label = F.pad(label, (self.size[1] - label.size[0], 0), self.fill, self.padding_mode)
18 | # pad the height if needed
19 | if self.pad_if_needed and image.size[1] < self.size[0]:
20 | image = F.pad(image, (0, self.size[0] - image.size[1]), self.fill, self.padding_mode)
21 | label = F.pad(label, (0, self.size[0] - image.size[1]), self.fill, self.padding_mode)
22 |
23 | i, j, h, w = self.get_params(image, self.size)
24 |
25 | return F.crop(image, i, j, h, w), F.crop(label, i, j, h, w)
26 |
27 |
28 | class PairCompose(transforms.Compose):
29 | def __call__(self, image, label):
30 | for t in self.transforms:
31 | image, label = t(image, label)
32 | return image, label
33 |
34 |
35 | class PairRandomHorizontalFilp(transforms.RandomHorizontalFlip):
36 | def __call__(self, img, label):
37 | """
38 | Args:
39 | img (PIL Image): Image to be flipped.
40 |
41 | Returns:
42 | PIL Image: Randomly flipped image.
43 | """
44 | if random.random() < self.p:
45 | return F.hflip(img), F.hflip(label)
46 | return img, label
47 |
48 |
49 | class PairRandomVerticalFlip(transforms.RandomVerticalFlip):
50 | def __call__(self, img, label):
51 | """
52 | Args:
53 | img (PIL Image): Image to be flipped.
54 |
55 | Returns:
56 | PIL Image: Randomly flipped image.
57 | """
58 | if random.random() < self.p:
59 | return F.vflip(img), F.vflip(label)
60 | return img, label
61 |
62 |
63 | class PairToTensor(transforms.ToTensor):
64 | def __call__(self, pic, label):
65 | """
66 | Args:
67 | pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
68 |
69 | Returns:
70 | Tensor: Converted image.
71 | """
72 | return F.to_tensor(pic), F.to_tensor(label)
73 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # [ECCV 2024] LightenDiffusion: Unsupervised Low-Light Image Enhancement with Latent-Retinex Diffusion Models [[Paper]](https://arxiv.org/pdf/2407.08939)
2 |
Hai Jiang1,5, Ao Luo2,5, Xiaohong Liu4, Songchen Han1, Shuaicheng Liu3,5
3 | 1.Sichuan University, 2.Southwest Jiaotong University,
4 | 3.University of Electronic Science and Technology of China,
5 | 4.Shanghai Jiaotong University, 5.Megvii Technology
6 |
7 | ## Pipeline
8 | 
9 |
10 | ## Dependencies
11 | ```
12 | pip install -r requirements.txt
13 | ````
14 |
15 | ## Download the raw training and evaluation datasets
16 | ### Paired datasets
17 | LOL dataset: Chen Wei, Wenjing Wang, Wenhan Yang, and Jiaying Liu. "Deep Retinex Decomposition for Low-Light Enhancement". BMVC, 2018. [[Baiduyun (extracted code: sdd0)]](https://pan.baidu.com/s/1spt0kYU3OqsQSND-be4UaA) [[Google Drive]](https://drive.google.com/file/d/18bs_mAREhLipaM2qvhxs7u7ff2VSHet2/view?usp=sharing)
18 |
19 | LSRW dataset: Jiang Hai, Zhu Xuan, Ren Yang, Yutong Hao, Fengzhu Zou, Fang Lin, and Songchen Han. "R2RNet: Low-light Image Enhancement via Real-low to Real-normal Network". Journal of Visual Communication and Image Representation, 2023. [[Baiduyun (extracted code: wmrr)]](https://pan.baidu.com/s/1XHWQAS0ZNrnCyZ-bq7MKvA)
20 |
21 | ### Unpaired datasets
22 | Please refer to [[Project Page of RetinexNet]](https://daooshee.github.io/BMVC2018website/).
23 |
24 | ## Pre-trained Models
25 | You can download our pre-trained model from [[Google Drive]](https://drive.google.com/drive/folders/1m3t15rWw76IDDWJ0exLOe5P0uEnjk3zl?usp=drive_link) and [[Baidu Yun (extracted code:cjzk)]](https://pan.baidu.com/s/1fPLVgnZbdY1n75Flq54bMQ)
26 |
27 | ## How to train?
28 | You need to modify ```datasets/dataset.py``` slightly for your environment, and then
29 | ```
30 | python train.py
31 | ```
32 |
33 | ## How to test?
34 | ```
35 | python evaluate.py
36 | ```
37 |
38 | ## Visual comparison
39 | 
40 |
41 | ## Citation
42 | If you use this code or ideas from the paper for your research, please cite our paper:
43 | ```
44 | @InProceedings{Jiang_2024_ECCV,
45 | author = {Jiang, Hai and Luo, Ao and Liu, Xiaohong and Han, Songchen and Liu, Shuaicheng},
46 | title = {LightenDiffusion: Unsupervised Low-Light Image Enhancement with Latent-Retinex Diffusion Models},
47 | booktitle = {European Conference on Computer Vision},
48 | year = {2024},
49 | pages = {}
50 | }
51 | ```
52 |
53 | ## Acknowledgement
54 | Part of the code is adapted from previous works: [WeatherDiff](https://github.com/IGITUGraz/WeatherDiffusion) and [MIMO-UNet](https://github.com/chosj95/MIMO-UNet). We thank all the authors for their contributions.
55 |
56 |
--------------------------------------------------------------------------------
/datasets/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.utils.data
4 | from PIL import Image
5 | from datasets.data_augment import PairCompose, PairToTensor, PairRandomHorizontalFilp
6 |
7 |
8 | class LLdataset:
9 | def __init__(self, config):
10 | self.config = config
11 |
12 | def get_loaders(self):
13 | train_dataset = AllWeatherDataset(self.config.data.data_dir,
14 | patch_size=self.config.data.patch_size,
15 | filelist='{}_train.txt'.format(self.config.data.train_dataset))
16 | val_dataset = AllWeatherDataset(self.config.data.data_dir,
17 | patch_size=self.config.data.patch_size,
18 | filelist='{}_val.txt'.format(self.config.data.val_dataset), train=False)
19 |
20 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=self.config.training.batch_size,
21 | shuffle=True, num_workers=self.config.data.num_workers,
22 | pin_memory=True)
23 | val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=self.config.sampling.batch_size,
24 | shuffle=False, num_workers=self.config.data.num_workers,
25 | pin_memory=True)
26 |
27 | return train_loader, val_loader
28 |
29 |
30 | class AllWeatherDataset(torch.utils.data.Dataset):
31 | def __init__(self, dir, patch_size, filelist=None, train=True):
32 | super().__init__()
33 |
34 | self.dir = dir
35 | self.file_list = filelist
36 | self.train_list = os.path.join(dir, self.file_list)
37 | with open(self.train_list) as f:
38 | contents = f.readlines()
39 | input_names = [i.strip() for i in contents]
40 |
41 | self.input_names = input_names
42 | self.patch_size = patch_size
43 |
44 | if train:
45 | self.transforms = PairCompose([
46 | PairRandomHorizontalFilp(),
47 | PairToTensor()
48 | ])
49 | else:
50 | self.transforms = PairCompose([
51 | PairToTensor()
52 | ])
53 |
54 | def get_images(self, index):
55 | input_name = self.input_names[index].replace('\n', '')
56 |
57 | low_img_name, high_img_name = input_name.split(' ')[0], input_name.split(' ')[1]
58 |
59 | img_id = low_img_name.split('/')[-1]
60 | low_img, high_img = Image.open(low_img_name), Image.open(high_img_name)
61 |
62 | low_img, high_img = self.transforms(low_img, high_img)
63 |
64 | return torch.cat([low_img, high_img], dim=0), img_id
65 |
66 | def __getitem__(self, index):
67 | res = self.get_images(index)
68 | return res
69 |
70 | def __len__(self):
71 | return len(self.input_names)
72 |
--------------------------------------------------------------------------------
/models/ddm.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.backends.cudnn as cudnn
7 | import torch.nn.functional as F
8 | import utils
9 | from models.unet import DiffusionUNet
10 | from models.decom import CTDN
11 |
12 |
13 | class EMAHelper(object):
14 | def __init__(self, mu=0.9999):
15 | self.mu = mu
16 | self.shadow = {}
17 |
18 | def register(self, module):
19 | if isinstance(module, nn.DataParallel):
20 | module = module.module
21 | for name, param in module.named_parameters():
22 | if param.requires_grad:
23 | self.shadow[name] = param.data.clone()
24 |
25 | def update(self, module):
26 | if isinstance(module, nn.DataParallel):
27 | module = module.module
28 | for name, param in module.named_parameters():
29 | if param.requires_grad:
30 | self.shadow[name].data = (1. - self.mu) * param.data + self.mu * self.shadow[name].data
31 |
32 | def ema(self, module):
33 | if isinstance(module, nn.DataParallel):
34 | module = module.module
35 | for name, param in module.named_parameters():
36 | if param.requires_grad:
37 | param.data.copy_(self.shadow[name].data)
38 |
39 | def ema_copy(self, module):
40 | if isinstance(module, nn.DataParallel):
41 | inner_module = module.module
42 | module_copy = type(inner_module)(inner_module.config).to(inner_module.config.device)
43 | module_copy.load_state_dict(inner_module.state_dict())
44 | module_copy = nn.DataParallel(module_copy)
45 | else:
46 | module_copy = type(module)(module.config).to(module.config.device)
47 | module_copy.load_state_dict(module.state_dict())
48 | self.ema(module_copy)
49 | return module_copy
50 |
51 | def state_dict(self):
52 | return self.shadow
53 |
54 | def load_state_dict(self, state_dict):
55 | self.shadow = state_dict
56 |
57 |
58 | def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
59 | def sigmoid(x):
60 | return 1 / (np.exp(-x) + 1)
61 |
62 | if beta_schedule == "quad":
63 | betas = (np.linspace(beta_start ** 0.5, beta_end ** 0.5, num_diffusion_timesteps, dtype=np.float64) ** 2)
64 | elif beta_schedule == "linear":
65 | betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
66 | elif beta_schedule == "const":
67 | betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
68 | elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
69 | betas = 1.0 / np.linspace(num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64)
70 | elif beta_schedule == "sigmoid":
71 | betas = np.linspace(-6, 6, num_diffusion_timesteps)
72 | betas = sigmoid(betas) * (beta_end - beta_start) + beta_start
73 | else:
74 | raise NotImplementedError(beta_schedule)
75 | assert betas.shape == (num_diffusion_timesteps,)
76 | return betas
77 |
78 |
79 | class Net(nn.Module):
80 | def __init__(self, args, config):
81 | super(Net, self).__init__()
82 |
83 | self.args = args
84 | self.config = config
85 | self.device = config.device
86 |
87 | self.Unet = DiffusionUNet(config)
88 | if self.args.mode == 'training':
89 | self.decom = self.load_stage1(CTDN(), 'ckpt/stage1')
90 | else:
91 | self.decom = CTDN()
92 |
93 | betas = get_beta_schedule(
94 | beta_schedule=config.diffusion.beta_schedule,
95 | beta_start=config.diffusion.beta_start,
96 | beta_end=config.diffusion.beta_end,
97 | num_diffusion_timesteps=config.diffusion.num_diffusion_timesteps,
98 | )
99 |
100 | self.betas = torch.from_numpy(betas).float()
101 | self.num_timesteps = self.betas.shape[0]
102 |
103 | @staticmethod
104 | def compute_alpha(beta, t):
105 | beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0)
106 | a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1)
107 | return a
108 |
109 | @staticmethod
110 | def load_stage1(model, model_dir):
111 | checkpoint = utils.logging.load_checkpoint(os.path.join(model_dir, 'stage1_weight.pth.tar'), 'cuda')
112 | model.load_state_dict(checkpoint['model'], strict=True)
113 | return model
114 |
115 | def sample_training(self, x_cond, b, eta=0.):
116 | skip = self.config.diffusion.num_diffusion_timesteps // self.config.diffusion.num_sampling_timesteps
117 | seq = range(0, self.config.diffusion.num_diffusion_timesteps, skip)
118 | n, c, h, w = x_cond.shape
119 | seq_next = [-1] + list(seq[:-1])
120 | x = torch.randn(n, c, h, w, device=self.device)
121 | xs = [x]
122 | for i, j in zip(reversed(seq), reversed(seq_next)):
123 | t = (torch.ones(n) * i).to(x.device)
124 | next_t = (torch.ones(n) * j).to(x.device)
125 | at = self.compute_alpha(b, t.long())
126 | at_next = self.compute_alpha(b, next_t.long())
127 | xt = xs[-1].to(x.device)
128 |
129 | et = self.Unet(torch.cat([x_cond, xt], dim=1), t)
130 | x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()
131 |
132 | c1 = eta * ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt()
133 | c2 = ((1 - at_next) - c1 ** 2).sqrt()
134 | xt_next = at_next.sqrt() * x0_t + c1 * torch.randn_like(x) + c2 * et
135 | xs.append(xt_next.to(x.device))
136 |
137 | return xs[-1]
138 |
139 | def forward(self, inputs):
140 | data_dict = {}
141 |
142 | b = self.betas.to(inputs.device)
143 |
144 | if self.training:
145 | output = self.decom(inputs, pred_fea=None)
146 | low_R, low_L, low_fea, high_L = output["low_R"], output["low_L"], \
147 | output["low_fea"], output["high_L"]
148 | low_condition_norm = utils.data_transform(low_fea)
149 |
150 | t = torch.randint(low=0, high=self.num_timesteps, size=(low_condition_norm.shape[0] // 2 + 1,)).to(
151 | self.device)
152 | t = torch.cat([t, self.num_timesteps - t - 1], dim=0)[:low_condition_norm.shape[0]].to(inputs.device)
153 | a = (1 - b).cumprod(dim=0).index_select(0, t).view(-1, 1, 1, 1)
154 |
155 | e = torch.randn_like(low_condition_norm)
156 |
157 | high_input_norm = utils.data_transform(low_R * high_L)
158 |
159 | x = high_input_norm * a.sqrt() + e * (1.0 - a).sqrt()
160 | noise_output = self.Unet(torch.cat([low_condition_norm, x], dim=1), t.float())
161 |
162 | pred_fea = self.sample_training(low_condition_norm, b)
163 | pred_fea = utils.inverse_data_transform(pred_fea)
164 | reference_fea = low_R * torch.pow(low_L, 0.2)
165 |
166 | data_dict["noise_output"] = noise_output
167 | data_dict["e"] = e
168 |
169 | data_dict["pred_fea"] = pred_fea
170 | data_dict["reference_fea"] = reference_fea
171 |
172 | else:
173 | output = self.decom(inputs, pred_fea=None)
174 | low_fea = output["low_fea"]
175 | low_condition_norm = utils.data_transform(low_fea)
176 |
177 | pred_fea = self.sample_training(low_condition_norm, b)
178 | pred_fea = utils.inverse_data_transform(pred_fea)
179 | pred_x = self.decom(inputs, pred_fea=pred_fea)["pred_img"]
180 | data_dict["pred_x"] = pred_x
181 |
182 | return data_dict
183 |
184 |
185 | class DenoisingDiffusion(object):
186 | def __init__(self, args, config):
187 | super().__init__()
188 | self.args = args
189 | self.config = config
190 | self.device = config.device
191 |
192 | self.model = Net(args, config)
193 | self.model.to(self.device)
194 | self.model = torch.nn.DataParallel(self.model, device_ids=range(torch.cuda.device_count()))
195 |
196 | self.ema_helper = EMAHelper()
197 | self.ema_helper.register(self.model)
198 |
199 | self.l2_loss = torch.nn.MSELoss()
200 | self.l1_loss = torch.nn.L1Loss()
201 |
202 | self.optimizer = utils.optimize.get_optimizer(self.config, self.model.parameters())
203 | self.start_epoch, self.step = 0, 0
204 |
205 | def load_ddm_ckpt(self, load_path, ema=False):
206 | checkpoint = utils.logging.load_checkpoint(load_path, None)
207 | self.model.load_state_dict(checkpoint['state_dict'], strict=True)
208 | if ema:
209 | self.ema_helper.ema(self.model)
210 | print("=> loaded checkpoint {} step {}".format(load_path, self.step))
211 |
212 | def train(self, DATASET):
213 | cudnn.benchmark = True
214 | train_loader, val_loader = DATASET.get_loaders()
215 |
216 | if os.path.isfile(self.args.resume):
217 | self.load_ddm_ckpt(self.args.resume)
218 |
219 | for name, param in self.model.named_parameters():
220 | if "decom" in name:
221 | param.requires_grad = False
222 | else:
223 | param.requires_grad = True
224 |
225 | for epoch in range(self.start_epoch, self.config.training.n_epochs):
226 | print('epoch: ', epoch)
227 | data_start = time.time()
228 | data_time = 0
229 | for i, (x, y) in enumerate(train_loader):
230 | x = x.flatten(start_dim=0, end_dim=1) if x.ndim == 5 else x
231 | self.model.train()
232 | self.step += 1
233 |
234 | x = x.to(self.device)
235 |
236 | output = self.model(x)
237 |
238 | noise_loss, scc_loss = self.noise_estimation_loss(output)
239 | loss = noise_loss + scc_loss
240 |
241 | data_time += time.time() - data_start
242 |
243 | if self.step % 10 == 0:
244 | print("step:{}, noise_loss:{:.5f} scc_loss:{:.5f} time:{:.5f}".
245 | format(self.step, noise_loss.item(),
246 | scc_loss.item(), data_time / (i + 1)))
247 |
248 | self.optimizer.zero_grad()
249 | loss.backward()
250 | self.optimizer.step()
251 | self.ema_helper.update(self.model)
252 | data_start = time.time()
253 |
254 | if self.step % self.config.training.validation_freq == 0 and self.step != 0:
255 | self.model.eval()
256 | self.sample_validation_patches(val_loader, self.step)
257 |
258 | utils.logging.save_checkpoint({'step': self.step,
259 | 'epoch': epoch + 1,
260 | 'state_dict': self.model.state_dict(),
261 | 'optimizer': self.optimizer.state_dict(),
262 | 'ema_helper': self.ema_helper.state_dict(),
263 | 'params': self.args,
264 | 'config': self.config},
265 | filename=os.path.join(self.config.data.ckpt_dir, 'model_latest'))
266 |
267 | def noise_estimation_loss(self, output):
268 | pred_fea, reference_fea = output["pred_fea"], output["reference_fea"]
269 | noise_output, e = output["noise_output"], output["e"]
270 | # ==================noise loss==================
271 | noise_loss = self.l2_loss(noise_output, e)
272 | # ==================scc loss==================
273 | scc_loss = 0.001 * self.l1_loss(pred_fea, reference_fea)
274 |
275 | return noise_loss, scc_loss
276 |
277 | def sample_validation_patches(self, val_loader, step):
278 | image_folder = os.path.join(self.args.image_folder,
279 | self.config.data.type + str(self.config.data.patch_size))
280 | self.model.eval()
281 |
282 | with torch.no_grad():
283 | print('Performing validation at step: {}'.format(step))
284 | for i, (x, y) in enumerate(val_loader):
285 | b, _, img_h, img_w = x.shape
286 |
287 | img_h_64 = int(64 * np.ceil(img_h / 64.0))
288 | img_w_64 = int(64 * np.ceil(img_w / 64.0))
289 | x = F.pad(x, (0, img_w_64 - img_w, 0, img_h_64 - img_h), 'reflect')
290 | pred_x = self.model(x.to(self.device))["pred_x"][:, :, :img_h, :img_w]
291 | utils.logging.save_image(pred_x, os.path.join(image_folder, str(step), '{}'.format(y[0])))
292 |
--------------------------------------------------------------------------------
/models/decom.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import warnings
4 | import os
5 | import math
6 | import torch.nn.functional as F
7 | from einops import rearrange
8 |
9 | warnings.filterwarnings("ignore", category=UserWarning)
10 | warnings.filterwarnings("ignore", category=FutureWarning)
11 |
12 |
13 | class Depth_conv(nn.Module):
14 | def __init__(self, in_ch, out_ch):
15 | super(Depth_conv, self).__init__()
16 | self.depth_conv = nn.Conv2d(
17 | in_channels=in_ch,
18 | out_channels=in_ch,
19 | kernel_size=(3, 3),
20 | stride=(1, 1),
21 | padding=1,
22 | groups=in_ch
23 | )
24 | self.point_conv = nn.Conv2d(
25 | in_channels=in_ch,
26 | out_channels=out_ch,
27 | kernel_size=(1, 1),
28 | stride=(1, 1),
29 | padding=0,
30 | groups=1
31 | )
32 |
33 | def forward(self, input):
34 | out = self.depth_conv(input)
35 | out = self.point_conv(out)
36 | return out
37 |
38 |
39 | class Res_block(nn.Module):
40 | def __init__(self, in_channels, out_channels):
41 | super(Res_block, self).__init__()
42 |
43 | sequence = []
44 |
45 | sequence += [
46 | nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), stride=(1, 1), padding=1),
47 | nn.LeakyReLU(),
48 | nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), stride=(1, 1), padding=1)
49 | ]
50 |
51 | self.model = nn.Sequential(*sequence)
52 |
53 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1), stride=(1, 1), padding=0)
54 |
55 | def forward(self, x):
56 | out = self.model(x) + self.conv(x)
57 |
58 | return out
59 |
60 |
61 | class upsampling(nn.Module):
62 | def __init__(self, in_channels, out_channels):
63 | super(upsampling, self).__init__()
64 |
65 | self.conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1,
66 | output_padding=1)
67 |
68 | self.relu = nn.LeakyReLU()
69 |
70 | def forward(self, x):
71 | out = self.relu(self.conv(x))
72 | return out
73 |
74 |
75 | class channel_down(nn.Module):
76 | def __init__(self, channels):
77 | super(channel_down, self).__init__()
78 |
79 | self.conv0 = nn.Conv2d(channels * 4, channels * 2, kernel_size=(3, 3), stride=(1, 1), padding=1)
80 | self.conv1 = nn.Conv2d(channels * 2, channels, kernel_size=(3, 3), stride=(1, 1), padding=1)
81 | self.conv2 = nn.Conv2d(channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=1)
82 |
83 | self.relu = nn.LeakyReLU()
84 |
85 | def forward(self, x):
86 | out = torch.sigmoid(self.conv2(self.relu(self.conv1(self.relu(self.conv0(x))))))
87 |
88 | return out
89 |
90 |
91 | class channel_up(nn.Module):
92 | def __init__(self, channels):
93 | super(channel_up, self).__init__()
94 |
95 | self.conv0 = nn.Conv2d(3, channels, kernel_size=(3, 3), stride=(1, 1), padding=1)
96 | self.conv1 = nn.Conv2d(channels, channels * 2, kernel_size=(3, 3), stride=(1, 1), padding=1)
97 | self.conv2 = nn.Conv2d(channels * 2, channels * 4, kernel_size=(3, 3), stride=(1, 1), padding=1)
98 |
99 | self.relu = nn.LeakyReLU()
100 |
101 | def forward(self, x):
102 | out = self.conv2(self.relu(self.conv1(self.relu(self.conv0(x)))))
103 |
104 | return out
105 |
106 |
107 | class feature_pyramid(nn.Module):
108 | def __init__(self, channels):
109 | super(feature_pyramid, self).__init__()
110 |
111 | self.convs = nn.Sequential(nn.Conv2d(3, channels, kernel_size=(5, 5), stride=(1, 1), padding=2),
112 | nn.Conv2d(channels, channels, kernel_size=(5, 5), stride=(1, 1), padding=2))
113 |
114 | self.block0 = Res_block(channels, channels)
115 |
116 | self.down0 = nn.Conv2d(channels, channels, kernel_size=(3, 3), stride=(2, 2), padding=1)
117 |
118 | self.block1 = Res_block(channels, channels * 2)
119 |
120 | self.down1 = nn.Conv2d(channels * 2, channels * 2, kernel_size=(3, 3), stride=(2, 2), padding=1)
121 |
122 | self.block2 = Res_block(channels * 2, channels * 4)
123 |
124 | self.down2 = nn.Conv2d(channels * 4, channels * 4, kernel_size=(3, 3), stride=(2, 2), padding=1)
125 |
126 | self.relu = nn.LeakyReLU()
127 |
128 | def forward(self, x):
129 |
130 | level0 = self.down0(self.block0(self.convs(x)))
131 | level1 = self.down1(self.block1(level0))
132 | level2 = self.down2(self.block2(level1))
133 |
134 | return level0, level1, level2
135 |
136 |
137 | class ReconNet(nn.Module):
138 | def __init__(self, channels):
139 | super(ReconNet, self).__init__()
140 |
141 | self.pyramid = feature_pyramid(channels)
142 |
143 | self.channel_down = channel_down(channels)
144 | self.channel_up = channel_up(channels)
145 |
146 | self.block_up0 = Res_block(channels * 4, channels * 4)
147 | self.block_up1 = Res_block(channels * 4, channels * 4)
148 | self.up_sampling0 = upsampling(channels * 4, channels * 2)
149 | self.block_up2 = Res_block(channels * 2, channels * 2)
150 | self.block_up3 = Res_block(channels * 2, channels * 2)
151 | self.up_sampling1 = upsampling(channels * 2, channels)
152 | self.block_up4 = Res_block(channels, channels)
153 | self.block_up5 = Res_block(channels, channels)
154 | self.up_sampling2 = upsampling(channels, channels)
155 |
156 | self.conv2 = nn.Conv2d(channels, channels, kernel_size=(3, 3), stride=(1, 1), padding=1)
157 | self.conv3 = nn.Conv2d(channels, 3, kernel_size=(1, 1), stride=(1, 1), padding=0)
158 |
159 | self.relu = nn.LeakyReLU()
160 |
161 | def forward(self, x, pred_fea=None):
162 |
163 | if pred_fea is None:
164 | low_fea_down2, low_fea_down4, low_fea_down8 = self.pyramid(x[:, :3, ...])
165 | low_fea_down8 = self.channel_down(low_fea_down8)
166 |
167 | high_fea_down2, high_fea_down4, high_fea_down8 = self.pyramid(x[:, 3:, ...])
168 | high_fea_down8 = self.channel_down(high_fea_down8)
169 |
170 | return low_fea_down8, high_fea_down8
171 | else:
172 | # =================low ori decoder=================
173 | low_fea_down2, low_fea_down4, low_fea_down8 = self.pyramid(x[:, :3, ...])
174 |
175 | pred_fea = self.channel_up(pred_fea)
176 |
177 | pred_fea_up2 = self.up_sampling0(
178 | self.block_up1(self.block_up0(pred_fea) + low_fea_down8))
179 | pred_fea_up4 = self.up_sampling1(
180 | self.block_up3(self.block_up2(pred_fea_up2) + low_fea_down4))
181 | pred_fea_up8 = self.up_sampling2(
182 | self.block_up5(self.block_up4(pred_fea_up4) + low_fea_down2))
183 |
184 | pred_img = self.conv3(self.relu(self.conv2(pred_fea_up8)))
185 |
186 | return pred_img
187 |
188 |
189 | class Self_Attention(nn.Module):
190 | def __init__(self, dim, num_heads, bias):
191 | super(Self_Attention, self).__init__()
192 | self.num_heads = num_heads
193 | self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=(1, 1), bias=bias)
194 | self.qkv_dwconv = nn.Conv2d(dim * 3, dim * 3, kernel_size=(3, 3), stride=(1, 1),
195 | padding=1, groups=dim * 3, bias=bias)
196 | self.project_out = nn.Conv2d(dim, dim, kernel_size=(1, 1), bias=bias)
197 |
198 | def forward(self, x):
199 | b, c, h, w = x.shape
200 |
201 | qkv = self.qkv_dwconv(self.qkv(x))
202 | q, k, v = qkv.chunk(3, dim=1)
203 |
204 | q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
205 | k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
206 | v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
207 |
208 | q = torch.nn.functional.normalize(q, dim=-1)
209 | k = torch.nn.functional.normalize(k, dim=-1)
210 |
211 | attn = (q @ k.transpose(-2, -1))
212 | attn = attn.softmax(dim=-1)
213 |
214 | out = (attn @ v)
215 |
216 | out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
217 |
218 | out = self.project_out(out)
219 | return out
220 |
221 |
222 | class Cross_Attention(nn.Module):
223 | def __init__(self, dim, num_heads, dropout=0.):
224 | super(Cross_Attention, self).__init__()
225 | if dim % num_heads != 0:
226 | raise ValueError(
227 | "The hidden size (%d) is not a multiple of the number of attention "
228 | "heads (%d)" % (dim, num_heads)
229 | )
230 | self.num_heads = num_heads
231 | self.attention_head_size = int(dim / num_heads)
232 |
233 | self.query = Depth_conv(in_ch=dim, out_ch=dim)
234 | self.key = Depth_conv(in_ch=dim, out_ch=dim)
235 | self.value = Depth_conv(in_ch=dim, out_ch=dim)
236 |
237 | self.dropout = nn.Dropout(dropout)
238 |
239 | def transpose_for_scores(self, x):
240 | '''
241 | new_x_shape = x.size()[:-1] + (
242 | self.num_heads,
243 | self.attention_head_size,
244 | )
245 | print(new_x_shape)
246 | x = x.view(*new_x_shape)
247 | '''
248 | return x.permute(0, 2, 1, 3)
249 |
250 | def forward(self, hidden_states, ctx):
251 | mixed_query_layer = self.query(hidden_states)
252 | mixed_key_layer = self.key(ctx)
253 | mixed_value_layer = self.value(ctx)
254 |
255 | query_layer = self.transpose_for_scores(mixed_query_layer)
256 | key_layer = self.transpose_for_scores(mixed_key_layer)
257 | value_layer = self.transpose_for_scores(mixed_value_layer)
258 |
259 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
260 | attention_scores = attention_scores / math.sqrt(self.attention_head_size)
261 |
262 | attention_probs = nn.Softmax(dim=-1)(attention_scores)
263 |
264 | attention_probs = self.dropout(attention_probs)
265 |
266 | ctx_layer = torch.matmul(attention_probs, value_layer)
267 | ctx_layer = ctx_layer.permute(0, 2, 1, 3).contiguous()
268 |
269 | return ctx_layer
270 |
271 |
272 | class Retinex_decom(nn.Module):
273 | def __init__(self, channels):
274 | super(Retinex_decom, self).__init__()
275 |
276 | self.conv0 = nn.Conv2d(3, channels, kernel_size=(3, 3), stride=(1, 1), padding=1)
277 | self.blocks0 = nn.Sequential(Res_block(channels, channels),
278 | Res_block(channels, channels))
279 |
280 | self.conv1 = nn.Conv2d(1, channels, kernel_size=(3, 3), stride=(1, 1), padding=1)
281 | self.blocks1 = nn.Sequential(Res_block(channels, channels),
282 | Res_block(channels, channels))
283 |
284 | self.cross_attention = Cross_Attention(dim=channels, num_heads=8)
285 | self.self_attention = Self_Attention(dim=channels, num_heads=8, bias=True)
286 |
287 | self.conv0_1 = nn.Sequential(Res_block(channels, channels),
288 | nn.Conv2d(channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=1))
289 | self.conv1_1 = nn.Sequential(Res_block(channels, channels),
290 | nn.Conv2d(channels, 1, kernel_size=(3, 3), stride=(1, 1), padding=1))
291 |
292 | def forward(self, x):
293 | init_illumination = torch.max(x, dim=1, keepdim=True)[0]
294 | init_reflectance = x / init_illumination
295 |
296 | Reflectance, Illumination = (self.blocks0(self.conv0(init_reflectance)),
297 | self.blocks1(self.conv1(init_illumination)))
298 |
299 | Reflectance_final = self.cross_attention(Illumination, Reflectance)
300 |
301 | Illumination_content = self.self_attention(Illumination)
302 |
303 | Reflectance_final = self.conv0_1(Reflectance_final + Illumination_content)
304 | Illumination_final = self.conv1_1(Illumination - Illumination_content)
305 |
306 | R = torch.sigmoid(Reflectance_final)
307 | L = torch.sigmoid(Illumination_final)
308 | L = torch.cat([L for i in range(3)], dim=1)
309 |
310 | return R, L
311 |
312 |
313 | class CTDN(nn.Module):
314 | def __init__(self, channels=64):
315 | super(CTDN, self).__init__()
316 |
317 | self.ReconNet = ReconNet(channels)
318 | self.retinex = Retinex_decom(channels)
319 |
320 | def forward(self, images, pred_fea=None):
321 |
322 | output = {}
323 | # =================decomposition low=================
324 | if pred_fea is None:
325 | low_fea_down8, high_fea_down8 = self.ReconNet(images, pred_fea=None)
326 |
327 | low_R, low_L = self.retinex(low_fea_down8)
328 | high_R, high_L = self.retinex(high_fea_down8)
329 |
330 | output["low_R"] = low_R
331 | output["low_L"] = low_L
332 | output["low_fea"] = low_fea_down8
333 | output["high_R"] = high_R
334 | output["high_L"] = high_L
335 | output["high_fea"] = high_fea_down8
336 |
337 | else:
338 | pred_img = self.ReconNet(images[:, :3, ...], pred_fea=pred_fea)
339 | output["pred_img"] = pred_img
340 |
341 | return output
342 |
--------------------------------------------------------------------------------
/models/unet.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional
5 |
6 | # This script is from the following repositories
7 | # https://github.com/ermongroup/ddim
8 | # https://github.com/bahjat-kawar/ddrm
9 |
10 |
11 | def get_timestep_embedding(timesteps, embedding_dim):
12 | """
13 | This matches the implementation in Denoising Diffusion Probabilistic Models:
14 | From Fairseq.
15 | Build sinusoidal embeddings.
16 | This matches the implementation in tensor2tensor, but differs slightly
17 | from the description in Section 3.5 of "Attention Is All You Need".
18 | """
19 | assert len(timesteps.shape) == 1
20 |
21 | half_dim = embedding_dim // 2
22 | emb = math.log(10000) / (half_dim - 1)
23 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
24 | emb = emb.to(device=timesteps.device)
25 | emb = timesteps.float()[:, None] * emb[None, :]
26 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
27 | if embedding_dim % 2 == 1: # zero pad
28 | emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
29 | return emb
30 |
31 |
32 | def nonlinearity(x):
33 | # swish
34 | return x*torch.sigmoid(x)
35 |
36 |
37 | def Normalize(in_channels):
38 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
39 |
40 |
41 | class Upsample(nn.Module):
42 | def __init__(self, in_channels, with_conv):
43 | super().__init__()
44 | self.with_conv = with_conv
45 | if self.with_conv:
46 | self.conv = torch.nn.Conv2d(in_channels,
47 | in_channels,
48 | kernel_size=3,
49 | stride=1,
50 | padding=1)
51 |
52 | def forward(self, x):
53 | x = torch.nn.functional.interpolate(
54 | x, scale_factor=2.0, mode="nearest")
55 | if self.with_conv:
56 | x = self.conv(x)
57 | return x
58 |
59 |
60 | class Downsample(nn.Module):
61 | def __init__(self, in_channels, with_conv):
62 | super().__init__()
63 | self.with_conv = with_conv
64 | if self.with_conv:
65 | # no asymmetric padding in torch conv, must do it ourselves
66 | self.conv = torch.nn.Conv2d(in_channels,
67 | in_channels,
68 | kernel_size=3,
69 | stride=2,
70 | padding=0)
71 |
72 | def forward(self, x):
73 | if self.with_conv:
74 | pad = (0, 1, 0, 1)
75 | x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
76 | x = self.conv(x)
77 | else:
78 | x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
79 | return x
80 |
81 |
82 | class ResnetBlock(nn.Module):
83 | def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
84 | dropout, temb_channels=512):
85 | super().__init__()
86 | self.in_channels = in_channels
87 | out_channels = in_channels if out_channels is None else out_channels
88 | self.out_channels = out_channels
89 | self.use_conv_shortcut = conv_shortcut
90 |
91 | self.norm1 = Normalize(in_channels)
92 | self.conv1 = torch.nn.Conv2d(in_channels,
93 | out_channels,
94 | kernel_size=3,
95 | stride=1,
96 | padding=1)
97 | self.temb_proj = torch.nn.Linear(temb_channels,
98 | out_channels)
99 | self.norm2 = Normalize(out_channels)
100 | self.dropout = torch.nn.Dropout(dropout)
101 | self.conv2 = torch.nn.Conv2d(out_channels,
102 | out_channels,
103 | kernel_size=3,
104 | stride=1,
105 | padding=1)
106 | if self.in_channels != self.out_channels:
107 | if self.use_conv_shortcut:
108 | self.conv_shortcut = torch.nn.Conv2d(in_channels,
109 | out_channels,
110 | kernel_size=3,
111 | stride=1,
112 | padding=1)
113 | else:
114 | self.nin_shortcut = torch.nn.Conv2d(in_channels,
115 | out_channels,
116 | kernel_size=1,
117 | stride=1,
118 | padding=0)
119 |
120 | def forward(self, x, temb):
121 | h = x
122 | h = self.norm1(h)
123 | h = nonlinearity(h)
124 | h = self.conv1(h)
125 |
126 | h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
127 |
128 | h = self.norm2(h)
129 | h = nonlinearity(h)
130 | h = self.dropout(h)
131 | h = self.conv2(h)
132 |
133 | if self.in_channels != self.out_channels:
134 | if self.use_conv_shortcut:
135 | x = self.conv_shortcut(x)
136 | else:
137 | x = self.nin_shortcut(x)
138 |
139 | return x+h
140 |
141 |
142 | class AttnBlock(nn.Module):
143 | def __init__(self, in_channels):
144 | super().__init__()
145 | self.in_channels = in_channels
146 |
147 | self.norm = Normalize(in_channels)
148 | self.q = torch.nn.Conv2d(in_channels,
149 | in_channels,
150 | kernel_size=1,
151 | stride=1,
152 | padding=0)
153 | self.k = torch.nn.Conv2d(in_channels,
154 | in_channels,
155 | kernel_size=1,
156 | stride=1,
157 | padding=0)
158 | self.v = torch.nn.Conv2d(in_channels,
159 | in_channels,
160 | kernel_size=1,
161 | stride=1,
162 | padding=0)
163 | self.proj_out = torch.nn.Conv2d(in_channels,
164 | in_channels,
165 | kernel_size=1,
166 | stride=1,
167 | padding=0)
168 |
169 | def forward(self, x):
170 | h_ = x
171 | h_ = self.norm(h_)
172 | q = self.q(h_)
173 | k = self.k(h_)
174 | v = self.v(h_)
175 |
176 | # compute attention
177 | b, c, h, w = q.shape
178 | q = q.reshape(b, c, h*w)
179 | q = q.permute(0, 2, 1) # b,hw,c
180 | k = k.reshape(b, c, h*w) # b,c,hw
181 | w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
182 | w_ = w_ * (int(c)**(-0.5))
183 | w_ = torch.nn.functional.softmax(w_, dim=2)
184 |
185 | # attend to values
186 | v = v.reshape(b, c, h*w)
187 | w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
188 | # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
189 | h_ = torch.bmm(v, w_)
190 | h_ = h_.reshape(b, c, h, w)
191 |
192 | h_ = self.proj_out(h_)
193 |
194 | return x+h_
195 |
196 |
197 | class DiffusionUNet(nn.Module):
198 | def __init__(self, config):
199 | super().__init__()
200 | self.config = config
201 | ch, out_ch, ch_mult = config.model.ch, config.model.out_ch, tuple(config.model.ch_mult)
202 | num_res_blocks = config.model.num_res_blocks
203 | dropout = config.model.dropout
204 | in_channels = config.model.in_channels * 2 if config.data.conditional else config.model.in_channels
205 | resamp_with_conv = config.model.resamp_with_conv
206 |
207 | self.ch = ch
208 | self.temb_ch = self.ch*4
209 | self.num_resolutions = len(ch_mult)
210 | self.num_res_blocks = num_res_blocks
211 | self.in_channels = in_channels
212 |
213 | # timestep embedding
214 | self.temb = nn.Module()
215 | self.temb.dense = nn.ModuleList([
216 | torch.nn.Linear(self.ch,
217 | self.temb_ch),
218 | torch.nn.Linear(self.temb_ch,
219 | self.temb_ch),
220 | ])
221 |
222 | # downsampling
223 | self.conv_in = torch.nn.Conv2d(in_channels,
224 | self.ch,
225 | kernel_size=3,
226 | stride=1,
227 | padding=1)
228 |
229 | in_ch_mult = (1,)+ch_mult
230 | self.down = nn.ModuleList()
231 | block_in = None
232 | for i_level in range(self.num_resolutions):
233 | block = nn.ModuleList()
234 | attn = nn.ModuleList()
235 | block_in = ch*in_ch_mult[i_level]
236 | block_out = ch*ch_mult[i_level]
237 | for i_block in range(self.num_res_blocks):
238 | block.append(ResnetBlock(in_channels=block_in,
239 | out_channels=block_out,
240 | temb_channels=self.temb_ch,
241 | dropout=dropout))
242 | block_in = block_out
243 | if i_level == 2:
244 | attn.append(AttnBlock(block_in))
245 | down = nn.Module()
246 | down.block = block
247 | down.attn = attn
248 | if i_level != self.num_resolutions-1:
249 | down.downsample = Downsample(block_in, resamp_with_conv)
250 | self.down.append(down)
251 |
252 | # middle
253 | self.mid = nn.Module()
254 | self.mid.block_1 = ResnetBlock(in_channels=block_in,
255 | out_channels=block_in,
256 | temb_channels=self.temb_ch,
257 | dropout=dropout)
258 | self.mid.attn_1 = AttnBlock(block_in)
259 | self.mid.block_2 = ResnetBlock(in_channels=block_in,
260 | out_channels=block_in,
261 | temb_channels=self.temb_ch,
262 | dropout=dropout)
263 |
264 | # upsampling
265 | self.up = nn.ModuleList()
266 | for i_level in reversed(range(self.num_resolutions)):
267 | block = nn.ModuleList()
268 | attn = nn.ModuleList()
269 | block_out = ch*ch_mult[i_level]
270 | skip_in = ch*ch_mult[i_level]
271 | for i_block in range(self.num_res_blocks+1):
272 | if i_block == self.num_res_blocks:
273 | skip_in = ch*in_ch_mult[i_level]
274 | block.append(ResnetBlock(in_channels=block_in+skip_in,
275 | out_channels=block_out,
276 | temb_channels=self.temb_ch,
277 | dropout=dropout))
278 | block_in = block_out
279 | if i_level == 2:
280 | attn.append(AttnBlock(block_in))
281 | up = nn.Module()
282 | up.block = block
283 | up.attn = attn
284 | if i_level != 0:
285 | up.upsample = Upsample(block_in, resamp_with_conv)
286 | self.up.insert(0, up) # prepend to get consistent order
287 |
288 | # end
289 | self.norm_out = Normalize(block_in)
290 | self.conv_out = torch.nn.Conv2d(block_in,
291 | out_ch,
292 | kernel_size=3,
293 | stride=1,
294 | padding=1)
295 |
296 | def forward(self, x, t):
297 | # assert x.shape[2] == x.shape[3] == self.resolution
298 |
299 | # timestep embedding
300 | temb = get_timestep_embedding(t, self.ch)
301 | temb = self.temb.dense[0](temb)
302 | temb = nonlinearity(temb)
303 | temb = self.temb.dense[1](temb)
304 |
305 | # downsampling
306 | hs = [self.conv_in(x)]
307 | for i_level in range(self.num_resolutions):
308 | for i_block in range(self.num_res_blocks):
309 | h = self.down[i_level].block[i_block](hs[-1], temb)
310 | if len(self.down[i_level].attn) > 0:
311 | h = self.down[i_level].attn[i_block](h)
312 | hs.append(h)
313 | if i_level != self.num_resolutions-1:
314 | hs.append(self.down[i_level].downsample(hs[-1]))
315 |
316 | # middle
317 | h = hs[-1]
318 | h = self.mid.block_1(h, temb)
319 | h = self.mid.attn_1(h)
320 | h = self.mid.block_2(h, temb)
321 |
322 | # upsampling
323 | for i_level in reversed(range(self.num_resolutions)):
324 | for i_block in range(self.num_res_blocks+1):
325 | h = self.up[i_level].block[i_block](
326 | torch.cat([h, hs.pop()], dim=1), temb)
327 | if len(self.up[i_level].attn) > 0:
328 | h = self.up[i_level].attn[i_block](h)
329 | if i_level != 0:
330 | h = self.up[i_level].upsample(h)
331 |
332 | # end
333 | h = self.norm_out(h)
334 | h = nonlinearity(h)
335 | h = self.conv_out(h)
336 | return h
337 |
--------------------------------------------------------------------------------