├── datasets ├── __init__.py ├── data_augment.py └── dataset.py ├── models ├── __init__.py ├── restoration.py ├── ddm.py ├── decom.py └── unet.py ├── Figures ├── visual.jpg └── pipeline.jpg ├── utils ├── __init__.py ├── sampling.py ├── logging.py └── optimize.py ├── configs └── unsupervised.yml ├── requirements.txt ├── evaluate.py ├── train.py └── README.md /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from datasets.dataset import * 2 | 3 | __all__ = ["LLdataset"] 4 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from models.ddm import * 2 | from models.restoration import * 3 | -------------------------------------------------------------------------------- /Figures/visual.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JianghaiSCU/LightenDiffusion/HEAD/Figures/visual.jpg -------------------------------------------------------------------------------- /Figures/pipeline.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JianghaiSCU/LightenDiffusion/HEAD/Figures/pipeline.jpg -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from utils.logging import * 2 | from utils.sampling import * 3 | from utils.optimize import * 4 | -------------------------------------------------------------------------------- /utils/sampling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | # This script is adapted from the following repository: https://github.com/ermongroup/ddim 3 | 4 | 5 | def data_transform(X): 6 | return 2 * X - 1.0 7 | 8 | 9 | def inverse_data_transform(X): 10 | return torch.clamp((X + 1.0) / 2.0, 0.0, 1.0) -------------------------------------------------------------------------------- /utils/logging.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import shutil 3 | import os 4 | import torchvision.utils as tvu 5 | 6 | 7 | def save_image(img, file_directory): 8 | if not os.path.exists(os.path.dirname(file_directory)): 9 | os.makedirs(os.path.dirname(file_directory)) 10 | tvu.save_image(img, file_directory) 11 | 12 | 13 | def save_checkpoint(state, filename): 14 | if not os.path.exists(os.path.dirname(filename)): 15 | os.makedirs(os.path.dirname(filename)) 16 | torch.save(state, filename + '.pth.tar') 17 | 18 | 19 | def load_checkpoint(path, device): 20 | if device is None: 21 | return torch.load(path) 22 | else: 23 | return torch.load(path, map_location=device) 24 | -------------------------------------------------------------------------------- /utils/optimize.py: -------------------------------------------------------------------------------- 1 | import torch.optim as optim 2 | 3 | 4 | def get_optimizer(config, parameters): 5 | if config.optim.optimizer == 'Adam': 6 | optimizer = optim.Adam(parameters, lr=config.optim.lr, weight_decay=config.optim.weight_decay, 7 | betas=(0.9, 0.999), amsgrad=config.optim.amsgrad, eps=config.optim.eps) 8 | 9 | elif config.optim.optimizer == 'RMSProp': 10 | optimizer = optim.RMSprop(parameters, lr=config.optim.lr, weight_decay=config.optim.weight_decay) 11 | elif config.optim.optimizer == 'SGD': 12 | optimizer = optim.SGD(parameters, lr=config.optim.lr, momentum=0.9) 13 | else: 14 | raise NotImplementedError('Optimizer {} not understood.'.format(config.optim.optimizer)) 15 | 16 | return optimizer 17 | -------------------------------------------------------------------------------- /configs/unsupervised.yml: -------------------------------------------------------------------------------- 1 | data: 2 | type: "LLdataset" 3 | train_dataset: "unpaired" 4 | val_dataset: "LOLv1" 5 | patch_size: 512 6 | channels: 3 7 | num_workers: 4 8 | data_dir: "/data/Image_restoration/LSRW_dataset" 9 | ckpt_dir: "ckpt/stage2" 10 | conditional: True 11 | 12 | model: 13 | in_channels: 3 14 | out_ch: 3 15 | ch: 64 16 | ch_mult: [1, 2, 3, 4] 17 | num_res_blocks: 2 18 | dropout: 0.0 19 | ema_rate: 0.999 20 | ema: True 21 | resamp_with_conv: True 22 | 23 | diffusion: 24 | beta_schedule: linear 25 | beta_start: 0.0001 26 | beta_end: 0.02 27 | num_diffusion_timesteps: 1000 28 | num_sampling_timesteps: 20 29 | 30 | training: 31 | batch_size: 12 32 | n_epochs: 100 33 | validation_freq: 2000 34 | 35 | sampling: 36 | batch_size: 1 37 | 38 | optim: 39 | weight_decay: 0.000 40 | optimizer: "Adam" 41 | lr: 0.00002 42 | amsgrad: False 43 | eps: 0.00000001 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | astunparse==1.6.3 2 | boto3==1.18.47 3 | botocore==1.21.47 4 | cffi==1.15.0 5 | coloredlogs==15.0 6 | configargparse==1.5.3 7 | cupy-cuda102==7.8.0 8 | cvxpy==1.1.18 9 | docutils==0.15.2 10 | ecos==2.0.8 11 | fastrlock==0.8 12 | flatbuffers==2.0 13 | gast==0.3.3 14 | gdown==4.2.0 15 | google-pasta==0.2.0 16 | grpcio==1.40.0 17 | humanfriendly==10.0 18 | influxdb==5.3.1 19 | jmespath==0.10.0 20 | jpeg4py==0.1.4 21 | jupyterlab==1.0.0 22 | jupyterlab-server==1.0.0 23 | keras==2.7.0 24 | keras-preprocessing==1.1.2 25 | libclang==12.0.0 26 | loguru==0.5.3 27 | ninja==1.10.2.2 28 | nori2==1.11.8 29 | opencv-python==4.5.3.56 30 | opt-einsum==3.3.0 31 | osqp==0.6.2.post4 32 | pip==21.3.1 33 | protobuf==3.18.0 34 | pycocotools==2.0.4 35 | pypng==0.0.21 36 | python-statemachine==0.8.0 37 | qdldl==0.1.5.post0 38 | redis==3.5.3 39 | refile==5.8.3.post1 40 | s3transfer==0.5.0 41 | scipy==1.4.1 42 | scs==3.0.0 43 | setuptools==59.4.0 44 | smart-open==5.2.1 45 | tabulate==0.8.9 46 | tensorboard==2.2.2 47 | tensorflow==2.2.0 48 | tensorflow-estimator==2.2.0 49 | tensorflow-io-gcs-filesystem==0.22.0 50 | termcolor==1.1.0 51 | timm==0.4.12 52 | torch==1.7.1 53 | torchsummary==1.5.1 54 | torchvision==0.8.2 55 | urllib3==1.25.11 56 | kornia==0.6.7 57 | accelerate==0.15.0 -------------------------------------------------------------------------------- /models/restoration.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import utils 4 | import os 5 | import time 6 | import torch.nn.functional as F 7 | 8 | 9 | class DiffusiveRestoration: 10 | def __init__(self, diffusion, args, config): 11 | super(DiffusiveRestoration, self).__init__() 12 | self.args = args 13 | self.config = config 14 | self.diffusion = diffusion 15 | 16 | if os.path.isfile(args.resume): 17 | self.diffusion.load_ddm_ckpt(args.resume, ema=False) 18 | self.diffusion.model.eval() 19 | else: 20 | print('Pre-trained model path is missing!') 21 | 22 | def restore(self, val_loader): 23 | image_folder = os.path.join(self.args.image_folder, self.config.data.val_dataset) 24 | with torch.no_grad(): 25 | for i, (x, y) in enumerate(val_loader): 26 | 27 | x_cond = x[:, :3, :, :].to(self.diffusion.device) 28 | b, c, h, w = x_cond.shape 29 | img_h_64 = int(64 * np.ceil(h / 64.0)) 30 | img_w_64 = int(64 * np.ceil(w / 64.0)) 31 | x_cond = F.pad(x_cond, (0, img_w_64 - w, 0, img_h_64 - h), 'reflect') 32 | 33 | t1 = time.time() 34 | pred_x = self.diffusion.model(torch.cat((x_cond, x_cond), 35 | dim=1))["pred_x"][:, :, :h, :w] 36 | t2 = time.time() 37 | 38 | utils.logging.save_image(pred_x, os.path.join(image_folder, f"{y[0]}")) 39 | print(f"processing image {y[0]}, time={t2 - t1}") 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import socket 5 | import yaml 6 | import torch 7 | import torch.backends.cudnn as cudnn 8 | import numpy as np 9 | import torchvision 10 | import models 11 | import datasets 12 | import utils 13 | from models import DenoisingDiffusion, DiffusiveRestoration 14 | 15 | 16 | def parse_args_and_config(): 17 | parser = argparse.ArgumentParser(description='Latent-Retinex Diffusion Models') 18 | parser.add_argument("--config", default='unsupervised.yml', type=str, 19 | help="Path to the config file") 20 | parser.add_argument('--mode', type=str, default='evaluation', help='training or evaluation') 21 | parser.add_argument('--resume', default='ckpt/stage2/stage2_weight.pth.tar', type=str, 22 | help='Path for the diffusion model checkpoint to load for evaluation') 23 | parser.add_argument("--image_folder", default='results/', type=str, 24 | help="Location to save restored images") 25 | args = parser.parse_args() 26 | 27 | with open(os.path.join("configs", args.config), "r") as f: 28 | config = yaml.safe_load(f) 29 | new_config = dict2namespace(config) 30 | 31 | return args, new_config 32 | 33 | 34 | def dict2namespace(config): 35 | namespace = argparse.Namespace() 36 | for key, value in config.items(): 37 | if isinstance(value, dict): 38 | new_value = dict2namespace(value) 39 | else: 40 | new_value = value 41 | setattr(namespace, key, new_value) 42 | return namespace 43 | 44 | 45 | def main(): 46 | args, config = parse_args_and_config() 47 | 48 | # setup device to run 49 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 50 | print("Using device: {}".format(device)) 51 | config.device = device 52 | 53 | if torch.cuda.is_available(): 54 | print('Note: Currently supports evaluations (restoration) when run only on a single GPU!') 55 | 56 | print("=> using dataset '{}'".format(config.data.val_dataset)) 57 | DATASET = datasets.__dict__[config.data.type](config) 58 | _, val_loader = DATASET.get_loaders() 59 | 60 | # create model 61 | print("=> creating denoising-diffusion model") 62 | diffusion = DenoisingDiffusion(args, config) 63 | model = DiffusiveRestoration(diffusion, args, config) 64 | model.restore(val_loader) 65 | 66 | 67 | if __name__ == '__main__': 68 | main() 69 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import socket 5 | import yaml 6 | import torch 7 | import torch.backends.cudnn as cudnn 8 | import torch.utils.data 9 | import numpy as np 10 | import torchvision 11 | import models 12 | import datasets 13 | import utils 14 | from models import DenoisingDiffusion 15 | 16 | 17 | def parse_args_and_config(): 18 | parser = argparse.ArgumentParser(description='Latent-Retinex Diffusion Models') 19 | parser.add_argument("--config", default='unsupervised.yml', type=str, 20 | help="Path to the config file") 21 | parser.add_argument('--mode', type=str, default='training', help='training or evaluation') 22 | parser.add_argument('--resume', default='', type=str, 23 | help='Path for checkpoint to load and resume') 24 | parser.add_argument("--image_folder", default='results/', type=str, 25 | help="Location to save restored validation image patches") 26 | parser.add_argument('--seed', default=230, type=int, metavar='N', 27 | help='Seed for initializing training (default: 230)') 28 | args = parser.parse_args() 29 | 30 | with open(os.path.join("configs", args.config), "r") as f: 31 | config = yaml.safe_load(f) 32 | new_config = dict2namespace(config) 33 | 34 | return args, new_config 35 | 36 | 37 | def dict2namespace(config): 38 | namespace = argparse.Namespace() 39 | for key, value in config.items(): 40 | if isinstance(value, dict): 41 | new_value = dict2namespace(value) 42 | else: 43 | new_value = value 44 | setattr(namespace, key, new_value) 45 | return namespace 46 | 47 | 48 | def main(): 49 | args, config = parse_args_and_config() 50 | # setup device to run 51 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 52 | print("Using device: {}".format(device)) 53 | config.device = device 54 | 55 | # set random seed 56 | torch.manual_seed(args.seed) 57 | np.random.seed(args.seed) 58 | if torch.cuda.is_available(): 59 | torch.cuda.manual_seed_all(args.seed) 60 | torch.backends.cudnn.benchmark = True 61 | 62 | # data loading 63 | print("=> using dataset '{}'".format(config.data.train_dataset)) 64 | DATASET = datasets.__dict__[config.data.type](config) 65 | 66 | # create model 67 | print("=> creating denoising-diffusion model...") 68 | diffusion = DenoisingDiffusion(args, config) 69 | diffusion.train(DATASET) 70 | 71 | 72 | if __name__ == "__main__": 73 | main() 74 | -------------------------------------------------------------------------------- /datasets/data_augment.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torchvision.transforms as transforms 3 | import torchvision.transforms.functional as F 4 | 5 | 6 | class PairRandomCrop(transforms.RandomCrop): 7 | 8 | def __call__(self, image, label): 9 | 10 | if self.padding is not None: 11 | image = F.pad(image, self.padding, self.fill, self.padding_mode) 12 | label = F.pad(label, self.padding, self.fill, self.padding_mode) 13 | 14 | # pad the width if needed 15 | if self.pad_if_needed and image.size[0] < self.size[1]: 16 | image = F.pad(image, (self.size[1] - image.size[0], 0), self.fill, self.padding_mode) 17 | label = F.pad(label, (self.size[1] - label.size[0], 0), self.fill, self.padding_mode) 18 | # pad the height if needed 19 | if self.pad_if_needed and image.size[1] < self.size[0]: 20 | image = F.pad(image, (0, self.size[0] - image.size[1]), self.fill, self.padding_mode) 21 | label = F.pad(label, (0, self.size[0] - image.size[1]), self.fill, self.padding_mode) 22 | 23 | i, j, h, w = self.get_params(image, self.size) 24 | 25 | return F.crop(image, i, j, h, w), F.crop(label, i, j, h, w) 26 | 27 | 28 | class PairCompose(transforms.Compose): 29 | def __call__(self, image, label): 30 | for t in self.transforms: 31 | image, label = t(image, label) 32 | return image, label 33 | 34 | 35 | class PairRandomHorizontalFilp(transforms.RandomHorizontalFlip): 36 | def __call__(self, img, label): 37 | """ 38 | Args: 39 | img (PIL Image): Image to be flipped. 40 | 41 | Returns: 42 | PIL Image: Randomly flipped image. 43 | """ 44 | if random.random() < self.p: 45 | return F.hflip(img), F.hflip(label) 46 | return img, label 47 | 48 | 49 | class PairRandomVerticalFlip(transforms.RandomVerticalFlip): 50 | def __call__(self, img, label): 51 | """ 52 | Args: 53 | img (PIL Image): Image to be flipped. 54 | 55 | Returns: 56 | PIL Image: Randomly flipped image. 57 | """ 58 | if random.random() < self.p: 59 | return F.vflip(img), F.vflip(label) 60 | return img, label 61 | 62 | 63 | class PairToTensor(transforms.ToTensor): 64 | def __call__(self, pic, label): 65 | """ 66 | Args: 67 | pic (PIL Image or numpy.ndarray): Image to be converted to tensor. 68 | 69 | Returns: 70 | Tensor: Converted image. 71 | """ 72 | return F.to_tensor(pic), F.to_tensor(label) 73 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [ECCV 2024] LightenDiffusion: Unsupervised Low-Light Image Enhancement with Latent-Retinex Diffusion Models [[Paper]](https://arxiv.org/pdf/2407.08939) 2 |

