├── datasets ├── scratch │ └── LLIE │ │ ├── ckpts │ │ └── pretrained model.txt │ │ └── data │ │ ├── lowlight │ │ └── test │ │ │ ├── gt │ │ │ ├── 1.png │ │ │ ├── 22.png │ │ │ ├── 23.png │ │ │ ├── 55.png │ │ │ ├── 79.png │ │ │ ├── 111.png │ │ │ ├── 146.png │ │ │ ├── 179.png │ │ │ ├── 493.png │ │ │ ├── 547.png │ │ │ ├── 665.png │ │ │ ├── 669.png │ │ │ ├── 748.png │ │ │ ├── 778.png │ │ │ └── 780.png │ │ │ ├── input │ │ │ ├── 1.png │ │ │ ├── 111.png │ │ │ ├── 146.png │ │ │ ├── 179.png │ │ │ ├── 22.png │ │ │ ├── 23.png │ │ │ ├── 493.png │ │ │ ├── 547.png │ │ │ ├── 55.png │ │ │ ├── 665.png │ │ │ ├── 669.png │ │ │ ├── 748.png │ │ │ ├── 778.png │ │ │ ├── 780.png │ │ │ └── 79.png │ │ │ └── lowlighttesta.txt │ │ └── read.py ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── lowlight.cpython-38.pyc │ ├── raindrop.cpython-38.pyc │ ├── raindrop1.cpython-38.pyc │ ├── snow100k.cpython-38.pyc │ ├── allweather.cpython-38.pyc │ └── outdoorrain.cpython-38.pyc └── lowlight.py ├── core ├── figs ├── com.png ├── v1.png ├── vis.png └── unpaired.png ├── models ├── __init__.py ├── __pycache__ │ ├── ddm.cpython-38.pyc │ ├── ddm1.cpython-38.pyc │ ├── unet.cpython-38.pyc │ ├── unet1.cpython-38.pyc │ ├── unet2.cpython-38.pyc │ ├── unet6.cpython-38.pyc │ ├── __init__.cpython-38.pyc │ ├── restoration.cpython-38.pyc │ └── restoration1.cpython-38.pyc ├── restoration.py ├── ddm.py └── unet.py ├── utils ├── __init__.py ├── __pycache__ │ ├── logging.cpython-38.pyc │ ├── metrics.cpython-38.pyc │ ├── __init__.cpython-38.pyc │ ├── optimize.cpython-38.pyc │ └── sampling.cpython-38.pyc ├── logging.py ├── optimize.py ├── metrics.py └── sampling.py ├── results └── images │ └── lowlight │ └── lowlight │ ├── ['1'].png │ ├── ['111'].png │ ├── ['146'].png │ ├── ['179'].png │ ├── ['22'].png │ ├── ['23'].png │ ├── ['493'].png │ ├── ['547'].png │ ├── ['55'].png │ ├── ['665'].png │ ├── ['669'].png │ ├── ['748'].png │ ├── ['778'].png │ ├── ['780'].png │ └── ['79'].png ├── requirements.txt ├── configs └── lowlight.yml ├── train_diffusion.py ├── eval_diffusion.py ├── README.md └── evaluation.py /datasets/scratch/LLIE/ckpts/pretrained model.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /core: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/core -------------------------------------------------------------------------------- /figs/com.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/figs/com.png -------------------------------------------------------------------------------- /figs/v1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/figs/v1.png -------------------------------------------------------------------------------- /figs/vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/figs/vis.png -------------------------------------------------------------------------------- /figs/unpaired.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/figs/unpaired.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from models.ddm import * 2 | from models.restoration import * 3 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from datasets.lowlight import * 2 | __all__ = ["AllWeather","lowlight"] 3 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from utils.logging import * 2 | from utils.sampling import * 3 | from utils.optimize import * 4 | -------------------------------------------------------------------------------- /models/__pycache__/ddm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/models/__pycache__/ddm.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/ddm1.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/models/__pycache__/ddm1.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/unet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/models/__pycache__/unet.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/unet1.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/models/__pycache__/unet1.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/unet2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/models/__pycache__/unet2.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/unet6.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/models/__pycache__/unet6.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/logging.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/utils/__pycache__/logging.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/metrics.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/utils/__pycache__/metrics.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /results/images/lowlight/lowlight/['1'].png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/results/images/lowlight/lowlight/['1'].png -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/optimize.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/utils/__pycache__/optimize.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/sampling.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/utils/__pycache__/sampling.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/lowlight.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/__pycache__/lowlight.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/raindrop.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/__pycache__/raindrop.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/raindrop1.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/__pycache__/raindrop1.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/snow100k.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/__pycache__/snow100k.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/restoration.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/models/__pycache__/restoration.cpython-38.pyc -------------------------------------------------------------------------------- /results/images/lowlight/lowlight/['111'].png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/results/images/lowlight/lowlight/['111'].png -------------------------------------------------------------------------------- /results/images/lowlight/lowlight/['146'].png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/results/images/lowlight/lowlight/['146'].png -------------------------------------------------------------------------------- /results/images/lowlight/lowlight/['179'].png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/results/images/lowlight/lowlight/['179'].png -------------------------------------------------------------------------------- /results/images/lowlight/lowlight/['22'].png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/results/images/lowlight/lowlight/['22'].png -------------------------------------------------------------------------------- /results/images/lowlight/lowlight/['23'].png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/results/images/lowlight/lowlight/['23'].png -------------------------------------------------------------------------------- /results/images/lowlight/lowlight/['493'].png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/results/images/lowlight/lowlight/['493'].png -------------------------------------------------------------------------------- /results/images/lowlight/lowlight/['547'].png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/results/images/lowlight/lowlight/['547'].png -------------------------------------------------------------------------------- /results/images/lowlight/lowlight/['55'].png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/results/images/lowlight/lowlight/['55'].png -------------------------------------------------------------------------------- /results/images/lowlight/lowlight/['665'].png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/results/images/lowlight/lowlight/['665'].png -------------------------------------------------------------------------------- /results/images/lowlight/lowlight/['669'].png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/results/images/lowlight/lowlight/['669'].png -------------------------------------------------------------------------------- /results/images/lowlight/lowlight/['748'].png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/results/images/lowlight/lowlight/['748'].png -------------------------------------------------------------------------------- /results/images/lowlight/lowlight/['778'].png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/results/images/lowlight/lowlight/['778'].png -------------------------------------------------------------------------------- /results/images/lowlight/lowlight/['780'].png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/results/images/lowlight/lowlight/['780'].png -------------------------------------------------------------------------------- /results/images/lowlight/lowlight/['79'].png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/results/images/lowlight/lowlight/['79'].png -------------------------------------------------------------------------------- /datasets/__pycache__/allweather.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/__pycache__/allweather.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/outdoorrain.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/__pycache__/outdoorrain.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/restoration1.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/models/__pycache__/restoration1.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/scratch/LLIE/data/lowlight/test/gt/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/scratch/LLIE/data/lowlight/test/gt/1.png -------------------------------------------------------------------------------- /datasets/scratch/LLIE/data/lowlight/test/gt/22.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/scratch/LLIE/data/lowlight/test/gt/22.png -------------------------------------------------------------------------------- /datasets/scratch/LLIE/data/lowlight/test/gt/23.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/scratch/LLIE/data/lowlight/test/gt/23.png -------------------------------------------------------------------------------- /datasets/scratch/LLIE/data/lowlight/test/gt/55.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/scratch/LLIE/data/lowlight/test/gt/55.png -------------------------------------------------------------------------------- /datasets/scratch/LLIE/data/lowlight/test/gt/79.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/scratch/LLIE/data/lowlight/test/gt/79.png -------------------------------------------------------------------------------- /datasets/scratch/LLIE/data/lowlight/test/gt/111.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/scratch/LLIE/data/lowlight/test/gt/111.png -------------------------------------------------------------------------------- /datasets/scratch/LLIE/data/lowlight/test/gt/146.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/scratch/LLIE/data/lowlight/test/gt/146.png -------------------------------------------------------------------------------- /datasets/scratch/LLIE/data/lowlight/test/gt/179.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/scratch/LLIE/data/lowlight/test/gt/179.png -------------------------------------------------------------------------------- /datasets/scratch/LLIE/data/lowlight/test/gt/493.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/scratch/LLIE/data/lowlight/test/gt/493.png -------------------------------------------------------------------------------- /datasets/scratch/LLIE/data/lowlight/test/gt/547.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/scratch/LLIE/data/lowlight/test/gt/547.png -------------------------------------------------------------------------------- /datasets/scratch/LLIE/data/lowlight/test/gt/665.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/scratch/LLIE/data/lowlight/test/gt/665.png -------------------------------------------------------------------------------- /datasets/scratch/LLIE/data/lowlight/test/gt/669.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/scratch/LLIE/data/lowlight/test/gt/669.png -------------------------------------------------------------------------------- /datasets/scratch/LLIE/data/lowlight/test/gt/748.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/scratch/LLIE/data/lowlight/test/gt/748.png -------------------------------------------------------------------------------- /datasets/scratch/LLIE/data/lowlight/test/gt/778.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/scratch/LLIE/data/lowlight/test/gt/778.png -------------------------------------------------------------------------------- /datasets/scratch/LLIE/data/lowlight/test/gt/780.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/scratch/LLIE/data/lowlight/test/gt/780.png -------------------------------------------------------------------------------- /datasets/scratch/LLIE/data/lowlight/test/input/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/scratch/LLIE/data/lowlight/test/input/1.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.21.5 2 | opencv_python==4.5.5.64 3 | Pillow==8.4.0 4 | PyYAML==6.0 5 | torch==1.10.0 6 | torchvision==0.11.1 7 | tqdm==4.61.2 8 | -------------------------------------------------------------------------------- /datasets/scratch/LLIE/data/lowlight/test/input/111.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/scratch/LLIE/data/lowlight/test/input/111.png -------------------------------------------------------------------------------- /datasets/scratch/LLIE/data/lowlight/test/input/146.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/scratch/LLIE/data/lowlight/test/input/146.png -------------------------------------------------------------------------------- /datasets/scratch/LLIE/data/lowlight/test/input/179.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/scratch/LLIE/data/lowlight/test/input/179.png -------------------------------------------------------------------------------- /datasets/scratch/LLIE/data/lowlight/test/input/22.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/scratch/LLIE/data/lowlight/test/input/22.png -------------------------------------------------------------------------------- /datasets/scratch/LLIE/data/lowlight/test/input/23.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/scratch/LLIE/data/lowlight/test/input/23.png -------------------------------------------------------------------------------- /datasets/scratch/LLIE/data/lowlight/test/input/493.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/scratch/LLIE/data/lowlight/test/input/493.png -------------------------------------------------------------------------------- /datasets/scratch/LLIE/data/lowlight/test/input/547.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/scratch/LLIE/data/lowlight/test/input/547.png -------------------------------------------------------------------------------- /datasets/scratch/LLIE/data/lowlight/test/input/55.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/scratch/LLIE/data/lowlight/test/input/55.png -------------------------------------------------------------------------------- /datasets/scratch/LLIE/data/lowlight/test/input/665.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/scratch/LLIE/data/lowlight/test/input/665.png -------------------------------------------------------------------------------- /datasets/scratch/LLIE/data/lowlight/test/input/669.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/scratch/LLIE/data/lowlight/test/input/669.png -------------------------------------------------------------------------------- /datasets/scratch/LLIE/data/lowlight/test/input/748.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/scratch/LLIE/data/lowlight/test/input/748.png -------------------------------------------------------------------------------- /datasets/scratch/LLIE/data/lowlight/test/input/778.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/scratch/LLIE/data/lowlight/test/input/778.png -------------------------------------------------------------------------------- /datasets/scratch/LLIE/data/lowlight/test/input/780.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/scratch/LLIE/data/lowlight/test/input/780.png -------------------------------------------------------------------------------- /datasets/scratch/LLIE/data/lowlight/test/input/79.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/scratch/LLIE/data/lowlight/test/input/79.png -------------------------------------------------------------------------------- /utils/logging.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import shutil 3 | import os 4 | import torchvision.utils as tvu 5 | 6 | 7 | def save_image(img, file_directory): 8 | if not os.path.exists(os.path.dirname(file_directory)): 9 | os.makedirs(os.path.dirname(file_directory)) 10 | tvu.save_image(img, file_directory) 11 | 12 | 13 | def save_checkpoint(state, filename): 14 | if not os.path.exists(os.path.dirname(filename)): 15 | os.makedirs(os.path.dirname(filename)) 16 | torch.save(state, filename + '.pth.tar') 17 | 18 | 19 | def load_checkpoint(path, device): 20 | if device is None: 21 | return torch.load(path) 22 | else: 23 | return torch.load(path, map_location=device) 24 | -------------------------------------------------------------------------------- /datasets/scratch/LLIE/data/read.py: -------------------------------------------------------------------------------- 1 | import os 2 | import codecs # 读取文件夹中的文件名 3 | 4 | folder = '/home/shangkai/MDMS/MDMS-main/datasets/scratch/LLIE/data/lowlight/test/input' 5 | filenames = os.listdir(folder) 6 | 7 | # 将文件名写入 txt 文件 8 | txt_file = r'/home/shangkai/MDMS/MDMS-main/datasets/scratch/LLIE/data/lowlight/test/lowlighttesta.txt' 9 | with codecs.open(txt_file, 'w', 'utf-8') as f: 10 | for filename in filenames: 11 | filepath = os.path.dirname(__file__) 12 | filepath = os.path.join(filepath, '/lowlight/test/input') 13 | filename = '/home/shangkai/MDMS/MDMS-main/datasets/scratch/LLIE/data/lowlight/test/input/'+filename 14 | 15 | f.write(filename + '\n') 16 | -------------------------------------------------------------------------------- /utils/optimize.py: -------------------------------------------------------------------------------- 1 | import torch.optim as optim 2 | 3 | 4 | def get_optimizer(config, parameters): 5 | if config.optim.optimizer == 'Adam': 6 | return optim.Adam(parameters, lr=config.optim.lr, weight_decay=config.optim.weight_decay, 7 | betas=(0.9, 0.999), amsgrad=config.optim.amsgrad, eps=config.optim.eps) 8 | elif config.optim.optimizer == 'RMSProp': 9 | return optim.RMSprop(parameters, lr=config.optim.lr, weight_decay=config.optim.weight_decay) 10 | elif config.optim.optimizer == 'SGD': 11 | return optim.SGD(parameters, lr=config.optim.lr, momentum=0.9) 12 | else: 13 | raise NotImplementedError('Optimizer {} not understood.'.format(config.optim.optimizer)) 14 | -------------------------------------------------------------------------------- /configs/lowlight.yml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset: "lowlight" 3 | image_size: 64 4 | channels: 3 5 | num_workers: 0 6 | data_dir: "/root/LLIE/MDMS-main/datasets/scratch/LLIE/" 7 | conditional: True 8 | 9 | model: 10 | in_channels: 3 11 | out_ch: 3 12 | ch: 128 13 | #ch: 32 14 | # ch_mult: [1, 1, 2, 2, 4, 4] 15 | ch_mult: [1, 2, 4,8] 16 | num_res_blocks: 2 17 | attn_resolutions: [] 18 | dropout: 0.0 19 | ema_rate: 0.999 20 | ema: True 21 | resamp_with_conv: True 22 | 23 | diffusion: 24 | beta_schedule: linear 25 | beta_start: 0.0001 26 | beta_end: 0.02 27 | num_diffusion_timesteps: 1000 28 | 29 | training: 30 | patch_n: 2 31 | batch_size: 38 32 | n_epochs: 200000 33 | n_iters: 2000000 34 | snapshot_freq: 200 35 | validation_freq: 200 36 | 37 | sampling: 38 | batch_size: 1 #4 39 | last_only: True 40 | 41 | optim: 42 | weight_decay: 0.000 43 | optimizer: "Adam" 44 | lr: 0.00002 45 | amsgrad: False 46 | eps: 0.00000001 47 | -------------------------------------------------------------------------------- /datasets/scratch/LLIE/data/lowlight/test/lowlighttesta.txt: -------------------------------------------------------------------------------- 1 | /root/LLIE/MDMS-main/datasets/scratch/LLIE/data/lowlight/test/input/1.png 2 | /root/LLIE/MDMS-main/datasets/scratch/LLIE/data/lowlight/test/input/111.png 3 | /root/LLIE/MDMS-main/datasets/scratch/LLIE/data/lowlight/test/input/146.png 4 | /root/LLIE/MDMS-main/datasets/scratch/LLIE/data/lowlight/test/input/179.png 5 | /root/LLIE/MDMS-main/datasets/scratch/LLIE/data/lowlight/test/input/22.png 6 | /root/LLIE/MDMS-main/datasets/scratch/LLIE/data/lowlight/test/input/23.png 7 | /root/LLIE/MDMS-main/datasets/scratch/LLIE/data/lowlight/test/input/493.png 8 | /root/LLIE/MDMS-main/datasets/scratch/LLIE/data/lowlight/test/input/547.png 9 | /root/LLIE/MDMS-main/datasets/scratch/LLIE/data/lowlight/test/input/55.png 10 | /root/LLIE/MDMS-main/datasets/scratch/LLIE/data/lowlight/test/input/665.png 11 | /root/LLIE/MDMS-main/datasets/scratch/LLIE/data/lowlight/test/input/669.png 12 | /root/LLIE/MDMS-main/datasets/scratch/LLIE/data/lowlight/test/input/748.png 13 | /root/LLIE/MDMS-main/datasets/scratch/LLIE/data/lowlight/test/input/778.png 14 | /root/LLIE/MDMS-main/datasets/scratch/LLIE/data/lowlight/test/input/780.png 15 | /root/LLIE/MDMS-main/datasets/scratch/LLIE/data/lowlight/test/input/79.png 16 | -------------------------------------------------------------------------------- /train_diffusion.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import socket 5 | import yaml 6 | import torch 7 | import torch.backends.cudnn as cudnn 8 | import torch.utils.data 9 | import numpy as np 10 | import torchvision 11 | import models 12 | import datasets 13 | import utils 14 | from models import DenoisingDiffusion 15 | 16 | 17 | def parse_args_and_config(): 18 | parser = argparse.ArgumentParser(description='Training Patch-Based Denoising Diffusion Models') 19 | parser.add_argument("--config", type=str, required=False, default="lowlight.yml",help="Path to the config file") 20 | parser.add_argument('--resume', default=r'', type=str, 21 | help='Path for checkpoint to load and resume') 22 | parser.add_argument("--sampling_timesteps", type=int, default=25, 23 | help="Number of implicit sampling steps for validation image patches") 24 | parser.add_argument("--image_folder", default='results/images/', type=str, 25 | help="Location to save restored validation image patches") 26 | parser.add_argument('--seed', default=61, type=int, metavar='N', 27 | help='Seed for initializing training (default: 61)') 28 | args = parser.parse_args() 29 | 30 | with open(os.path.join("configs", args.config), "r") as f: 31 | config = yaml.safe_load(f) 32 | new_config = dict2namespace(config) 33 | 34 | return args, new_config 35 | 36 | 37 | def dict2namespace(config): 38 | namespace = argparse.Namespace() 39 | for key, value in config.items(): 40 | if isinstance(value, dict): 41 | new_value = dict2namespace(value) 42 | else: 43 | new_value = value 44 | setattr(namespace, key, new_value) 45 | return namespace 46 | 47 | 48 | def main(): 49 | args, config = parse_args_and_config() 50 | 51 | # setup device to run 52 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 53 | print("Using device: {}".format(device)) 54 | config.device = device 55 | 56 | # set random seed 57 | torch.manual_seed(args.seed) 58 | np.random.seed(args.seed) 59 | if torch.cuda.is_available(): 60 | torch.cuda.manual_seed_all(args.seed) 61 | torch.backends.cudnn.benchmark = True 62 | 63 | # data loading 64 | print("=> using dataset '{}'".format(config.data.dataset)) 65 | DATASET = datasets.__dict__[config.data.dataset](config) 66 | 67 | # create model 68 | print("=> creating denoising-diffusion model...") 69 | diffusion = DenoisingDiffusion(args, config) 70 | diffusion.train(DATASET) 71 | 72 | 73 | if __name__ == "__main__": 74 | main() 75 | -------------------------------------------------------------------------------- /eval_diffusion.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import socket 5 | import yaml 6 | import torch 7 | import torch.backends.cudnn as cudnn 8 | import numpy as np 9 | import torchvision 10 | import models 11 | import datasets 12 | import utils 13 | from models import DenoisingDiffusion, DiffusiveRestoration 14 | 15 | 16 | def parse_args_and_config(): 17 | parser = argparse.ArgumentParser(description='Restoring Weather with Patch-Based Denoising Diffusion Models') 18 | parser.add_argument("--config", type=str, default='lowlight.yml', 19 | help="Path to the config file") 20 | 21 | parser.add_argument('--resume', default=r'/root/LLIE/MDMS-main/datasets/scratch/LLIE/ckpts/bestpsnr.pth.tar', type=str, 22 | help='Path for the diffusion model checkpoint to load for evaluation') 23 | parser.add_argument("--grid_r", type=int, default=16, 24 | help="Grid cell width r that defines the overlap between patches") 25 | parser.add_argument("--sampling_timesteps", type=int, default=25, 26 | help="Number of implicit sampling steps") 27 | parser.add_argument("--test_set", type=str, default='lowlight') 28 | parser.add_argument("--image_folder", default='results/images/', type=str, 29 | help="Location to save restored images") 30 | parser.add_argument('--seed', default=61, type=int, metavar='N', 31 | help='Seed for initializing training (default: 61)') 32 | args = parser.parse_args() 33 | 34 | with open(os.path.join("configs", args.config), "r") as f: 35 | config = yaml.safe_load(f) 36 | new_config = dict2namespace(config) 37 | 38 | return args, new_config 39 | 40 | 41 | def dict2namespace(config): 42 | namespace = argparse.Namespace() 43 | for key, value in config.items(): 44 | if isinstance(value, dict): 45 | new_value = dict2namespace(value) 46 | else: 47 | new_value = value 48 | setattr(namespace, key, new_value) 49 | return namespace 50 | 51 | 52 | def main(): 53 | args, config = parse_args_and_config() 54 | 55 | # setup device to run 56 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 57 | print("Using device: {}".format(device)) 58 | config.device = device 59 | 60 | if torch.cuda.is_available(): 61 | print('Note: Currently supports evaluations (restoration) when run only on a single GPU!') 62 | 63 | # set random seed 64 | torch.manual_seed(args.seed) 65 | np.random.seed(args.seed) 66 | if torch.cuda.is_available(): 67 | torch.cuda.manual_seed_all(args.seed) 68 | torch.backends.cudnn.benchmark = True 69 | 70 | # data loading 71 | print("=> using dataset '{}'".format(config.data.dataset)) 72 | DATASET = datasets.__dict__[config.data.dataset](config) 73 | _, val_loader = DATASET.get_loaders(parse_patches=False, validation=args.test_set) 74 | 75 | # create model 76 | print("=> creating denoising-diffusion model with wrapper...") 77 | diffusion = DenoisingDiffusion(args, config) 78 | model = DiffusiveRestoration(diffusion, args, config) 79 | model.restore(val_loader, validation=args.test_set, r=args.grid_r,use_align=True) 80 | 81 | 82 | if __name__ == '__main__': 83 | main() 84 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # 【AAAI'2024】Multi-Domain Multi-Scale Diffusion Model for Low-Light Image Enhancement 4 |
5 | 6 | The official implementation of AAAI24 paper [Multi-Domain Multi-Scale Diffusion Model for Low-Light Image Enhancement](https://ojs.aaai.org/index.php/AAAI/article/view/28273). 7 | 8 | ## Environment 9 | create a new conda env, 10 | and run 11 | ``` 12 | $ pip install -r requirements.txt 13 | ``` 14 | torch/torchvision with CUDA version >= 11.3 should be fine. 15 | 16 | ## Demo 17 | #### 1. Download pretrained model 18 | 19 | Download the Pretrained MDMS model from [Baidu NetDisk](https://pan.baidu.com/s/1AQzsofBfsiSEy6cG6wcbYg?pwd=gfri) or [Google Drive](https://drive.google.com/file/d/1PQIqAs0mw8xp5obrVcg7QnxLLsfQ5H9L/view?usp=sharing). 20 | 21 | Put the downloaded ckpt in `datasets/scratch/LLIE/ckpts`. 22 | 23 | 24 | #### 2. Inference 25 | ``` 26 | # in {path_to_this_repo}/, 27 | $ python eval_diffusion.py 28 | ``` 29 | Put the test input in `datasets/scratch/LLIE/data/lowlight/test/input`. 30 | 31 | Output results will be saved in `results/images/lowlight/lowlight`. 32 | 33 | ## Evaluation 34 | 35 | Put the test GT in `datasets/scratch/LLIE/data/lowlight/test/gt` for paired evaluation. 36 | 37 | ``` 38 | # in {path_to_this_repo}/, 39 | $ python evaluation.py 40 | ``` 41 | * Note that our [evaluation metrics](https://github.com/Oliiveralien/MDMS/tree/main/evaluation.py) are slightly different from [PyDiff](https://github.com/limuloo/PyDIff/tree/862f8cc428450ef02822fd218b15705e2214ec2d/BasicSR-light/basicsr/metrics) (inherited from [BasicSR](https://github.com/XPixelGroup/BasicSR)). 42 | 43 | ## Results 44 | All results listed in our paper including the compared methods are available in [Baidu Netdisk](https://pan.baidu.com/s/1O8hOVflnLGLSLP07nXp_sg?pwd=zftu) or [Google Drive](https://drive.google.com/file/d/1k9-vD-I5JaHj7Y9bGq1gen2TKEzEhCzs/view?usp=sharing). 45 | 46 | * Note that the provided model is trained on the [LOLv1](https://daooshee.github.io/BMVC2018website/) training set, but generalizes well on other datasets. 47 | * For SSIM, we directly calculate the performance on [RGB channel](https://github.com/Oliiveralien/MDMS/tree/main/evaluation.py#L49-L51) rather than just [grayscale images](https://github.com/limuloo/PyDIff/blob/862f8cc428450ef02822fd218b15705e2214ec2d/BasicSR-light/basicsr/metrics/ssim_lol.py#L7C1-L12C132) in PyDiff. 48 | * For LPIPS, we use a different normalization method ([NormA](https://github.com/Oliiveralien/MDMS/tree/main/evaluation.py#L74)) compared to PyDiff ([NormB](https://github.com/limuloo/PyDIff/blob/862f8cc428450ef02822fd218b15705e2214ec2d/BasicSR-light/basicsr/metrics/lpips_lol.py#L19)). 49 | 50 | Our method remains superior under the same setting as PyDiff. 51 | 52 |
53 | 54 |
55 | 56 | ### 1. Test results on LOLv1 test set. 57 | 58 |
59 | 60 |
61 | 62 | ### 2. Generalization results on LOLv2 syn and real test sets. 63 | 64 |
65 | 66 |
67 | 68 | ### 3. Generalization results on other unpaired datasets. 69 | 70 |
71 | 72 |
73 | 74 | We will perform more training and tests on other datasets in the future. 75 | 76 | ## Training 77 | Put the training dataset in `datasets/scratch/LLIE/data/lowlight/train`. 78 | 79 | ``` 80 | # in {path_to_this_repo}/, 81 | $ python train_diffusion.py 82 | ``` 83 | 84 | Detailed training instructions will be updated soon. 85 | 86 | ## Citation 87 | If you find this paper useful, please consider staring this repo and citing our paper: 88 | ``` 89 | @inproceedings{shang2024multi, 90 | title={Multi-Domain Multi-Scale Diffusion Model for Low-Light Image Enhancement}, 91 | author={Shang, Kai and Shao, Mingwen and Wang, Chao and Cheng, Yuanshuo and Wang, Shuigen}, 92 | booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, 93 | volume={38}, 94 | number={5}, 95 | pages={4722--4730}, 96 | year={2024} 97 | } 98 | -------------------------------------------------------------------------------- /models/restoration.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import utils 4 | import torchvision 5 | import os 6 | import PIL 7 | import re 8 | from torchvision.transforms import Resize 9 | 10 | def data_transform(X): 11 | return 2 * X - 1.0 12 | 13 | 14 | def inverse_data_transform(X): 15 | return torch.clamp((X + 1.0) / 2.0, 0.0, 1.0) 16 | 17 | 18 | class DiffusiveRestoration: 19 | def __init__(self, diffusion, args, config): 20 | super(DiffusiveRestoration, self).__init__() 21 | self.args = args 22 | self.config = config 23 | self.diffusion = diffusion 24 | 25 | if os.path.isfile(args.resume): 26 | self.diffusion.load_ddm_ckpt(args.resume, ema=True) 27 | self.diffusion.model.eval() 28 | else: 29 | print('Pre-trained diffusion model path is missing!') 30 | 31 | def restore(self, val_loader, validation='lowlight', r=None,use_align=False): 32 | image_folder = os.path.join(self.args.image_folder, self.config.data.dataset, validation) 33 | with torch.no_grad(): 34 | for i, (x, y,wd,ht) in enumerate(val_loader): 35 | print(f"starting processing from image {y}") 36 | y = re.findall(r'\d+', y[0]) 37 | x = x.flatten(start_dim=0, end_dim=1) if x.ndim == 5 else x 38 | x_cond = x[:, :6, :, :].to(self.diffusion.device) 39 | x_output1 = self.diffusive_restoration(x_cond, r=r,fullresolusion=False) 40 | #x_output = inverse_data_transform(x_output) 41 | x_output1 = inverse_data_transform(x_output1) 42 | x_output=x_output1 43 | #x_output=(x_output+x_output1)/2 44 | b,c,h,w=x_output.shape 45 | ht=ht.item() 46 | wd=wd.item() 47 | torch_resize = Resize([ht, wd]) 48 | x_output=torch_resize(x_output) 49 | 50 | if use_align==True: 51 | target = x[:, 6:, :, :] 52 | gt_mean = torch.mean(target) 53 | sr_mean = torch.mean(x_output) 54 | x_output = x_output * gt_mean / sr_mean 55 | 56 | 57 | 58 | utils.logging.save_image(x_output, os.path.join(image_folder, f"{y}.png")) 59 | 60 | def diffusive_restoration(self, x_cond, r=None,fullresolusion=False): 61 | if fullresolusion==False:## 62 | # p_size = self.config.data.image_size 63 | p_size=64 64 | h_list, w_list = self.overlapping_grid_indices(x_cond, output_size=p_size, r=8) 65 | corners = [(i, j) for i in h_list for j in w_list] 66 | h_list1, w_list1 = self.overlapping_grid_indices(x_cond, output_size=96, r=8) 67 | corners1 = [(i, j) for i in h_list1 for j in w_list1] 68 | 69 | h_list2, w_list2 = self.overlapping_grid_indices(x_cond, output_size=128, r=8) 70 | corners2 = [(i, j) for i in h_list2 for j in w_list2] 71 | 72 | 73 | x = torch.randn(x_cond.size()[0],3,x_cond.size()[2],x_cond.size()[3], device=self.diffusion.device) 74 | 75 | ii = torch.tensor([item[0] for item in corners]) 76 | jj = torch.tensor([item[1] for item in corners]) 77 | ii=ii/x_cond.size()[2]*2-1 78 | jj=jj/x_cond.size()[3]*2-1 79 | osize=torch.full((len(corners),), p_size) 80 | x_output = self.diffusion.sample_image(x_cond, x, ii,jj,osize,patch_locs=corners, patch_size=p_size,patch_locs1=corners1,patch_locs2=corners2) 81 | else: 82 | x = torch.randn(x_cond.size()[0], 3, x_cond.size()[2], x_cond.size()[3], device=self.diffusion.device) 83 | ii=torch.tensor(-1).unsqueeze(0) 84 | jj=torch.tensor(-1).unsqueeze(0) 85 | osize=torch.tensor(x_cond.size()[2]).unsqueeze(0) 86 | x_output = self.diffusion.sample_image(x_cond, x, ii, jj, osize, patch_locs=None, patch_size=None) 87 | 88 | 89 | return x_output 90 | #return x_output,x_output1 91 | 92 | 93 | def overlapping_grid_indices(self, x_cond, output_size, r=None): 94 | _, c, h, w = x_cond.shape 95 | r = 16 if r is None else r 96 | h_list = [i for i in range(0, h - output_size + 1, r)] 97 | w_list = [i for i in range(0, w - output_size + 1, r)] 98 | return h_list, w_list 99 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import time 4 | from collections import OrderedDict 5 | 6 | import numpy as np 7 | import torch 8 | import cv2 9 | import argparse 10 | 11 | from natsort import natsort 12 | from skimage.metrics import structural_similarity as ssim 13 | from skimage.metrics import peak_signal_noise_ratio as psnr 14 | import lpips 15 | from numpy.lib.histograms import histogram 16 | from numpy.lib.function_base import interp 17 | def histeq(im, nbr_bins=256): 18 | im=im.cpu() 19 | im=im.detach().numpy() 20 | imhist, bins = histogram(im.flatten(), nbr_bins) 21 | cdf = imhist.cumsum() 22 | cdf = 1.0 * cdf / cdf[-1] 23 | im2 = interp(im.flatten(), bins[:-1], cdf) 24 | return im2.reshape(im.shape) 25 | def save_img(filepath, img): 26 | cv2.imwrite(filepath, cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) 27 | def rgb(t): return ( 28 | np.clip((t[0] if len(t.shape) == 4 else t).detach().cpu().numpy().transpose([1, 2, 0]), 0, 1) * 255).astype( 29 | np.uint8) 30 | class Measure(): 31 | def __init__(self, net='alex', use_gpu=False): 32 | self.device = 'cuda' if use_gpu else 'cpu' 33 | self.model = lpips.LPIPS(net=net) 34 | self.model.to(self.device) 35 | 36 | def measure(self, imgA, imgB):#A=gt B=out 37 | 38 | return [float(f(imgA, imgB)) for f in [self.psnr, self.ssim, self.lpips]],imgB 39 | 40 | def lpips(self, imgA, imgB, model=None): 41 | tA = t(imgA).to(self.device) 42 | tB = t(imgB).to(self.device) 43 | # dist01 = self.model.forward(tA, tB).item() 44 | 45 | dist01 = self.model(tA, tB).item() 46 | 47 | return dist01 48 | 49 | def ssim(self, imgA, imgB): 50 | score, diff = ssim(imgA, imgB, full=True, multichannel=True) 51 | # score, diff = ssim( cv2.cvtColor(imgA, cv2.COLOR_RGB2GRAY), cv2.cvtColor(imgB, cv2.COLOR_RGB2GRAY), full=True, multichannel=True) 52 | 53 | return score 54 | 55 | def psnr(self, imgA, imgB): 56 | psnr_val = psnr(imgA, imgB) 57 | return psnr_val 58 | 59 | 60 | def t(img): 61 | def to_4d(img): 62 | assert len(img.shape) == 3 63 | assert img.dtype == np.uint8 64 | img_new = np.expand_dims(img, axis=0) 65 | assert len(img_new.shape) == 4 66 | return img_new 67 | 68 | def to_CHW(img): 69 | return np.transpose(img, [2, 0, 1]) 70 | 71 | def to_tensor(img): 72 | return torch.Tensor(img) 73 | 74 | return to_tensor(to_4d(to_CHW(img))) / 255 75 | # return to_tensor(to_4d(to_CHW(img))) / 127.5 - 1 76 | 77 | 78 | 79 | def fiFindByWildcard(wildcard): 80 | return natsort.natsorted(glob.glob(wildcard, recursive=True)) 81 | 82 | 83 | def imread(path): 84 | return cv2.imread(path)[:, :, [2, 1, 0]] 85 | 86 | 87 | def format_result(psnr, ssim, lpips): 88 | return f'{psnr:0.2f}, {ssim:0.3f}, {lpips:0.3f}' 89 | 90 | def measure_dirs(dirA, dirB, use_gpu, verbose=False): 91 | if verbose: 92 | vprint = lambda x: print(x) 93 | else: 94 | vprint = lambda x: None 95 | 96 | 97 | t_init = time.time() 98 | 99 | paths_A = fiFindByWildcard(os.path.join(dirA, f'*.{type}')) 100 | paths_B = fiFindByWildcard(os.path.join(dirB, f'*.{type}')) 101 | 102 | vprint("Comparing: ") 103 | vprint(dirA) 104 | vprint(dirB) 105 | 106 | measure = Measure(use_gpu=use_gpu) 107 | 108 | results = [] 109 | for pathA, pathB in zip(paths_A, paths_B): 110 | result = OrderedDict() 111 | 112 | t = time.time() 113 | A=imread(pathA) 114 | B=imread(pathB) 115 | As = A.shape 116 | Bs = B.shape 117 | A0=As[0] 118 | A1= As[1] 119 | b = cv2.resize(B, (A1, A0)) 120 | [result['psnr'], result['ssim'], result['lpips']],imgb= measure.measure(A, b) 121 | 122 | d = time.time() - t 123 | vprint(f"{pathA.split('/')[-1]}, {pathB.split('/')[-1]}, {format_result(**result)}, {d:0.1f}") 124 | 125 | results.append(result) 126 | 127 | psnr = np.mean([result['psnr'] for result in results]) 128 | ssim = np.mean([result['ssim'] for result in results]) 129 | lpips = np.mean([result['lpips'] for result in results]) 130 | 131 | vprint(f"Final Result: {format_result(psnr, ssim, lpips)}, {time.time() - t_init:0.1f}s") 132 | 133 | 134 | if __name__ == "__main__": 135 | parser = argparse.ArgumentParser() 136 | parser.add_argument('-dirA', default=r'./datasets/scratch/LLIE/data/lowlight/test/gt', type=str) 137 | parser.add_argument('-dirB', default=r'./results/images/lowlight/lowlight', type=str) 138 | parser.add_argument('-type', default='png') 139 | parser.add_argument('--use_gpu', default=True) 140 | args = parser.parse_args() 141 | 142 | dirA = args.dirA 143 | dirB = args.dirB 144 | type = args.type 145 | use_gpu = args.use_gpu 146 | 147 | if len(dirA) > 0 and len(dirB) > 0: 148 | measure_dirs(dirA, dirB, use_gpu=use_gpu, verbose=True) 149 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | 5 | # This script is adapted from the following repository: https://github.com/JingyunLiang/SwinIR 6 | 7 | 8 | def calculate_psnr(img1, img2, test_y_channel=False): 9 | """Calculate PSNR (Peak Signal-to-Noise Ratio). 10 | 11 | Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio 12 | 13 | Args: 14 | img1 (ndarray): Images with range [0, 255]. 15 | img2 (ndarray): Images with range [0, 255]. 16 | test_y_channel (bool): Test on Y channel of YCbCr. Default: False. 17 | 18 | Returns: 19 | float: psnr result. 20 | """ 21 | 22 | assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.') 23 | assert img1.shape[2] == 3 24 | img1 = img1.astype(np.float64) 25 | img2 = img2.astype(np.float64) 26 | 27 | if test_y_channel: 28 | img1 = to_y_channel(img1) 29 | img2 = to_y_channel(img2) 30 | 31 | mse = np.mean((img1 - img2) ** 2) 32 | if mse == 0: 33 | return float('inf') 34 | return 20. * np.log10(255. / np.sqrt(mse)) 35 | 36 | 37 | def _ssim(img1, img2): 38 | """Calculate SSIM (structural similarity) for one channel images. 39 | 40 | It is called by func:`calculate_ssim`. 41 | 42 | Args: 43 | img1 (ndarray): Images with range [0, 255] with order 'HWC'. 44 | img2 (ndarray): Images with range [0, 255] with order 'HWC'. 45 | 46 | Returns: 47 | float: ssim result. 48 | """ 49 | 50 | C1 = (0.01 * 255) ** 2 51 | C2 = (0.03 * 255) ** 2 52 | 53 | img1 = img1.astype(np.float64) 54 | img2 = img2.astype(np.float64) 55 | kernel = cv2.getGaussianKernel(11, 1.5) 56 | window = np.outer(kernel, kernel.transpose()) 57 | 58 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] 59 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 60 | mu1_sq = mu1 ** 2 61 | mu2_sq = mu2 ** 2 62 | mu1_mu2 = mu1 * mu2 63 | sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq 64 | sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq 65 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 66 | 67 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 68 | return ssim_map.mean() 69 | 70 | 71 | def calculate_ssim(img1, img2, test_y_channel=False): 72 | """Calculate SSIM (structural similarity). 73 | 74 | Ref: 75 | Image quality assessment: From error visibility to structural similarity 76 | 77 | The results are the same as that of the official released MATLAB code in 78 | https://ece.uwaterloo.ca/~z70wang/research/ssim/. 79 | 80 | For three-channel images, SSIM is calculated for each channel and then 81 | averaged. 82 | 83 | Args: 84 | img1 (ndarray): Images with range [0, 255]. 85 | img2 (ndarray): Images with range [0, 255]. 86 | test_y_channel (bool): Test on Y channel of YCbCr. Default: False. 87 | 88 | Returns: 89 | float: ssim result. 90 | """ 91 | 92 | assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.') 93 | assert img1.shape[2] == 3 94 | img1 = img1.astype(np.float64) 95 | img2 = img2.astype(np.float64) 96 | 97 | if test_y_channel: 98 | img1 = to_y_channel(img1) 99 | img2 = to_y_channel(img2) 100 | 101 | ssims = [] 102 | for i in range(img1.shape[2]): 103 | ssims.append(_ssim(img1[..., i], img2[..., i])) 104 | return np.array(ssims).mean() 105 | 106 | 107 | def to_y_channel(img): 108 | """Change to Y channel of YCbCr. 109 | 110 | Args: 111 | img (ndarray): Images with range [0, 255]. 112 | 113 | Returns: 114 | (ndarray): Images with range [0, 255] (float type) without round. 115 | """ 116 | img = img.astype(np.float32) / 255. 117 | if img.ndim == 3 and img.shape[2] == 3: 118 | img = bgr2ycbcr(img, y_only=True) 119 | img = img[..., None] 120 | return img * 255. 121 | 122 | 123 | def _convert_input_type_range(img): 124 | """Convert the type and range of the input image. 125 | 126 | It converts the input image to np.float32 type and range of [0, 1]. 127 | It is mainly used for pre-processing the input image in colorspace 128 | convertion functions such as rgb2ycbcr and ycbcr2rgb. 129 | 130 | Args: 131 | img (ndarray): The input image. It accepts: 132 | 1. np.uint8 type with range [0, 255]; 133 | 2. np.float32 type with range [0, 1]. 134 | 135 | Returns: 136 | (ndarray): The converted image with type of np.float32 and range of 137 | [0, 1]. 138 | """ 139 | img_type = img.dtype 140 | img = img.astype(np.float32) 141 | if img_type == np.float32: 142 | pass 143 | elif img_type == np.uint8: 144 | img /= 255. 145 | else: 146 | raise TypeError('The img type should be np.float32 or np.uint8, ' f'but got {img_type}') 147 | return img 148 | 149 | 150 | def _convert_output_type_range(img, dst_type): 151 | """Convert the type and range of the image according to dst_type. 152 | 153 | It converts the image to desired type and range. If `dst_type` is np.uint8, 154 | images will be converted to np.uint8 type with range [0, 255]. If 155 | `dst_type` is np.float32, it converts the image to np.float32 type with 156 | range [0, 1]. 157 | It is mainly used for post-processing images in colorspace convertion 158 | functions such as rgb2ycbcr and ycbcr2rgb. 159 | 160 | Args: 161 | img (ndarray): The image to be converted with np.float32 type and 162 | range [0, 255]. 163 | dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it 164 | converts the image to np.uint8 type with range [0, 255]. If 165 | dst_type is np.float32, it converts the image to np.float32 type 166 | with range [0, 1]. 167 | 168 | Returns: 169 | (ndarray): The converted image with desired type and range. 170 | """ 171 | if dst_type not in (np.uint8, np.float32): 172 | raise TypeError('The dst_type should be np.float32 or np.uint8, ' f'but got {dst_type}') 173 | if dst_type == np.uint8: 174 | img = img.round() 175 | else: 176 | img /= 255. 177 | return img.astype(dst_type) 178 | 179 | 180 | def bgr2ycbcr(img, y_only=False): 181 | """Convert a BGR image to YCbCr image. 182 | 183 | The bgr version of rgb2ycbcr. 184 | It implements the ITU-R BT.601 conversion for standard-definition 185 | television. See more details in 186 | https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. 187 | 188 | It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`. 189 | In OpenCV, it implements a JPEG conversion. See more details in 190 | https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. 191 | 192 | Args: 193 | img (ndarray): The input image. It accepts: 194 | 1. np.uint8 type with range [0, 255]; 195 | 2. np.float32 type with range [0, 1]. 196 | y_only (bool): Whether to only return Y channel. Default: False. 197 | 198 | Returns: 199 | ndarray: The converted YCbCr image. The output image has the same type 200 | and range as input image. 201 | """ 202 | img_type = img.dtype 203 | img = _convert_input_type_range(img) 204 | if y_only: 205 | out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0 206 | else: 207 | out_img = np.matmul( 208 | img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [16, 128, 128] 209 | out_img = _convert_output_type_range(out_img, img_type) 210 | return out_img 211 | -------------------------------------------------------------------------------- /utils/sampling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import utils.logging 3 | import os 4 | import torchvision 5 | from torchvision.transforms.functional import crop 6 | 7 | 8 | # This script is adapted from the following repository: https://github.com/ermongroup/ddim 9 | 10 | def compute_alpha(beta, t): 11 | beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0) 12 | a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1) 13 | return a 14 | 15 | 16 | def data_transform(X): 17 | return 2 * X - 1.0 18 | 19 | 20 | def inverse_data_transform(X): 21 | return torch.clamp((X + 1.0) / 2.0, 0.0, 1.0) 22 | 23 | 24 | def generalized_steps(x, x_cond,seq, model, b,ii,jj,osize,eta=0.): 25 | with torch.no_grad(): 26 | n = x.size(0) 27 | seq_next = [-1] + list(seq[:-1]) 28 | x0_preds = [] 29 | xs = [x] 30 | for i, j in zip(reversed(seq), reversed(seq_next)): 31 | t = (torch.ones(n) * i).to(x.device) 32 | next_t = (torch.ones(n) * j).to(x.device) 33 | at = compute_alpha(b, t.long()) 34 | at_next = compute_alpha(b, next_t.long()) 35 | xt = xs[-1].to('cuda') 36 | ii = ii.to(x.device) 37 | jj = jj.to(x.device) 38 | osize = osize.to(x.device) 39 | # print(xt.shape) 40 | # print(x_cond.shape) 41 | et = model(torch.cat([x_cond, xt], dim=1), t,ii,jj,osize) 42 | 43 | x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt() 44 | x0_preds.append(x0_t.to('cpu')) 45 | 46 | c1 = eta * ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt() 47 | c2 = ((1 - at_next) - c1 ** 2).sqrt() 48 | xt_next = at_next.sqrt() * x0_t + c1 * torch.randn_like(x) + c2 * et 49 | xs.append(xt_next.to('cpu')) 50 | return xs, x0_preds 51 | 52 | 53 | def generalized_steps_overlapping(x, x_cond, seq, model, b,ii,jj,osize, eta=0., corners=None, p_size=None,corners1=None,corners2=None,manual_batching=True): 54 | with torch.no_grad(): 55 | n = x.size(0) 56 | seq_next = [-1] + list(seq[:-1]) 57 | x0_preds = [] 58 | xs = [x] 59 | 60 | x_grid_mask = torch.zeros(x_cond.size(0),3,x_cond.size(2),x_cond.size(3), device=x.device) 61 | for (hi, wi) in corners: 62 | x_grid_mask[:, :, hi:hi + p_size, wi:wi + p_size] += 1 63 | if corners1!=None: 64 | p_size1=96 65 | x_grid_mask1 = torch.zeros(x_cond.size(0),3,x_cond.size(2),x_cond.size(3), device=x.device) 66 | for (hi, wi) in corners1: 67 | x_grid_mask1[:, :, hi:hi + p_size1, wi:wi + p_size1] += 1 68 | if corners2 != None: 69 | p_size2=128 70 | x_grid_mask2 = torch.zeros(x_cond.size(0),3,x_cond.size(2),x_cond.size(3), device=x.device) 71 | for (hi, wi) in corners2: 72 | x_grid_mask2[:, :, hi:hi + p_size2, wi:wi + p_size2] += 1 73 | 74 | 75 | for i, j in zip(reversed(seq), reversed(seq_next)): 76 | t = (torch.ones(n) * i).to(x.device) 77 | next_t = (torch.ones(n) * j).to(x.device) 78 | at = compute_alpha(b, t.long()) 79 | at_next = compute_alpha(b, next_t.long()) 80 | xt = xs[-1].to('cuda') 81 | et_output = torch.zeros(x_cond.size(0),3,x_cond.size(2),x_cond.size(3), device=x.device) 82 | 83 | if manual_batching==True: 84 | manual_batching_size = p_size 85 | xt_patch = torch.cat([crop(xt, hi, wi, p_size, p_size) for (hi, wi) in corners], dim=0) 86 | x_cond_patch = torch.cat([data_transform(crop(x_cond, hi, wi, p_size, p_size)) for (hi, wi) in corners], dim=0) 87 | for i in range(0, len(corners), manual_batching_size):#以16为步长的块即为corners 88 | ii_input=torch.unsqueeze(ii[i],dim=0) 89 | jj_input=torch.unsqueeze(jj[i],dim=0) 90 | osize_input=torch.unsqueeze(osize[i],dim=0) 91 | x_input=torch.cat([x_cond_patch[i:i + manual_batching_size], xt_patch[i:i + manual_batching_size]], dim=1) 92 | # print("x_input", x_input.shape) 93 | outputs = model(x_input, t,ii_input,jj_input,osize_input) 94 | # print("outputs", outputs.shape) 95 | for idx, (hi, wi) in enumerate(corners[i:i+manual_batching_size]): 96 | # print("idx", idx) 97 | et_output[0, :, hi:hi + p_size, wi:wi + p_size] += outputs[idx] 98 | 99 | if corners1 != None: 100 | et_output1 = torch.zeros(x_cond.size(0), 3, x_cond.size(2), x_cond.size(3), device=x.device) 101 | manual_batching_size = p_size1 102 | xt_patch = torch.cat([crop(xt, hi, wi, p_size1, p_size1) for (hi, wi) in corners1], dim=0) 103 | x_cond_patch = torch.cat( 104 | [data_transform(crop(x_cond, hi, wi, p_size1, p_size1)) for (hi, wi) in corners1], dim=0) 105 | for i in range(0, len(corners1), manual_batching_size): # 以16为步长的块即为corners 106 | ii_input = torch.unsqueeze(ii[i], dim=0) 107 | jj_input = torch.unsqueeze(jj[i], dim=0) 108 | osize_input = torch.unsqueeze(osize[i], dim=0) 109 | x_input = torch.cat( 110 | [x_cond_patch[i:i + manual_batching_size], xt_patch[i:i + manual_batching_size]], dim=1) 111 | # print("x_input", x_input.shape) 112 | outputs1 = model(x_input, t, ii_input, jj_input, osize_input) 113 | # print("outputs", outputs.shape) 114 | for idx, (hi, wi) in enumerate(corners1[i:i + manual_batching_size]): 115 | # print("idx", idx) 116 | et_output1[0, :, hi:hi + p_size1, wi:wi + p_size1] += outputs1[idx] 117 | if corners2 != None: 118 | et_output2= torch.zeros(x_cond.size(0), 3, x_cond.size(2), x_cond.size(3), device=x.device) 119 | manual_batching_size = p_size2 120 | xt_patch = torch.cat([crop(xt, hi, wi, p_size2, p_size2) for (hi, wi) in corners2], dim=0) 121 | x_cond_patch = torch.cat( 122 | [data_transform(crop(x_cond, hi, wi, p_size2, p_size2)) for (hi, wi) in corners2], dim=0) 123 | for i in range(0, len(corners2), manual_batching_size): # 以16为步长的块即为corners 124 | ii_input = torch.unsqueeze(ii[i], dim=0) 125 | jj_input = torch.unsqueeze(jj[i], dim=0) 126 | osize_input = torch.unsqueeze(osize[i], dim=0) 127 | x_input = torch.cat( 128 | [x_cond_patch[i:i + manual_batching_size], xt_patch[i:i + manual_batching_size]], dim=1) 129 | # print("x_input", x_input.shape) 130 | outputs2 = model(x_input, t, ii_input, jj_input, osize_input) 131 | # print("outputs", outputs.shape) 132 | for idx, (hi, wi) in enumerate(corners2[i:i + manual_batching_size]): 133 | # print("idx", idx) 134 | et_output2[0, :, hi:hi + p_size2, wi:wi + p_size2] += outputs2[idx] 135 | 136 | else: 137 | for (hi, wi) in corners: 138 | xt_patch = crop(xt, hi, wi, p_size, p_size) 139 | x_cond_patch = crop(x_cond, hi, wi, p_size, p_size) 140 | x_cond_patch = data_transform(x_cond_patch) 141 | et_output[:, :, hi:hi + p_size, wi:wi + p_size] += model(torch.cat([x_cond_patch, xt_patch], dim=1), t) 142 | 143 | et0 = torch.div(et_output, x_grid_mask) 144 | if corners1 != None: 145 | et1 = torch.div(et_output1, x_grid_mask1) 146 | et=(et0+et1)/2.0 147 | if corners2 != None: 148 | et2 = torch.div(et_output2, x_grid_mask2) 149 | et = (et0 +et1+ et2) / 3.0 150 | else: 151 | et=et0 152 | x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt() 153 | x0_preds.append(x0_t.to('cpu')) 154 | 155 | c1 = eta * ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt() 156 | c2 = ((1 - at_next) - c1 ** 2).sqrt() 157 | xt_next = at_next.sqrt() * x0_t + c1 * torch.randn_like(x) + c2 * et 158 | # if i>500: 159 | # xt_next=xt_next/torch.mean(xt_next)*torch.mean(x_cond[:,3:6,:,]) 160 | xs.append(xt_next.to('cpu')) 161 | 162 | 163 | 164 | 165 | 166 | return xs, x0_preds 167 | -------------------------------------------------------------------------------- /datasets/lowlight.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | from os import listdir 4 | from os.path import isfile 5 | import torch 6 | import numpy as np 7 | import torchvision 8 | import torch.utils.data 9 | import PIL 10 | import re 11 | import random 12 | import torchvision.transforms as transforms 13 | 14 | class lowlight: 15 | def __init__(self, config): 16 | self.config = config 17 | self.transforms = torchvision.transforms.Compose([torchvision.transforms.ToTensor()]) 18 | 19 | def get_loaders(self, parse_patches=True, validation='lowlight'): 20 | print("=> evaluating lowlight test set...") 21 | train_dataset = lowlightDataset(dir=os.path.join(self.config.data.data_dir, 'data', 'lowlight', 'train'), 22 | n=self.config.training.patch_n, 23 | patch_size=self.config.data.image_size, 24 | transforms=self.transforms, 25 | filelist=None, 26 | parse_patches=parse_patches, 27 | train=True) 28 | val_dataset = lowlightDataset(dir=os.path.join(self.config.data.data_dir, 'data', 'lowlight', 'test'), 29 | n=self.config.training.patch_n, 30 | patch_size=self.config.data.image_size, 31 | transforms=self.transforms, 32 | filelist='lowlighttesta.txt', 33 | parse_patches=parse_patches, 34 | train=False) 35 | 36 | if not parse_patches: 37 | self.config.training.batch_size = 1 38 | self.config.sampling.batch_size = 1 39 | 40 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=self.config.training.batch_size, 41 | shuffle=True, num_workers=self.config.data.num_workers, 42 | pin_memory=True) 43 | val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=self.config.sampling.batch_size, 44 | shuffle=False, num_workers=self.config.data.num_workers, 45 | pin_memory=True) 46 | 47 | return train_loader, val_loader 48 | 49 | 50 | class lowlightDataset(torch.utils.data.Dataset): 51 | def __init__(self, dir, patch_size, n, transforms, train,filelist=None, parse_patches=True): 52 | super().__init__() 53 | 54 | if filelist is None: 55 | lowlight_dir = dir 56 | input_names, gt_names = [], [] 57 | 58 | 59 | 60 | filepath = os.path.dirname(__file__) 61 | print(f"filepath:{filepath}") 62 | lowlight_dir=os.path.join(filepath,lowlight_dir) 63 | #lowlight_inputs = os.path.join(filepath, lowlight_inputs0) 64 | 65 | # lowlight train filelist 66 | lowlight_inputs = os.path.join(lowlight_dir, 'input') 67 | listdir(lowlight_inputs) 68 | images = [f for f in listdir(lowlight_inputs) if isfile(os.path.join(lowlight_inputs, f))] 69 | assert len(images) == 485 70 | input_names += [os.path.join(lowlight_inputs, i) for i in images] 71 | # gt_names += [os.path.join(os.path.join(lowlight_dir, 'gt'), i.replace('rain', 'clean')) for i in images] 72 | gt_names += [os.path.join(os.path.join(lowlight_dir, 'gt'), i.replace('', '')) for i in images] 73 | print(len(input_names)) 74 | 75 | x = list(enumerate(input_names)) 76 | random.shuffle(x) 77 | indices, input_names = zip(*x) 78 | gt_names = [gt_names[idx] for idx in indices] 79 | self.dir = None 80 | else: 81 | self.dir = dir 82 | filepath = os.path.dirname(__file__) 83 | dir = os.path.join(filepath, dir) 84 | train_list = os.path.join(dir, filelist) 85 | with open(train_list) as f: 86 | contents = f.readlines() 87 | input_names = [i.strip() for i in contents] 88 | gt_names = [i.strip().replace('input', 'gt') for i in input_names] 89 | 90 | self.input_names = input_names 91 | self.gt_names = gt_names 92 | self.patch_size = patch_size 93 | self.transforms = transforms 94 | self.n = n 95 | self.parse_patches = parse_patches 96 | self.batchnum = 0 97 | self.batchsize = 1 98 | self.train=train 99 | 100 | @staticmethod 101 | def get_params(img, output_size, n,random_size): 102 | w, h = img.size 103 | if random_size == 0: 104 | output_size = (64,64) 105 | elif random_size == 1: 106 | output_size = (128,128) 107 | else: 108 | output_size = (256,256) 109 | # output_size = (192,192) 110 | th, tw = output_size 111 | if w == tw and h == th: 112 | return 0, 0, h, w 113 | 114 | i_list = [random.randint(0, h - th) for _ in range(n)] 115 | j_list = [random.randint(0, w - tw) for _ in range(n)] 116 | osize = [output_size[0] for _ in range(n)] 117 | return i_list, j_list, th, tw,osize,h,w 118 | 119 | @staticmethod 120 | def n_random_crops(img, x, y, h, w): 121 | crops = [] 122 | for i in range(len(x)): 123 | new_crop = img.crop((y[i], x[i], y[i] + w, x[i] + h)) 124 | 125 | if h!=64: 126 | resize_transform = transforms.Resize(64) 127 | new_crop=resize_transform(new_crop) 128 | crops.append(new_crop) 129 | return tuple(crops) 130 | 131 | def get_max(self,input): 132 | T,_=torch.max(input,dim=0) 133 | T=T+0.1 134 | input[0,:,:] = input[0,:,:]/ T 135 | input[1,:,:] = input[1,:,:]/ T 136 | input[2,:,:]= input[2,:,:] / T 137 | return input 138 | 139 | def get_images(self, index): 140 | if self.train==True: 141 | if self.batchnum==0: 142 | self.random_size = random.randint(0, 2) 143 | self.batchnum=self.batchnum+1 144 | if self.batchnum==self.batchsize: 145 | self.batchnum=0 146 | else: 147 | self.random_size=0 148 | 149 | input_name = self.input_names[index] 150 | gt_name = self.gt_names[index] 151 | img_id = re.split('/', input_name)[-1][:-4] 152 | input_img = PIL.Image.open(os.path.join(self.dir, input_name)) if self.dir else PIL.Image.open(input_name) 153 | try: 154 | gt_img = PIL.Image.open(os.path.join(self.dir, gt_name)) if self.dir else PIL.Image.open(gt_name) 155 | except: 156 | gt_img = PIL.Image.open(os.path.join(self.dir, gt_name)).convert('RGB') if self.dir else \ 157 | PIL.Image.open(gt_name).convert('RGB') 158 | 159 | if self.parse_patches: 160 | 161 | i, j, h, w,osize,h_org,w_org = self.get_params(input_img, (self.patch_size, self.patch_size), self.n,self.random_size) 162 | input_img = self.n_random_crops(input_img, i, j, h, w) 163 | gt_img = self.n_random_crops(gt_img, i, j, h, w) 164 | 165 | 166 | random_input=0 167 | if random_input==0: 168 | outputs = [torch.cat([self.transforms(input_img[i]),self.get_max(self.transforms(input_img[i])),self.transforms(gt_img[i])], dim=0) 169 | for i in range(self.n)] 170 | else: 171 | outputs = [torch.cat([self.get_max(self.transforms(input_img[i])), self.transforms(input_img[i]), 172 | self.transforms(gt_img[i])], dim=0) 173 | for i in range(self.n)] 174 | ii=torch.tensor(i) 175 | jj = torch.tensor(j) 176 | ii = (ii / h_org) * 2 - 1 177 | jj = (jj / w_org) * 2 - 1 178 | osize=torch.tensor(osize) 179 | return torch.stack(outputs, dim=0), img_id,ii,jj,osize 180 | else: 181 | # Resizing images to multiples of 16 for whole-image restoration 182 | wd_new, ht_new = input_img.size 183 | wd=wd_new 184 | ht=ht_new 185 | if ht_new > wd_new and ht_new > 1024: 186 | wd_new = int(np.ceil(wd_new * 1024 / ht_new)) 187 | ht_new = 1024 188 | elif ht_new <= wd_new and wd_new > 1024: 189 | ht_new = int(np.ceil(ht_new * 1024 / wd_new)) 190 | wd_new = 1024 191 | 192 | 193 | 194 | wd_new = int(8 * np.ceil(wd_new / 8.0)) 195 | ht_new = int(8 * np.ceil(ht_new / 8.0)) 196 | input_img = input_img.resize((wd_new, ht_new), PIL.Image.ANTIALIAS) 197 | gt_img = gt_img.resize((wd_new, ht_new), PIL.Image.ANTIALIAS) 198 | 199 | return torch.cat([self.transforms(input_img),self.get_max(self.transforms(input_img)), self.transforms(gt_img)], dim=0), img_id,wd,ht 200 | 201 | def __getitem__(self, index): 202 | res = self.get_images(index) 203 | return res 204 | 205 | def __len__(self): 206 | return len(self.input_names) 207 | -------------------------------------------------------------------------------- /models/ddm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import glob 4 | import numpy as np 5 | import tqdm 6 | import torch 7 | import torch.nn as nn 8 | import torch.utils.data as data 9 | import torch.backends.cudnn as cudnn 10 | import utils 11 | from models.unet import DiffusionUNet 12 | 13 | 14 | 15 | def data_transform(X): 16 | return 2 * X - 1.0 17 | 18 | 19 | def inverse_data_transform(X): 20 | return torch.clamp((X + 1.0) / 2.0, 0.0, 1.0) 21 | 22 | 23 | class EMAHelper(object): 24 | def __init__(self, mu=0.9999): 25 | self.mu = mu 26 | self.shadow = {} 27 | 28 | def register(self, module): 29 | if isinstance(module, nn.DataParallel): 30 | module = module.module 31 | for name, param in module.named_parameters(): 32 | if param.requires_grad: 33 | self.shadow[name] = param.data.clone() 34 | 35 | def update(self, module): 36 | if isinstance(module, nn.DataParallel): 37 | module = module.module 38 | for name, param in module.named_parameters(): 39 | if param.requires_grad: 40 | self.shadow[name].data = (1. - self.mu) * param.data + self.mu * self.shadow[name].data 41 | 42 | def ema(self, module): 43 | if isinstance(module, nn.DataParallel): 44 | module = module.module 45 | for name, param in module.named_parameters(): 46 | if param.requires_grad: 47 | param.data.copy_(self.shadow[name].data) 48 | 49 | def ema_copy(self, module): 50 | if isinstance(module, nn.DataParallel): 51 | inner_module = module.module 52 | module_copy = type(inner_module)(inner_module.config).to(inner_module.config.device) 53 | module_copy.load_state_dict(inner_module.state_dict()) 54 | module_copy = nn.DataParallel(module_copy) 55 | else: 56 | module_copy = type(module)(module.config).to(module.config.device) 57 | module_copy.load_state_dict(module.state_dict()) 58 | self.ema(module_copy) 59 | return module_copy 60 | 61 | def state_dict(self): 62 | return self.shadow 63 | 64 | def load_state_dict(self, state_dict): 65 | self.shadow = state_dict 66 | 67 | 68 | def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): 69 | def sigmoid(x): 70 | return 1 / (np.exp(-x) + 1) 71 | 72 | if beta_schedule == "quad": 73 | betas = (np.linspace(beta_start ** 0.5, beta_end ** 0.5, num_diffusion_timesteps, dtype=np.float64) ** 2) 74 | elif beta_schedule == "linear": 75 | betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) 76 | elif beta_schedule == "const": 77 | betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) 78 | elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 79 | betas = 1.0 / np.linspace(num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64) 80 | elif beta_schedule == "sigmoid": 81 | betas = np.linspace(-6, 6, num_diffusion_timesteps) 82 | betas = sigmoid(betas) * (beta_end - beta_start) + beta_start 83 | else: 84 | raise NotImplementedError(beta_schedule) 85 | assert betas.shape == (num_diffusion_timesteps,) 86 | return betas 87 | 88 | 89 | def noise_estimation_loss(model, x0, t, e, b, i, j, osize): 90 | a = (1 - b).cumprod(dim=0).index_select(0, t).view(-1, 1, 1, 1) 91 | x = x0[:, 6:, :, :] * a.sqrt() + e * (1.0 - a).sqrt() 92 | output = model(torch.cat([x0[:, :6, :, :], x], dim=1), t.float(), i, j, 93 | osize) 94 | 95 | 96 | return (e - output).square().sum(dim=(1, 2, 3)).mean(dim=0) 97 | 98 | 99 | def count_parameters(model): 100 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 101 | 102 | 103 | def count_flops(model, input_shape): 104 | flops = 0 105 | model.eval() 106 | with torch.no_grad(): 107 | input = torch.randn(1, *input_shape) 108 | model(input) 109 | flops = torch.cuda.memory_stats(0)["allocated_bytes.all.current"] 110 | model.train() 111 | return flops 112 | 113 | 114 | class DenoisingDiffusion(object): 115 | def __init__(self, args, config): 116 | super().__init__() 117 | self.args = args 118 | self.config = config 119 | self.device = config.device 120 | 121 | self.model = DiffusionUNet(config) 122 | self.model.to(self.device) 123 | self.model = torch.nn.DataParallel(self.model) 124 | 125 | 126 | num_params = count_parameters(self.model) 127 | print("parameters: {:,}".format(num_params)) 128 | 129 | self.ema_helper = EMAHelper() 130 | self.ema_helper.register(self.model) 131 | 132 | self.optimizer = utils.optimize.get_optimizer(self.config, self.model.parameters()) 133 | self.start_epoch, self.step = 0, 0 134 | 135 | betas = get_beta_schedule( 136 | beta_schedule=config.diffusion.beta_schedule, 137 | beta_start=config.diffusion.beta_start, 138 | beta_end=config.diffusion.beta_end, 139 | num_diffusion_timesteps=config.diffusion.num_diffusion_timesteps, 140 | ) 141 | 142 | betas = self.betas = torch.from_numpy(betas).float().to(self.device) 143 | self.num_timesteps = betas.shape[0] 144 | 145 | def load_ddm_ckpt(self, load_path, ema=False): 146 | checkpoint = utils.logging.load_checkpoint(load_path, None) 147 | self.start_epoch = checkpoint['epoch'] 148 | self.step = checkpoint['step'] 149 | self.model.load_state_dict(checkpoint['state_dict'], strict=True) 150 | self.optimizer.load_state_dict(checkpoint['optimizer']) 151 | self.ema_helper.load_state_dict(checkpoint['ema_helper']) 152 | if ema: 153 | self.ema_helper.ema(self.model) 154 | print("=> loaded checkpoint '{}' (epoch {}, step {})".format(load_path, checkpoint['epoch'], self.step)) 155 | 156 | def train(self, DATASET): 157 | cudnn.benchmark = True 158 | train_loader, val_loader = DATASET.get_loaders() 159 | 160 | if os.path.isfile(self.args.resume): 161 | self.load_ddm_ckpt(self.args.resume) 162 | 163 | for epoch in range(self.start_epoch, self.config.training.n_epochs): 164 | print('epoch: ', epoch) 165 | data_start = time.time() 166 | data_time = 0 167 | for inter, (x, y, i, j, osize) in enumerate(train_loader): 168 | x = x.flatten(start_dim=0, end_dim=1) if x.ndim == 5 else x 169 | n = x.size(0) 170 | data_time += time.time() - data_start 171 | self.model.train() 172 | self.step += 1 173 | 174 | x = x.to(self.device) 175 | x = data_transform(x) 176 | e = torch.randn_like(x[:, 6:, :, :]) 177 | b = self.betas 178 | 179 | # antithetic sampling 180 | t = torch.randint(low=0, high=self.num_timesteps, size=(n // 2 + 1,)).to( 181 | self.device) 182 | t = torch.cat([t, self.num_timesteps - t - 1], dim=0)[ 183 | :n] 184 | i = i.squeeze().to(self.device) 185 | j = j.squeeze().to(self.device) 186 | osize = osize.squeeze().to(self.device) 187 | i = i.view(n) 188 | j = j.view(n) 189 | osize = osize.view(n) 190 | # cond=torch.stack([ i, j, osize], dim=1) 191 | loss = noise_estimation_loss(self.model, x, t, e, b, i, j, 192 | osize) 193 | 194 | if self.step % 10 == 0: 195 | print(f"step: {self.step}, loss: {loss.item()}, data time: {data_time / (inter + 1)}") 196 | 197 | self.optimizer.zero_grad() 198 | loss.backward() 199 | self.optimizer.step() 200 | self.ema_helper.update(self.model) 201 | data_start = time.time() 202 | 203 | if self.step % self.config.training.validation_freq == 0: 204 | self.model.eval() 205 | self.sample_validation_patches(val_loader, self.step) 206 | 207 | if self.step % self.config.training.snapshot_freq == 0 or self.step == 1: 208 | utils.logging.save_checkpoint({ 209 | 'epoch': epoch + 1, 210 | 'step': self.step, 211 | 'state_dict': self.model.state_dict(), 212 | 'optimizer': self.optimizer.state_dict(), 213 | 'ema_helper': self.ema_helper.state_dict(), 214 | 'params': self.args, 215 | 'config': self.config 216 | }, filename=os.path.join(self.config.data.data_dir, 'ckpts',self.config.data.dataset + '_ddpm')) 217 | 218 | def sample_image(self, x_cond, x, ii, jj, osize, last=True, patch_locs=None, patch_size=None,patch_locs1=None,patch_locs2=None): 219 | skip = self.config.diffusion.num_diffusion_timesteps // self.args.sampling_timesteps 220 | seq = range(0, self.config.diffusion.num_diffusion_timesteps, skip) 221 | if patch_locs is not None: 222 | xs = utils.sampling.generalized_steps_overlapping(x, x_cond, seq, self.model, self.betas, ii=ii, jj=jj, 223 | osize=osize, eta=0., 224 | corners=patch_locs, p_size=patch_size,corners1=patch_locs1,corners2=patch_locs2) 225 | else: 226 | xs = utils.sampling.generalized_steps(x, x_cond, seq, self.model, self.betas, ii=ii, jj=jj, osize=osize, 227 | eta=0.) 228 | if last: 229 | xs = xs[0][-1] 230 | return xs 231 | 232 | def sample_validation_patches(self, val_loader, step): 233 | image_folder = os.path.join(self.args.image_folder, self.config.data.dataset + str(self.config.data.image_size)) 234 | with torch.no_grad(): 235 | print(f"Processing a single batch of validation images at step: {step}") 236 | for i, (x, y, ii, jj, osize) in enumerate(val_loader): 237 | x = x.flatten(start_dim=0, end_dim=1) if x.ndim == 5 else x 238 | break 239 | n = x.size(0) 240 | x_cond = x[:, :6, :, :].to(self.device) 241 | x_cond = data_transform(x_cond) 242 | # print("osize:", osize) 243 | # print("osize shape:", osize.shape) 244 | 245 | x = torch.randn(n, 3, int(osize[0, 0]), int(osize[0, 1]), device=self.device) # 这里是噪声项,需要3通道 246 | ii = ii.squeeze() 247 | jj = jj.squeeze() 248 | osize = osize.squeeze() 249 | ii = ii.view(n) 250 | jj = jj.view(n) 251 | osize = osize.view(n) 252 | x = self.sample_image(x_cond, x, ii, jj, osize) 253 | x = inverse_data_transform(x) 254 | x_cond = inverse_data_transform(x_cond) 255 | 256 | for i in range(n): 257 | utils.logging.save_image(x_cond[i][:3, :, :], os.path.join(image_folder, str(step), f"{i}_cond.png")) 258 | utils.logging.save_image(x[i], os.path.join(image_folder, str(step), f"{i}.png")) 259 | -------------------------------------------------------------------------------- /models/unet.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | 7 | def get_timestep_embedding(timesteps, embedding_dim): 8 | """ 9 | This matches the implementation in Denoising Diffusion Probabilistic Models: 10 | From Fairseq. 11 | Build sinusoidal embeddings. 12 | This matches the implementation in tensor2tensor, but differs slightly 13 | from the description in Section 3.5 of "Attention Is All You Need". 14 | """ 15 | # assert len(timesteps.shape) == 1 16 | 17 | half_dim = embedding_dim // 2 18 | 19 | emb = math.log(10000) / (half_dim - 1) 20 | 21 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) 22 | 23 | emb = emb.to(device=timesteps.device) 24 | 25 | emb = timesteps.float()[:, None] * emb[None, :] 26 | 27 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) 28 | 29 | if embedding_dim % 2 != 0: 30 | emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) 31 | 32 | return emb 33 | 34 | 35 | 36 | def nonlinearity(x): 37 | # swish 38 | return x*torch.sigmoid(x) 39 | 40 | 41 | def Normalize(in_channels): 42 | return torch.nn.GroupNorm(num_groups=8, num_channels=in_channels, eps=1e-6, affine=True) 43 | 44 | 45 | class Upsample(nn.Module): 46 | def __init__(self, in_channels, with_conv): 47 | super().__init__() 48 | self.with_conv = with_conv 49 | if self.with_conv: 50 | self.conv = torch.nn.Conv2d(in_channels, 51 | in_channels, 52 | kernel_size=3, 53 | stride=1, 54 | padding=1) 55 | 56 | def forward(self, x): 57 | x = torch.nn.functional.interpolate( 58 | x, scale_factor=2.0, mode="nearest") 59 | if self.with_conv: 60 | x = self.conv(x) 61 | return x 62 | 63 | 64 | class Downsample(nn.Module): 65 | def __init__(self, in_channels, with_conv): 66 | super().__init__() 67 | self.with_conv = with_conv 68 | if self.with_conv: 69 | # no asymmetric padding in torch conv, must do it ourselves 70 | self.conv = torch.nn.Conv2d(in_channels, 71 | in_channels, 72 | kernel_size=3, 73 | stride=2, 74 | padding=0) 75 | 76 | def forward(self, x): 77 | if self.with_conv: 78 | pad = (0, 1, 0, 1) 79 | x = torch.nn.functional.pad(x, pad, mode="constant", value=0) 80 | x = self.conv(x) 81 | else: 82 | x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) 83 | return x 84 | 85 | 86 | class HighMixer(nn.Module): 87 | def __init__(self, dim, kernel_size=3, stride=1, padding=1, 88 | **kwargs, ): 89 | super().__init__() 90 | 91 | self.cnn_in = cnn_in = dim 92 | # self.pool_in = pool_in = dim-cnn_in 93 | 94 | self.cnn_dim = cnn_dim = cnn_in 95 | # self.pool_dim = pool_dim = pool_in 96 | 97 | self.conv1 = nn.Conv2d(cnn_in, cnn_dim, kernel_size=1, stride=1, padding=0, bias=False) 98 | self.proj1 = nn.Conv2d(cnn_dim, cnn_dim, kernel_size=kernel_size, stride=stride, padding=padding, bias=False, 99 | groups=cnn_dim) 100 | self.mid_gelu1 = nn.GELU() 101 | 102 | 103 | 104 | def forward(self, x): 105 | # B, C H, W 106 | 107 | cx = x[:, :self.cnn_in, :, :].contiguous() 108 | cx = self.conv1(cx) 109 | cx = self.proj1(cx) 110 | cx = self.mid_gelu1(cx) 111 | 112 | return cx 113 | 114 | 115 | class LowMixer(nn.Module): 116 | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., pool_size=2, 117 | **kwargs, ): 118 | super().__init__() 119 | self.num_heads = num_heads 120 | self.head_dim = head_dim = dim // num_heads 121 | self.scale = head_dim ** -0.5 122 | self.dim = dim 123 | 124 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 125 | self.attn_drop = nn.Dropout(attn_drop) 126 | 127 | self.pool = nn.AvgPool2d(pool_size, stride=pool_size, padding=0, 128 | count_include_pad=False) if pool_size > 1 else nn.Identity() 129 | self.uppool = nn.Upsample(scale_factor=pool_size) if pool_size > 1 else nn.Identity() 130 | 131 | def att_fun(self, q, k, v, B, N, C): 132 | attn = (q @ k.transpose(-2, -1)) * self.scale 133 | attn = attn.softmax(dim=-1) 134 | attn = self.attn_drop(attn) 135 | # x = (attn @ v).transpose(1, 2).reshape(B, N, C) 136 | x = (attn @ v).transpose(2, 3).reshape(B, C, N) 137 | return x 138 | 139 | def forward(self, x): 140 | # B, C, H, W 141 | B, _, _, _ = x.shape 142 | x = self.pool(x) 143 | xa = x.permute(0, 2, 3, 1).view(B, -1, self.dim) 144 | B, N, C = xa.shape 145 | qkv = self.qkv(xa).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 146 | q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) 147 | xa = self.att_fun(q, k, v, B, N, C) 148 | xa = xa.view(B, C, int(N ** 0.5), int(N ** 0.5)) # .permute(0, 3, 1, 2) 149 | 150 | xa = self.uppool(xa) 151 | return xa 152 | 153 | 154 | class Mixer(nn.Module): 155 | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., attention_head=1, pool_size=2, 156 | **kwargs, ): 157 | super().__init__() 158 | self.num_heads = num_heads 159 | self.low_dim = low_dim = dim //2 160 | self.high_dim = high_dim = dim - low_dim 161 | self.high_mixer = HighMixer(high_dim) 162 | self.low_mixer = LowMixer(low_dim, num_heads=attention_head, qkv_bias=qkv_bias, attn_drop=attn_drop, 163 | pool_size=pool_size, ) 164 | 165 | self.conv_fuse = nn.Conv2d(low_dim + high_dim , low_dim + high_dim , kernel_size=3, stride=1, padding=1, 166 | bias=False, groups=low_dim + high_dim) 167 | self.proj = nn.Conv2d(low_dim + high_dim, dim, kernel_size=1, stride=1, padding=0) 168 | self.proj_drop = nn.Dropout(proj_drop) 169 | 170 | self.freblock=FreBlock(dim,dim) 171 | self.finalproj = nn.Conv2d(2 * dim, dim, 1, 1, 0) 172 | def forward(self, x): 173 | 174 | B, C, H, W = x.shape #16,128,64,64 175 | #x = x.permute(0, 3, 1, 2) 176 | x_ori=x 177 | 178 | hx = x[:, :self.high_dim, :, :].contiguous()#16,64,64,64 179 | hx = self.high_mixer(hx)#16,64,64,64 180 | 181 | lx = x[:, self.high_dim:, :, :].contiguous()#16,64,64,64 182 | lx = self.low_mixer(lx)#16,64,64,64 183 | x = torch.cat((hx, lx), dim=1)#16,128,64,64 184 | 185 | x = x + self.conv_fuse(x) 186 | x_sptial = self.proj(x) 187 | 188 | x_freq=self.freblock(x_ori) 189 | 190 | x_out=torch.cat((x_sptial,x_freq),1) 191 | x_out=self.finalproj(x_out) 192 | x_out=self.proj_drop(x_out) 193 | return x_out+x_ori 194 | 195 | 196 | class ResnetBlock(nn.Module): 197 | def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, 198 | dropout, temb_channels=512,incep=False): 199 | super().__init__() 200 | self.in_channels = in_channels 201 | out_channels = in_channels if out_channels is None else out_channels 202 | self.out_channels = out_channels 203 | self.use_conv_shortcut = conv_shortcut 204 | 205 | self.norm1 = Normalize(in_channels) 206 | self.conv1 = torch.nn.Conv2d(in_channels, 207 | out_channels, 208 | kernel_size=3, 209 | stride=1, 210 | padding=1) 211 | # self.conv1=FreBlock(in_channels,out_channels) 212 | self.temb_proj = torch.nn.Linear(temb_channels, 213 | out_channels) 214 | self.norm2 = Normalize(out_channels) 215 | self.dropout = torch.nn.Dropout(dropout) 216 | if incep==True: 217 | self.conv2 =Mixer(dim=out_channels) 218 | else: 219 | self.conv2 = torch.nn.Conv2d(out_channels, 220 | out_channels, 221 | kernel_size=3, 222 | stride=1, 223 | padding=1) 224 | if self.in_channels != self.out_channels: 225 | if self.use_conv_shortcut: 226 | self.conv_shortcut = torch.nn.Conv2d(in_channels, 227 | out_channels, 228 | kernel_size=3, 229 | stride=1, 230 | padding=1) 231 | else: 232 | self.nin_shortcut = torch.nn.Conv2d(in_channels, 233 | out_channels, 234 | kernel_size=1, 235 | stride=1, 236 | padding=0) 237 | 238 | def forward(self, x, temb): 239 | h = x 240 | h = self.norm1(h) 241 | h = nonlinearity(h) #(16,128,64,64) 242 | h = self.conv1(h) #(16,128,64,64) 243 | 244 | h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] 245 | 246 | h = self.norm2(h) 247 | h = nonlinearity(h) 248 | h = self.dropout(h) 249 | h = self.conv2(h) #(16,128,64,64) 250 | 251 | if self.in_channels != self.out_channels: 252 | if self.use_conv_shortcut: 253 | x = self.conv_shortcut(x) 254 | else: 255 | x = self.nin_shortcut(x) 256 | 257 | return x+h 258 | 259 | 260 | class AttnBlock(nn.Module): 261 | def __init__(self, in_channels): 262 | super().__init__() 263 | self.in_channels = in_channels 264 | 265 | self.norm = Normalize(in_channels) 266 | self.q = torch.nn.Conv2d(in_channels, 267 | in_channels, 268 | kernel_size=1, 269 | stride=1, 270 | padding=0) 271 | self.k = torch.nn.Conv2d(in_channels, 272 | in_channels, 273 | kernel_size=1, 274 | stride=1, 275 | padding=0) 276 | self.v = torch.nn.Conv2d(in_channels, 277 | in_channels, 278 | kernel_size=1, 279 | stride=1, 280 | padding=0) 281 | self.proj_out = torch.nn.Conv2d(in_channels, 282 | in_channels, 283 | kernel_size=1, 284 | stride=1, 285 | padding=0) 286 | 287 | def forward(self, x): 288 | h_ = x 289 | h_ = self.norm(h_) 290 | q = self.q(h_) 291 | k = self.k(h_) 292 | v = self.v(h_) 293 | 294 | # compute attention 295 | b, c, h, w = q.shape 296 | q = q.reshape(b, c, h*w) 297 | q = q.permute(0, 2, 1) # b,hw,c 298 | k = k.reshape(b, c, h*w) # b,c,hw 299 | w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] 300 | w_ = w_ * (int(c)**(-0.5)) 301 | w_ = torch.nn.functional.softmax(w_, dim=2) 302 | 303 | # attend to values 304 | v = v.reshape(b, c, h*w) 305 | w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) 306 | # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] 307 | h_ = torch.bmm(v, w_) 308 | h_ = h_.reshape(b, c, h, w) 309 | 310 | h_ = self.proj_out(h_) 311 | 312 | return x+h_ 313 | 314 | class FreBlock(nn.Module): 315 | def __init__(self, in_channels,out_channels): 316 | super(FreBlock, self).__init__() 317 | self.processmag = nn.Sequential( 318 | nn.Conv2d(in_channels,in_channels,1,1,0), 319 | nn.LeakyReLU(0.1,inplace=True), 320 | nn.Conv2d(in_channels,out_channels,1,1,0)) 321 | self.processpha = nn.Sequential( 322 | nn.Conv2d(in_channels, in_channels, 1, 1, 0), 323 | nn.LeakyReLU(0.1, inplace=True), 324 | nn.Conv2d(in_channels, out_channels, 1, 1, 0)) 325 | 326 | def forward(self,x): 327 | xori = x 328 | _, _, H, W = x.shape 329 | x_freq = torch.fft.rfft2(x, norm='backward') 330 | mag = torch.abs(x_freq) 331 | pha = torch.angle(x_freq) 332 | mag = self.processmag(mag) 333 | pha = self.processpha(pha) 334 | real = mag * torch.cos(pha) 335 | imag = mag * torch.sin(pha) 336 | x_out = torch.complex(real, imag) 337 | x_out1 = torch.fft.irfft2(x_out, s=(H, W), norm='backward') 338 | return x_out1 339 | class DiffusionUNet(nn.Module): 340 | def __init__(self, config): 341 | super().__init__() 342 | self.config = config 343 | ch, out_ch, ch_mult = config.model.ch, config.model.out_ch, tuple(config.model.ch_mult) 344 | num_res_blocks = config.model.num_res_blocks 345 | attn_resolutions = config.model.attn_resolutions 346 | dropout = config.model.dropout 347 | in_channels = config.model.in_channels * 3 if config.data.conditional else config.model.in_channels##为了max图,所以*3 348 | resolution = config.data.image_size 349 | resamp_with_conv = config.model.resamp_with_conv 350 | 351 | self.ch = ch #128 352 | self.temb_ch = self.ch*4 #512 353 | self.num_resolutions = len(ch_mult) 354 | self.num_res_blocks = num_res_blocks 355 | self.resolution = resolution 356 | self.in_channels = in_channels 357 | 358 | self.inceplayers=2 359 | 360 | # timestep embedding 361 | self.temb = nn.Module() 362 | self.temb.dense = nn.ModuleList([ 363 | torch.nn.Linear(self.ch, 364 | self.temb_ch), 365 | torch.nn.Linear(self.temb_ch, 366 | self.temb_ch), 367 | ]) 368 | 369 | # downsampling 370 | self.conv_in = torch.nn.Conv2d(in_channels, 371 | self.ch, 372 | kernel_size=3, 373 | stride=1, 374 | padding=1) 375 | 376 | curr_res = resolution 377 | in_ch_mult = (1,)+ch_mult 378 | self.down = nn.ModuleList() 379 | block_in = None 380 | for i_level in range(self.num_resolutions): 381 | block = nn.ModuleList() 382 | attn = nn.ModuleList() 383 | block_in = ch*in_ch_mult[i_level] 384 | block_out = ch*ch_mult[i_level] 385 | for i_block in range(self.num_res_blocks): 386 | if i_level+self.inceplayers=self.num_resolutions: 431 | block.append(ResnetBlock(in_channels=block_in+skip_in, 432 | out_channels=block_out, 433 | temb_channels=self.temb_ch, 434 | dropout=dropout,incep=True)) 435 | else: 436 | block.append(ResnetBlock(in_channels=block_in+skip_in, 437 | out_channels=block_out, 438 | temb_channels=self.temb_ch, 439 | dropout=dropout,incep=True)) 440 | block_in = block_out 441 | # if curr_res in attn_resolutions: 442 | # attn.append(AttnBlock(block_in)) 443 | up = nn.Module() 444 | up.block = block 445 | up.attn = attn 446 | if i_level != 0 : 447 | if ch_mult[i_level]== ch_mult[i_level-1]*2: 448 | up.upsample = Upsample(block_in, resamp_with_conv) 449 | curr_res = curr_res * 2 450 | self.up.insert(0, up) # prepend to get consistent order 451 | 452 | # end 453 | self.norm_out = Normalize(block_in) 454 | self.conv_out = torch.nn.Conv2d(block_in, 455 | out_ch, 456 | kernel_size=3, 457 | stride=1, 458 | padding=1) 459 | 460 | def forward(self, x, t,i,j,osize): 461 | assert x.shape[2] == x.shape[3] 462 | # timestep embedding 463 | # t=torch.stack([t.unsqueeze(1),cond],dim=1) 464 | temb1 = get_timestep_embedding(t, self.ch//4)#(16,32) 465 | temb2 = get_timestep_embedding(i, self.ch//4) # (16,32) 466 | temb3 = get_timestep_embedding(j, self.ch//4) # (16,32) 467 | temb4 = get_timestep_embedding(osize, self.ch//4) # (16,32) 468 | 469 | temb=torch.cat([temb1,temb2,temb3,temb4],dim=1) 470 | temb = self.temb.dense[0](temb) #(16,512) 471 | temb = nonlinearity(temb) #(16,512) 472 | temb = self.temb.dense[1](temb) #(16,512) 473 | 474 | # downsampling 475 | hs = [self.conv_in(x)] 476 | for i_level in range(self.num_resolutions): 477 | for i_block in range(self.num_res_blocks): 478 | h = self.down[i_level].block[i_block](hs[-1], temb) 479 | if len(self.down[i_level].attn) > 0: 480 | h = self.down[i_level].attn[i_block](h) 481 | hs.append(h) 482 | 483 | if i_level != self.num_resolutions-1 : 484 | hs.append(self.down[i_level].downsample(hs[-1])) 485 | 486 | 487 | # middle 488 | h = hs[-1] 489 | h = self.mid.block_1(h, temb) 490 | h = self.mid.block_2(h, temb) 491 | 492 | # upsampling 493 | for i_level in reversed(range(self.num_resolutions)): 494 | for i_block in range(self.num_res_blocks+1): 495 | h = self.up[i_level].block[i_block]( 496 | torch.cat([h, hs.pop()], dim=1), temb) 497 | if len(self.up[i_level].attn) > 0: 498 | h = self.up[i_level].attn[i_block](h) 499 | if i_level != 0: 500 | h = self.up[i_level].upsample(h) 501 | 502 | # end 503 | h = self.norm_out(h) 504 | h = nonlinearity(h) 505 | h = self.conv_out(h) 506 | return h 507 | --------------------------------------------------------------------------------