Hai Jiang1,5, Ao Luo2,5, Xiaohong Liu4, Songchen Han1, Shuaicheng Liu3,5 3 |

1.Sichuan University, 2.Southwest Jiaotong University, 4 |

3.University of Electronic Science and Technology of China, 5 |

4.Shanghai Jiaotong University, 5.Megvii Technology 6 | 7 | ## Pipeline 8 | ![](./Figures/pipeline.jpg) 9 | 10 | ## Dependencies 11 | ``` 12 | pip install -r requirements.txt 13 | ```` 14 | 15 | ## Download the raw training and evaluation datasets 16 | ### Paired datasets 17 | LOL dataset: Chen Wei, Wenjing Wang, Wenhan Yang, and Jiaying Liu. "Deep Retinex Decomposition for Low-Light Enhancement". BMVC, 2018. [[Baiduyun (extracted code: sdd0)]](https://pan.baidu.com/s/1spt0kYU3OqsQSND-be4UaA) [[Google Drive]](https://drive.google.com/file/d/18bs_mAREhLipaM2qvhxs7u7ff2VSHet2/view?usp=sharing) 18 | 19 | LSRW dataset: Jiang Hai, Zhu Xuan, Ren Yang, Yutong Hao, Fengzhu Zou, Fang Lin, and Songchen Han. "R2RNet: Low-light Image Enhancement via Real-low to Real-normal Network". Journal of Visual Communication and Image Representation, 2023. [[Baiduyun (extracted code: wmrr)]](https://pan.baidu.com/s/1XHWQAS0ZNrnCyZ-bq7MKvA) 20 | 21 | ### Unpaired datasets 22 | Please refer to [[Project Page of RetinexNet]](https://daooshee.github.io/BMVC2018website/). 23 | 24 | ## Pre-trained Models 25 | You can download our pre-trained model from [[Google Drive]](https://drive.google.com/drive/folders/1m3t15rWw76IDDWJ0exLOe5P0uEnjk3zl?usp=drive_link) and [[Baidu Yun (extracted code:cjzk)]](https://pan.baidu.com/s/1fPLVgnZbdY1n75Flq54bMQ) 26 | 27 | ## How to train? 28 | You need to modify ```datasets/dataset.py``` slightly for your environment, and then 29 | ``` 30 | python train.py 31 | ``` 32 | 33 | ## How to test? 34 | ``` 35 | python evaluate.py 36 | ``` 37 | 38 | ## Visual comparison 39 | ![](./Figures/visual.jpg) 40 | 41 | ## Citation 42 | If you use this code or ideas from the paper for your research, please cite our paper: 43 | ``` 44 | @InProceedings{Jiang_2024_ECCV, 45 | author = {Jiang, Hai and Luo, Ao and Liu, Xiaohong and Han, Songchen and Liu, Shuaicheng}, 46 | title = {LightenDiffusion: Unsupervised Low-Light Image Enhancement with Latent-Retinex Diffusion Models}, 47 | booktitle = {European Conference on Computer Vision}, 48 | year = {2024}, 49 | pages = {} 50 | } 51 | ``` 52 | 53 | ## Acknowledgement 54 | Part of the code is adapted from previous works: [WeatherDiff](https://github.com/IGITUGraz/WeatherDiffusion) and [MIMO-UNet](https://github.com/chosj95/MIMO-UNet). We thank all the authors for their contributions. 55 | 56 | -------------------------------------------------------------------------------- /datasets/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.utils.data 4 | from PIL import Image 5 | from datasets.data_augment import PairCompose, PairToTensor, PairRandomHorizontalFilp 6 | 7 | 8 | class LLdataset: 9 | def __init__(self, config): 10 | self.config = config 11 | 12 | def get_loaders(self): 13 | train_dataset = AllWeatherDataset(self.config.data.data_dir, 14 | patch_size=self.config.data.patch_size, 15 | filelist='{}_train.txt'.format(self.config.data.train_dataset)) 16 | val_dataset = AllWeatherDataset(self.config.data.data_dir, 17 | patch_size=self.config.data.patch_size, 18 | filelist='{}_val.txt'.format(self.config.data.val_dataset), train=False) 19 | 20 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=self.config.training.batch_size, 21 | shuffle=True, num_workers=self.config.data.num_workers, 22 | pin_memory=True) 23 | val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=self.config.sampling.batch_size, 24 | shuffle=False, num_workers=self.config.data.num_workers, 25 | pin_memory=True) 26 | 27 | return train_loader, val_loader 28 | 29 | 30 | class AllWeatherDataset(torch.utils.data.Dataset): 31 | def __init__(self, dir, patch_size, filelist=None, train=True): 32 | super().__init__() 33 | 34 | self.dir = dir 35 | self.file_list = filelist 36 | self.train_list = os.path.join(dir, self.file_list) 37 | with open(self.train_list) as f: 38 | contents = f.readlines() 39 | input_names = [i.strip() for i in contents] 40 | 41 | self.input_names = input_names 42 | self.patch_size = patch_size 43 | 44 | if train: 45 | self.transforms = PairCompose([ 46 | PairRandomHorizontalFilp(), 47 | PairToTensor() 48 | ]) 49 | else: 50 | self.transforms = PairCompose([ 51 | PairToTensor() 52 | ]) 53 | 54 | def get_images(self, index): 55 | input_name = self.input_names[index].replace('\n', '') 56 | 57 | low_img_name, high_img_name = input_name.split(' ')[0], input_name.split(' ')[1] 58 | 59 | img_id = low_img_name.split('/')[-1] 60 | low_img, high_img = Image.open(low_img_name), Image.open(high_img_name) 61 | 62 | low_img, high_img = self.transforms(low_img, high_img) 63 | 64 | return torch.cat([low_img, high_img], dim=0), img_id 65 | 66 | def __getitem__(self, index): 67 | res = self.get_images(index) 68 | return res 69 | 70 | def __len__(self): 71 | return len(self.input_names) 72 | -------------------------------------------------------------------------------- /models/ddm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.backends.cudnn as cudnn 7 | import torch.nn.functional as F 8 | import utils 9 | from models.unet import DiffusionUNet 10 | from models.decom import CTDN 11 | 12 | 13 | class EMAHelper(object): 14 | def __init__(self, mu=0.9999): 15 | self.mu = mu 16 | self.shadow = {} 17 | 18 | def register(self, module): 19 | if isinstance(module, nn.DataParallel): 20 | module = module.module 21 | for name, param in module.named_parameters(): 22 | if param.requires_grad: 23 | self.shadow[name] = param.data.clone() 24 | 25 | def update(self, module): 26 | if isinstance(module, nn.DataParallel): 27 | module = module.module 28 | for name, param in module.named_parameters(): 29 | if param.requires_grad: 30 | self.shadow[name].data = (1. - self.mu) * param.data + self.mu * self.shadow[name].data 31 | 32 | def ema(self, module): 33 | if isinstance(module, nn.DataParallel): 34 | module = module.module 35 | for name, param in module.named_parameters(): 36 | if param.requires_grad: 37 | param.data.copy_(self.shadow[name].data) 38 | 39 | def ema_copy(self, module): 40 | if isinstance(module, nn.DataParallel): 41 | inner_module = module.module 42 | module_copy = type(inner_module)(inner_module.config).to(inner_module.config.device) 43 | module_copy.load_state_dict(inner_module.state_dict()) 44 | module_copy = nn.DataParallel(module_copy) 45 | else: 46 | module_copy = type(module)(module.config).to(module.config.device) 47 | module_copy.load_state_dict(module.state_dict()) 48 | self.ema(module_copy) 49 | return module_copy 50 | 51 | def state_dict(self): 52 | return self.shadow 53 | 54 | def load_state_dict(self, state_dict): 55 | self.shadow = state_dict 56 | 57 | 58 | def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): 59 | def sigmoid(x): 60 | return 1 / (np.exp(-x) + 1) 61 | 62 | if beta_schedule == "quad": 63 | betas = (np.linspace(beta_start ** 0.5, beta_end ** 0.5, num_diffusion_timesteps, dtype=np.float64) ** 2) 64 | elif beta_schedule == "linear": 65 | betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) 66 | elif beta_schedule == "const": 67 | betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) 68 | elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 69 | betas = 1.0 / np.linspace(num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64) 70 | elif beta_schedule == "sigmoid": 71 | betas = np.linspace(-6, 6, num_diffusion_timesteps) 72 | betas = sigmoid(betas) * (beta_end - beta_start) + beta_start 73 | else: 74 | raise NotImplementedError(beta_schedule) 75 | assert betas.shape == (num_diffusion_timesteps,) 76 | return betas 77 | 78 | 79 | class Net(nn.Module): 80 | def __init__(self, args, config): 81 | super(Net, self).__init__() 82 | 83 | self.args = args 84 | self.config = config 85 | self.device = config.device 86 | 87 | self.Unet = DiffusionUNet(config) 88 | if self.args.mode == 'training': 89 | self.decom = self.load_stage1(CTDN(), 'ckpt/stage1') 90 | else: 91 | self.decom = CTDN() 92 | 93 | betas = get_beta_schedule( 94 | beta_schedule=config.diffusion.beta_schedule, 95 | beta_start=config.diffusion.beta_start, 96 | beta_end=config.diffusion.beta_end, 97 | num_diffusion_timesteps=config.diffusion.num_diffusion_timesteps, 98 | ) 99 | 100 | self.betas = torch.from_numpy(betas).float() 101 | self.num_timesteps = self.betas.shape[0] 102 | 103 | @staticmethod 104 | def compute_alpha(beta, t): 105 | beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0) 106 | a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1) 107 | return a 108 | 109 | @staticmethod 110 | def load_stage1(model, model_dir): 111 | checkpoint = utils.logging.load_checkpoint(os.path.join(model_dir, 'stage1_weight.pth.tar'), 'cuda') 112 | model.load_state_dict(checkpoint['model'], strict=True) 113 | return model 114 | 115 | def sample_training(self, x_cond, b, eta=0.): 116 | skip = self.config.diffusion.num_diffusion_timesteps // self.config.diffusion.num_sampling_timesteps 117 | seq = range(0, self.config.diffusion.num_diffusion_timesteps, skip) 118 | n, c, h, w = x_cond.shape 119 | seq_next = [-1] + list(seq[:-1]) 120 | x = torch.randn(n, c, h, w, device=self.device) 121 | xs = [x] 122 | for i, j in zip(reversed(seq), reversed(seq_next)): 123 | t = (torch.ones(n) * i).to(x.device) 124 | next_t = (torch.ones(n) * j).to(x.device) 125 | at = self.compute_alpha(b, t.long()) 126 | at_next = self.compute_alpha(b, next_t.long()) 127 | xt = xs[-1].to(x.device) 128 | 129 | et = self.Unet(torch.cat([x_cond, xt], dim=1), t) 130 | x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt() 131 | 132 | c1 = eta * ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt() 133 | c2 = ((1 - at_next) - c1 ** 2).sqrt() 134 | xt_next = at_next.sqrt() * x0_t + c1 * torch.randn_like(x) + c2 * et 135 | xs.append(xt_next.to(x.device)) 136 | 137 | return xs[-1] 138 | 139 | def forward(self, inputs): 140 | data_dict = {} 141 | 142 | b = self.betas.to(inputs.device) 143 | 144 | if self.training: 145 | output = self.decom(inputs, pred_fea=None) 146 | low_R, low_L, low_fea, high_L = output["low_R"], output["low_L"], \ 147 | output["low_fea"], output["high_L"] 148 | low_condition_norm = utils.data_transform(low_fea) 149 | 150 | t = torch.randint(low=0, high=self.num_timesteps, size=(low_condition_norm.shape[0] // 2 + 1,)).to( 151 | self.device) 152 | t = torch.cat([t, self.num_timesteps - t - 1], dim=0)[:low_condition_norm.shape[0]].to(inputs.device) 153 | a = (1 - b).cumprod(dim=0).index_select(0, t).view(-1, 1, 1, 1) 154 | 155 | e = torch.randn_like(low_condition_norm) 156 | 157 | high_input_norm = utils.data_transform(low_R * high_L) 158 | 159 | x = high_input_norm * a.sqrt() + e * (1.0 - a).sqrt() 160 | noise_output = self.Unet(torch.cat([low_condition_norm, x], dim=1), t.float()) 161 | 162 | pred_fea = self.sample_training(low_condition_norm, b) 163 | pred_fea = utils.inverse_data_transform(pred_fea) 164 | reference_fea = low_R * torch.pow(low_L, 0.2) 165 | 166 | data_dict["noise_output"] = noise_output 167 | data_dict["e"] = e 168 | 169 | data_dict["pred_fea"] = pred_fea 170 | data_dict["reference_fea"] = reference_fea 171 | 172 | else: 173 | output = self.decom(inputs, pred_fea=None) 174 | low_fea = output["low_fea"] 175 | low_condition_norm = utils.data_transform(low_fea) 176 | 177 | pred_fea = self.sample_training(low_condition_norm, b) 178 | pred_fea = utils.inverse_data_transform(pred_fea) 179 | pred_x = self.decom(inputs, pred_fea=pred_fea)["pred_img"] 180 | data_dict["pred_x"] = pred_x 181 | 182 | return data_dict 183 | 184 | 185 | class DenoisingDiffusion(object): 186 | def __init__(self, args, config): 187 | super().__init__() 188 | self.args = args 189 | self.config = config 190 | self.device = config.device 191 | 192 | self.model = Net(args, config) 193 | self.model.to(self.device) 194 | self.model = torch.nn.DataParallel(self.model, device_ids=range(torch.cuda.device_count())) 195 | 196 | self.ema_helper = EMAHelper() 197 | self.ema_helper.register(self.model) 198 | 199 | self.l2_loss = torch.nn.MSELoss() 200 | self.l1_loss = torch.nn.L1Loss() 201 | 202 | self.optimizer = utils.optimize.get_optimizer(self.config, self.model.parameters()) 203 | self.start_epoch, self.step = 0, 0 204 | 205 | def load_ddm_ckpt(self, load_path, ema=False): 206 | checkpoint = utils.logging.load_checkpoint(load_path, None) 207 | self.model.load_state_dict(checkpoint['state_dict'], strict=True) 208 | if ema: 209 | self.ema_helper.ema(self.model) 210 | print("=> loaded checkpoint {} step {}".format(load_path, self.step)) 211 | 212 | def train(self, DATASET): 213 | cudnn.benchmark = True 214 | train_loader, val_loader = DATASET.get_loaders() 215 | 216 | if os.path.isfile(self.args.resume): 217 | self.load_ddm_ckpt(self.args.resume) 218 | 219 | for name, param in self.model.named_parameters(): 220 | if "decom" in name: 221 | param.requires_grad = False 222 | else: 223 | param.requires_grad = True 224 | 225 | for epoch in range(self.start_epoch, self.config.training.n_epochs): 226 | print('epoch: ', epoch) 227 | data_start = time.time() 228 | data_time = 0 229 | for i, (x, y) in enumerate(train_loader): 230 | x = x.flatten(start_dim=0, end_dim=1) if x.ndim == 5 else x 231 | self.model.train() 232 | self.step += 1 233 | 234 | x = x.to(self.device) 235 | 236 | output = self.model(x) 237 | 238 | noise_loss, scc_loss = self.noise_estimation_loss(output) 239 | loss = noise_loss + scc_loss 240 | 241 | data_time += time.time() - data_start 242 | 243 | if self.step % 10 == 0: 244 | print("step:{}, noise_loss:{:.5f} scc_loss:{:.5f} time:{:.5f}". 245 | format(self.step, noise_loss.item(), 246 | scc_loss.item(), data_time / (i + 1))) 247 | 248 | self.optimizer.zero_grad() 249 | loss.backward() 250 | self.optimizer.step() 251 | self.ema_helper.update(self.model) 252 | data_start = time.time() 253 | 254 | if self.step % self.config.training.validation_freq == 0 and self.step != 0: 255 | self.model.eval() 256 | self.sample_validation_patches(val_loader, self.step) 257 | 258 | utils.logging.save_checkpoint({'step': self.step, 259 | 'epoch': epoch + 1, 260 | 'state_dict': self.model.state_dict(), 261 | 'optimizer': self.optimizer.state_dict(), 262 | 'ema_helper': self.ema_helper.state_dict(), 263 | 'params': self.args, 264 | 'config': self.config}, 265 | filename=os.path.join(self.config.data.ckpt_dir, 'model_latest')) 266 | 267 | def noise_estimation_loss(self, output): 268 | pred_fea, reference_fea = output["pred_fea"], output["reference_fea"] 269 | noise_output, e = output["noise_output"], output["e"] 270 | # ==================noise loss================== 271 | noise_loss = self.l2_loss(noise_output, e) 272 | # ==================scc loss================== 273 | scc_loss = 0.001 * self.l1_loss(pred_fea, reference_fea) 274 | 275 | return noise_loss, scc_loss 276 | 277 | def sample_validation_patches(self, val_loader, step): 278 | image_folder = os.path.join(self.args.image_folder, 279 | self.config.data.type + str(self.config.data.patch_size)) 280 | self.model.eval() 281 | 282 | with torch.no_grad(): 283 | print('Performing validation at step: {}'.format(step)) 284 | for i, (x, y) in enumerate(val_loader): 285 | b, _, img_h, img_w = x.shape 286 | 287 | img_h_64 = int(64 * np.ceil(img_h / 64.0)) 288 | img_w_64 = int(64 * np.ceil(img_w / 64.0)) 289 | x = F.pad(x, (0, img_w_64 - img_w, 0, img_h_64 - img_h), 'reflect') 290 | pred_x = self.model(x.to(self.device))["pred_x"][:, :, :img_h, :img_w] 291 | utils.logging.save_image(pred_x, os.path.join(image_folder, str(step), '{}'.format(y[0]))) 292 | -------------------------------------------------------------------------------- /models/decom.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import warnings 4 | import os 5 | import math 6 | import torch.nn.functional as F 7 | from einops import rearrange 8 | 9 | warnings.filterwarnings("ignore", category=UserWarning) 10 | warnings.filterwarnings("ignore", category=FutureWarning) 11 | 12 | 13 | class Depth_conv(nn.Module): 14 | def __init__(self, in_ch, out_ch): 15 | super(Depth_conv, self).__init__() 16 | self.depth_conv = nn.Conv2d( 17 | in_channels=in_ch, 18 | out_channels=in_ch, 19 | kernel_size=(3, 3), 20 | stride=(1, 1), 21 | padding=1, 22 | groups=in_ch 23 | ) 24 | self.point_conv = nn.Conv2d( 25 | in_channels=in_ch, 26 | out_channels=out_ch, 27 | kernel_size=(1, 1), 28 | stride=(1, 1), 29 | padding=0, 30 | groups=1 31 | ) 32 | 33 | def forward(self, input): 34 | out = self.depth_conv(input) 35 | out = self.point_conv(out) 36 | return out 37 | 38 | 39 | class Res_block(nn.Module): 40 | def __init__(self, in_channels, out_channels): 41 | super(Res_block, self).__init__() 42 | 43 | sequence = [] 44 | 45 | sequence += [ 46 | nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), stride=(1, 1), padding=1), 47 | nn.LeakyReLU(), 48 | nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), stride=(1, 1), padding=1) 49 | ] 50 | 51 | self.model = nn.Sequential(*sequence) 52 | 53 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1), stride=(1, 1), padding=0) 54 | 55 | def forward(self, x): 56 | out = self.model(x) + self.conv(x) 57 | 58 | return out 59 | 60 | 61 | class upsampling(nn.Module): 62 | def __init__(self, in_channels, out_channels): 63 | super(upsampling, self).__init__() 64 | 65 | self.conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, 66 | output_padding=1) 67 | 68 | self.relu = nn.LeakyReLU() 69 | 70 | def forward(self, x): 71 | out = self.relu(self.conv(x)) 72 | return out 73 | 74 | 75 | class channel_down(nn.Module): 76 | def __init__(self, channels): 77 | super(channel_down, self).__init__() 78 | 79 | self.conv0 = nn.Conv2d(channels * 4, channels * 2, kernel_size=(3, 3), stride=(1, 1), padding=1) 80 | self.conv1 = nn.Conv2d(channels * 2, channels, kernel_size=(3, 3), stride=(1, 1), padding=1) 81 | self.conv2 = nn.Conv2d(channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=1) 82 | 83 | self.relu = nn.LeakyReLU() 84 | 85 | def forward(self, x): 86 | out = torch.sigmoid(self.conv2(self.relu(self.conv1(self.relu(self.conv0(x)))))) 87 | 88 | return out 89 | 90 | 91 | class channel_up(nn.Module): 92 | def __init__(self, channels): 93 | super(channel_up, self).__init__() 94 | 95 | self.conv0 = nn.Conv2d(3, channels, kernel_size=(3, 3), stride=(1, 1), padding=1) 96 | self.conv1 = nn.Conv2d(channels, channels * 2, kernel_size=(3, 3), stride=(1, 1), padding=1) 97 | self.conv2 = nn.Conv2d(channels * 2, channels * 4, kernel_size=(3, 3), stride=(1, 1), padding=1) 98 | 99 | self.relu = nn.LeakyReLU() 100 | 101 | def forward(self, x): 102 | out = self.conv2(self.relu(self.conv1(self.relu(self.conv0(x))))) 103 | 104 | return out 105 | 106 | 107 | class feature_pyramid(nn.Module): 108 | def __init__(self, channels): 109 | super(feature_pyramid, self).__init__() 110 | 111 | self.convs = nn.Sequential(nn.Conv2d(3, channels, kernel_size=(5, 5), stride=(1, 1), padding=2), 112 | nn.Conv2d(channels, channels, kernel_size=(5, 5), stride=(1, 1), padding=2)) 113 | 114 | self.block0 = Res_block(channels, channels) 115 | 116 | self.down0 = nn.Conv2d(channels, channels, kernel_size=(3, 3), stride=(2, 2), padding=1) 117 | 118 | self.block1 = Res_block(channels, channels * 2) 119 | 120 | self.down1 = nn.Conv2d(channels * 2, channels * 2, kernel_size=(3, 3), stride=(2, 2), padding=1) 121 | 122 | self.block2 = Res_block(channels * 2, channels * 4) 123 | 124 | self.down2 = nn.Conv2d(channels * 4, channels * 4, kernel_size=(3, 3), stride=(2, 2), padding=1) 125 | 126 | self.relu = nn.LeakyReLU() 127 | 128 | def forward(self, x): 129 | 130 | level0 = self.down0(self.block0(self.convs(x))) 131 | level1 = self.down1(self.block1(level0)) 132 | level2 = self.down2(self.block2(level1)) 133 | 134 | return level0, level1, level2 135 | 136 | 137 | class ReconNet(nn.Module): 138 | def __init__(self, channels): 139 | super(ReconNet, self).__init__() 140 | 141 | self.pyramid = feature_pyramid(channels) 142 | 143 | self.channel_down = channel_down(channels) 144 | self.channel_up = channel_up(channels) 145 | 146 | self.block_up0 = Res_block(channels * 4, channels * 4) 147 | self.block_up1 = Res_block(channels * 4, channels * 4) 148 | self.up_sampling0 = upsampling(channels * 4, channels * 2) 149 | self.block_up2 = Res_block(channels * 2, channels * 2) 150 | self.block_up3 = Res_block(channels * 2, channels * 2) 151 | self.up_sampling1 = upsampling(channels * 2, channels) 152 | self.block_up4 = Res_block(channels, channels) 153 | self.block_up5 = Res_block(channels, channels) 154 | self.up_sampling2 = upsampling(channels, channels) 155 | 156 | self.conv2 = nn.Conv2d(channels, channels, kernel_size=(3, 3), stride=(1, 1), padding=1) 157 | self.conv3 = nn.Conv2d(channels, 3, kernel_size=(1, 1), stride=(1, 1), padding=0) 158 | 159 | self.relu = nn.LeakyReLU() 160 | 161 | def forward(self, x, pred_fea=None): 162 | 163 | if pred_fea is None: 164 | low_fea_down2, low_fea_down4, low_fea_down8 = self.pyramid(x[:, :3, ...]) 165 | low_fea_down8 = self.channel_down(low_fea_down8) 166 | 167 | high_fea_down2, high_fea_down4, high_fea_down8 = self.pyramid(x[:, 3:, ...]) 168 | high_fea_down8 = self.channel_down(high_fea_down8) 169 | 170 | return low_fea_down8, high_fea_down8 171 | else: 172 | # =================low ori decoder================= 173 | low_fea_down2, low_fea_down4, low_fea_down8 = self.pyramid(x[:, :3, ...]) 174 | 175 | pred_fea = self.channel_up(pred_fea) 176 | 177 | pred_fea_up2 = self.up_sampling0( 178 | self.block_up1(self.block_up0(pred_fea) + low_fea_down8)) 179 | pred_fea_up4 = self.up_sampling1( 180 | self.block_up3(self.block_up2(pred_fea_up2) + low_fea_down4)) 181 | pred_fea_up8 = self.up_sampling2( 182 | self.block_up5(self.block_up4(pred_fea_up4) + low_fea_down2)) 183 | 184 | pred_img = self.conv3(self.relu(self.conv2(pred_fea_up8))) 185 | 186 | return pred_img 187 | 188 | 189 | class Self_Attention(nn.Module): 190 | def __init__(self, dim, num_heads, bias): 191 | super(Self_Attention, self).__init__() 192 | self.num_heads = num_heads 193 | self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=(1, 1), bias=bias) 194 | self.qkv_dwconv = nn.Conv2d(dim * 3, dim * 3, kernel_size=(3, 3), stride=(1, 1), 195 | padding=1, groups=dim * 3, bias=bias) 196 | self.project_out = nn.Conv2d(dim, dim, kernel_size=(1, 1), bias=bias) 197 | 198 | def forward(self, x): 199 | b, c, h, w = x.shape 200 | 201 | qkv = self.qkv_dwconv(self.qkv(x)) 202 | q, k, v = qkv.chunk(3, dim=1) 203 | 204 | q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 205 | k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 206 | v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 207 | 208 | q = torch.nn.functional.normalize(q, dim=-1) 209 | k = torch.nn.functional.normalize(k, dim=-1) 210 | 211 | attn = (q @ k.transpose(-2, -1)) 212 | attn = attn.softmax(dim=-1) 213 | 214 | out = (attn @ v) 215 | 216 | out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) 217 | 218 | out = self.project_out(out) 219 | return out 220 | 221 | 222 | class Cross_Attention(nn.Module): 223 | def __init__(self, dim, num_heads, dropout=0.): 224 | super(Cross_Attention, self).__init__() 225 | if dim % num_heads != 0: 226 | raise ValueError( 227 | "The hidden size (%d) is not a multiple of the number of attention " 228 | "heads (%d)" % (dim, num_heads) 229 | ) 230 | self.num_heads = num_heads 231 | self.attention_head_size = int(dim / num_heads) 232 | 233 | self.query = Depth_conv(in_ch=dim, out_ch=dim) 234 | self.key = Depth_conv(in_ch=dim, out_ch=dim) 235 | self.value = Depth_conv(in_ch=dim, out_ch=dim) 236 | 237 | self.dropout = nn.Dropout(dropout) 238 | 239 | def transpose_for_scores(self, x): 240 | ''' 241 | new_x_shape = x.size()[:-1] + ( 242 | self.num_heads, 243 | self.attention_head_size, 244 | ) 245 | print(new_x_shape) 246 | x = x.view(*new_x_shape) 247 | ''' 248 | return x.permute(0, 2, 1, 3) 249 | 250 | def forward(self, hidden_states, ctx): 251 | mixed_query_layer = self.query(hidden_states) 252 | mixed_key_layer = self.key(ctx) 253 | mixed_value_layer = self.value(ctx) 254 | 255 | query_layer = self.transpose_for_scores(mixed_query_layer) 256 | key_layer = self.transpose_for_scores(mixed_key_layer) 257 | value_layer = self.transpose_for_scores(mixed_value_layer) 258 | 259 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 260 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 261 | 262 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 263 | 264 | attention_probs = self.dropout(attention_probs) 265 | 266 | ctx_layer = torch.matmul(attention_probs, value_layer) 267 | ctx_layer = ctx_layer.permute(0, 2, 1, 3).contiguous() 268 | 269 | return ctx_layer 270 | 271 | 272 | class Retinex_decom(nn.Module): 273 | def __init__(self, channels): 274 | super(Retinex_decom, self).__init__() 275 | 276 | self.conv0 = nn.Conv2d(3, channels, kernel_size=(3, 3), stride=(1, 1), padding=1) 277 | self.blocks0 = nn.Sequential(Res_block(channels, channels), 278 | Res_block(channels, channels)) 279 | 280 | self.conv1 = nn.Conv2d(1, channels, kernel_size=(3, 3), stride=(1, 1), padding=1) 281 | self.blocks1 = nn.Sequential(Res_block(channels, channels), 282 | Res_block(channels, channels)) 283 | 284 | self.cross_attention = Cross_Attention(dim=channels, num_heads=8) 285 | self.self_attention = Self_Attention(dim=channels, num_heads=8, bias=True) 286 | 287 | self.conv0_1 = nn.Sequential(Res_block(channels, channels), 288 | nn.Conv2d(channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=1)) 289 | self.conv1_1 = nn.Sequential(Res_block(channels, channels), 290 | nn.Conv2d(channels, 1, kernel_size=(3, 3), stride=(1, 1), padding=1)) 291 | 292 | def forward(self, x): 293 | init_illumination = torch.max(x, dim=1, keepdim=True)[0] 294 | init_reflectance = x / init_illumination 295 | 296 | Reflectance, Illumination = (self.blocks0(self.conv0(init_reflectance)), 297 | self.blocks1(self.conv1(init_illumination))) 298 | 299 | Reflectance_final = self.cross_attention(Illumination, Reflectance) 300 | 301 | Illumination_content = self.self_attention(Illumination) 302 | 303 | Reflectance_final = self.conv0_1(Reflectance_final + Illumination_content) 304 | Illumination_final = self.conv1_1(Illumination - Illumination_content) 305 | 306 | R = torch.sigmoid(Reflectance_final) 307 | L = torch.sigmoid(Illumination_final) 308 | L = torch.cat([L for i in range(3)], dim=1) 309 | 310 | return R, L 311 | 312 | 313 | class CTDN(nn.Module): 314 | def __init__(self, channels=64): 315 | super(CTDN, self).__init__() 316 | 317 | self.ReconNet = ReconNet(channels) 318 | self.retinex = Retinex_decom(channels) 319 | 320 | def forward(self, images, pred_fea=None): 321 | 322 | output = {} 323 | # =================decomposition low================= 324 | if pred_fea is None: 325 | low_fea_down8, high_fea_down8 = self.ReconNet(images, pred_fea=None) 326 | 327 | low_R, low_L = self.retinex(low_fea_down8) 328 | high_R, high_L = self.retinex(high_fea_down8) 329 | 330 | output["low_R"] = low_R 331 | output["low_L"] = low_L 332 | output["low_fea"] = low_fea_down8 333 | output["high_R"] = high_R 334 | output["high_L"] = high_L 335 | output["high_fea"] = high_fea_down8 336 | 337 | else: 338 | pred_img = self.ReconNet(images[:, :3, ...], pred_fea=pred_fea) 339 | output["pred_img"] = pred_img 340 | 341 | return output 342 | -------------------------------------------------------------------------------- /models/unet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional 5 | 6 | # This script is from the following repositories 7 | # https://github.com/ermongroup/ddim 8 | # https://github.com/bahjat-kawar/ddrm 9 | 10 | 11 | def get_timestep_embedding(timesteps, embedding_dim): 12 | """ 13 | This matches the implementation in Denoising Diffusion Probabilistic Models: 14 | From Fairseq. 15 | Build sinusoidal embeddings. 16 | This matches the implementation in tensor2tensor, but differs slightly 17 | from the description in Section 3.5 of "Attention Is All You Need". 18 | """ 19 | assert len(timesteps.shape) == 1 20 | 21 | half_dim = embedding_dim // 2 22 | emb = math.log(10000) / (half_dim - 1) 23 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) 24 | emb = emb.to(device=timesteps.device) 25 | emb = timesteps.float()[:, None] * emb[None, :] 26 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) 27 | if embedding_dim % 2 == 1: # zero pad 28 | emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) 29 | return emb 30 | 31 | 32 | def nonlinearity(x): 33 | # swish 34 | return x*torch.sigmoid(x) 35 | 36 | 37 | def Normalize(in_channels): 38 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 39 | 40 | 41 | class Upsample(nn.Module): 42 | def __init__(self, in_channels, with_conv): 43 | super().__init__() 44 | self.with_conv = with_conv 45 | if self.with_conv: 46 | self.conv = torch.nn.Conv2d(in_channels, 47 | in_channels, 48 | kernel_size=3, 49 | stride=1, 50 | padding=1) 51 | 52 | def forward(self, x): 53 | x = torch.nn.functional.interpolate( 54 | x, scale_factor=2.0, mode="nearest") 55 | if self.with_conv: 56 | x = self.conv(x) 57 | return x 58 | 59 | 60 | class Downsample(nn.Module): 61 | def __init__(self, in_channels, with_conv): 62 | super().__init__() 63 | self.with_conv = with_conv 64 | if self.with_conv: 65 | # no asymmetric padding in torch conv, must do it ourselves 66 | self.conv = torch.nn.Conv2d(in_channels, 67 | in_channels, 68 | kernel_size=3, 69 | stride=2, 70 | padding=0) 71 | 72 | def forward(self, x): 73 | if self.with_conv: 74 | pad = (0, 1, 0, 1) 75 | x = torch.nn.functional.pad(x, pad, mode="constant", value=0) 76 | x = self.conv(x) 77 | else: 78 | x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) 79 | return x 80 | 81 | 82 | class ResnetBlock(nn.Module): 83 | def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, 84 | dropout, temb_channels=512): 85 | super().__init__() 86 | self.in_channels = in_channels 87 | out_channels = in_channels if out_channels is None else out_channels 88 | self.out_channels = out_channels 89 | self.use_conv_shortcut = conv_shortcut 90 | 91 | self.norm1 = Normalize(in_channels) 92 | self.conv1 = torch.nn.Conv2d(in_channels, 93 | out_channels, 94 | kernel_size=3, 95 | stride=1, 96 | padding=1) 97 | self.temb_proj = torch.nn.Linear(temb_channels, 98 | out_channels) 99 | self.norm2 = Normalize(out_channels) 100 | self.dropout = torch.nn.Dropout(dropout) 101 | self.conv2 = torch.nn.Conv2d(out_channels, 102 | out_channels, 103 | kernel_size=3, 104 | stride=1, 105 | padding=1) 106 | if self.in_channels != self.out_channels: 107 | if self.use_conv_shortcut: 108 | self.conv_shortcut = torch.nn.Conv2d(in_channels, 109 | out_channels, 110 | kernel_size=3, 111 | stride=1, 112 | padding=1) 113 | else: 114 | self.nin_shortcut = torch.nn.Conv2d(in_channels, 115 | out_channels, 116 | kernel_size=1, 117 | stride=1, 118 | padding=0) 119 | 120 | def forward(self, x, temb): 121 | h = x 122 | h = self.norm1(h) 123 | h = nonlinearity(h) 124 | h = self.conv1(h) 125 | 126 | h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] 127 | 128 | h = self.norm2(h) 129 | h = nonlinearity(h) 130 | h = self.dropout(h) 131 | h = self.conv2(h) 132 | 133 | if self.in_channels != self.out_channels: 134 | if self.use_conv_shortcut: 135 | x = self.conv_shortcut(x) 136 | else: 137 | x = self.nin_shortcut(x) 138 | 139 | return x+h 140 | 141 | 142 | class AttnBlock(nn.Module): 143 | def __init__(self, in_channels): 144 | super().__init__() 145 | self.in_channels = in_channels 146 | 147 | self.norm = Normalize(in_channels) 148 | self.q = torch.nn.Conv2d(in_channels, 149 | in_channels, 150 | kernel_size=1, 151 | stride=1, 152 | padding=0) 153 | self.k = torch.nn.Conv2d(in_channels, 154 | in_channels, 155 | kernel_size=1, 156 | stride=1, 157 | padding=0) 158 | self.v = torch.nn.Conv2d(in_channels, 159 | in_channels, 160 | kernel_size=1, 161 | stride=1, 162 | padding=0) 163 | self.proj_out = torch.nn.Conv2d(in_channels, 164 | in_channels, 165 | kernel_size=1, 166 | stride=1, 167 | padding=0) 168 | 169 | def forward(self, x): 170 | h_ = x 171 | h_ = self.norm(h_) 172 | q = self.q(h_) 173 | k = self.k(h_) 174 | v = self.v(h_) 175 | 176 | # compute attention 177 | b, c, h, w = q.shape 178 | q = q.reshape(b, c, h*w) 179 | q = q.permute(0, 2, 1) # b,hw,c 180 | k = k.reshape(b, c, h*w) # b,c,hw 181 | w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] 182 | w_ = w_ * (int(c)**(-0.5)) 183 | w_ = torch.nn.functional.softmax(w_, dim=2) 184 | 185 | # attend to values 186 | v = v.reshape(b, c, h*w) 187 | w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) 188 | # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] 189 | h_ = torch.bmm(v, w_) 190 | h_ = h_.reshape(b, c, h, w) 191 | 192 | h_ = self.proj_out(h_) 193 | 194 | return x+h_ 195 | 196 | 197 | class DiffusionUNet(nn.Module): 198 | def __init__(self, config): 199 | super().__init__() 200 | self.config = config 201 | ch, out_ch, ch_mult = config.model.ch, config.model.out_ch, tuple(config.model.ch_mult) 202 | num_res_blocks = config.model.num_res_blocks 203 | dropout = config.model.dropout 204 | in_channels = config.model.in_channels * 2 if config.data.conditional else config.model.in_channels 205 | resamp_with_conv = config.model.resamp_with_conv 206 | 207 | self.ch = ch 208 | self.temb_ch = self.ch*4 209 | self.num_resolutions = len(ch_mult) 210 | self.num_res_blocks = num_res_blocks 211 | self.in_channels = in_channels 212 | 213 | # timestep embedding 214 | self.temb = nn.Module() 215 | self.temb.dense = nn.ModuleList([ 216 | torch.nn.Linear(self.ch, 217 | self.temb_ch), 218 | torch.nn.Linear(self.temb_ch, 219 | self.temb_ch), 220 | ]) 221 | 222 | # downsampling 223 | self.conv_in = torch.nn.Conv2d(in_channels, 224 | self.ch, 225 | kernel_size=3, 226 | stride=1, 227 | padding=1) 228 | 229 | in_ch_mult = (1,)+ch_mult 230 | self.down = nn.ModuleList() 231 | block_in = None 232 | for i_level in range(self.num_resolutions): 233 | block = nn.ModuleList() 234 | attn = nn.ModuleList() 235 | block_in = ch*in_ch_mult[i_level] 236 | block_out = ch*ch_mult[i_level] 237 | for i_block in range(self.num_res_blocks): 238 | block.append(ResnetBlock(in_channels=block_in, 239 | out_channels=block_out, 240 | temb_channels=self.temb_ch, 241 | dropout=dropout)) 242 | block_in = block_out 243 | if i_level == 2: 244 | attn.append(AttnBlock(block_in)) 245 | down = nn.Module() 246 | down.block = block 247 | down.attn = attn 248 | if i_level != self.num_resolutions-1: 249 | down.downsample = Downsample(block_in, resamp_with_conv) 250 | self.down.append(down) 251 | 252 | # middle 253 | self.mid = nn.Module() 254 | self.mid.block_1 = ResnetBlock(in_channels=block_in, 255 | out_channels=block_in, 256 | temb_channels=self.temb_ch, 257 | dropout=dropout) 258 | self.mid.attn_1 = AttnBlock(block_in) 259 | self.mid.block_2 = ResnetBlock(in_channels=block_in, 260 | out_channels=block_in, 261 | temb_channels=self.temb_ch, 262 | dropout=dropout) 263 | 264 | # upsampling 265 | self.up = nn.ModuleList() 266 | for i_level in reversed(range(self.num_resolutions)): 267 | block = nn.ModuleList() 268 | attn = nn.ModuleList() 269 | block_out = ch*ch_mult[i_level] 270 | skip_in = ch*ch_mult[i_level] 271 | for i_block in range(self.num_res_blocks+1): 272 | if i_block == self.num_res_blocks: 273 | skip_in = ch*in_ch_mult[i_level] 274 | block.append(ResnetBlock(in_channels=block_in+skip_in, 275 | out_channels=block_out, 276 | temb_channels=self.temb_ch, 277 | dropout=dropout)) 278 | block_in = block_out 279 | if i_level == 2: 280 | attn.append(AttnBlock(block_in)) 281 | up = nn.Module() 282 | up.block = block 283 | up.attn = attn 284 | if i_level != 0: 285 | up.upsample = Upsample(block_in, resamp_with_conv) 286 | self.up.insert(0, up) # prepend to get consistent order 287 | 288 | # end 289 | self.norm_out = Normalize(block_in) 290 | self.conv_out = torch.nn.Conv2d(block_in, 291 | out_ch, 292 | kernel_size=3, 293 | stride=1, 294 | padding=1) 295 | 296 | def forward(self, x, t): 297 | # assert x.shape[2] == x.shape[3] == self.resolution 298 | 299 | # timestep embedding 300 | temb = get_timestep_embedding(t, self.ch) 301 | temb = self.temb.dense[0](temb) 302 | temb = nonlinearity(temb) 303 | temb = self.temb.dense[1](temb) 304 | 305 | # downsampling 306 | hs = [self.conv_in(x)] 307 | for i_level in range(self.num_resolutions): 308 | for i_block in range(self.num_res_blocks): 309 | h = self.down[i_level].block[i_block](hs[-1], temb) 310 | if len(self.down[i_level].attn) > 0: 311 | h = self.down[i_level].attn[i_block](h) 312 | hs.append(h) 313 | if i_level != self.num_resolutions-1: 314 | hs.append(self.down[i_level].downsample(hs[-1])) 315 | 316 | # middle 317 | h = hs[-1] 318 | h = self.mid.block_1(h, temb) 319 | h = self.mid.attn_1(h) 320 | h = self.mid.block_2(h, temb) 321 | 322 | # upsampling 323 | for i_level in reversed(range(self.num_resolutions)): 324 | for i_block in range(self.num_res_blocks+1): 325 | h = self.up[i_level].block[i_block]( 326 | torch.cat([h, hs.pop()], dim=1), temb) 327 | if len(self.up[i_level].attn) > 0: 328 | h = self.up[i_level].attn[i_block](h) 329 | if i_level != 0: 330 | h = self.up[i_level].upsample(h) 331 | 332 | # end 333 | h = self.norm_out(h) 334 | h = nonlinearity(h) 335 | h = self.conv_out(h) 336 | return h 337 | --------------------------------------------------------------------------------