├── datasets
├── scratch
│ └── LLIE
│ │ ├── ckpts
│ │ └── pretrained model.txt
│ │ └── data
│ │ ├── lowlight
│ │ └── test
│ │ │ ├── gt
│ │ │ ├── 1.png
│ │ │ ├── 22.png
│ │ │ ├── 23.png
│ │ │ ├── 55.png
│ │ │ ├── 79.png
│ │ │ ├── 111.png
│ │ │ ├── 146.png
│ │ │ ├── 179.png
│ │ │ ├── 493.png
│ │ │ ├── 547.png
│ │ │ ├── 665.png
│ │ │ ├── 669.png
│ │ │ ├── 748.png
│ │ │ ├── 778.png
│ │ │ └── 780.png
│ │ │ ├── input
│ │ │ ├── 1.png
│ │ │ ├── 111.png
│ │ │ ├── 146.png
│ │ │ ├── 179.png
│ │ │ ├── 22.png
│ │ │ ├── 23.png
│ │ │ ├── 493.png
│ │ │ ├── 547.png
│ │ │ ├── 55.png
│ │ │ ├── 665.png
│ │ │ ├── 669.png
│ │ │ ├── 748.png
│ │ │ ├── 778.png
│ │ │ ├── 780.png
│ │ │ └── 79.png
│ │ │ └── lowlighttesta.txt
│ │ └── read.py
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-38.pyc
│ ├── lowlight.cpython-38.pyc
│ ├── raindrop.cpython-38.pyc
│ ├── raindrop1.cpython-38.pyc
│ ├── snow100k.cpython-38.pyc
│ ├── allweather.cpython-38.pyc
│ └── outdoorrain.cpython-38.pyc
└── lowlight.py
├── core
├── figs
├── com.png
├── v1.png
├── vis.png
└── unpaired.png
├── models
├── __init__.py
├── __pycache__
│ ├── ddm.cpython-38.pyc
│ ├── ddm1.cpython-38.pyc
│ ├── unet.cpython-38.pyc
│ ├── unet1.cpython-38.pyc
│ ├── unet2.cpython-38.pyc
│ ├── unet6.cpython-38.pyc
│ ├── __init__.cpython-38.pyc
│ ├── restoration.cpython-38.pyc
│ └── restoration1.cpython-38.pyc
├── restoration.py
├── ddm.py
└── unet.py
├── utils
├── __init__.py
├── __pycache__
│ ├── logging.cpython-38.pyc
│ ├── metrics.cpython-38.pyc
│ ├── __init__.cpython-38.pyc
│ ├── optimize.cpython-38.pyc
│ └── sampling.cpython-38.pyc
├── logging.py
├── optimize.py
├── metrics.py
└── sampling.py
├── results
└── images
│ └── lowlight
│ └── lowlight
│ ├── ['1'].png
│ ├── ['111'].png
│ ├── ['146'].png
│ ├── ['179'].png
│ ├── ['22'].png
│ ├── ['23'].png
│ ├── ['493'].png
│ ├── ['547'].png
│ ├── ['55'].png
│ ├── ['665'].png
│ ├── ['669'].png
│ ├── ['748'].png
│ ├── ['778'].png
│ ├── ['780'].png
│ └── ['79'].png
├── requirements.txt
├── configs
└── lowlight.yml
├── train_diffusion.py
├── eval_diffusion.py
├── README.md
└── evaluation.py
/datasets/scratch/LLIE/ckpts/pretrained model.txt:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/core:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/core
--------------------------------------------------------------------------------
/figs/com.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/figs/com.png
--------------------------------------------------------------------------------
/figs/v1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/figs/v1.png
--------------------------------------------------------------------------------
/figs/vis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/figs/vis.png
--------------------------------------------------------------------------------
/figs/unpaired.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/figs/unpaired.png
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from models.ddm import *
2 | from models.restoration import *
3 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from datasets.lowlight import *
2 | __all__ = ["AllWeather","lowlight"]
3 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from utils.logging import *
2 | from utils.sampling import *
3 | from utils.optimize import *
4 |
--------------------------------------------------------------------------------
/models/__pycache__/ddm.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/models/__pycache__/ddm.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/ddm1.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/models/__pycache__/ddm1.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/unet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/models/__pycache__/unet.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/unet1.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/models/__pycache__/unet1.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/unet2.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/models/__pycache__/unet2.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/unet6.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/models/__pycache__/unet6.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/logging.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/utils/__pycache__/logging.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/metrics.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/utils/__pycache__/metrics.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/models/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/results/images/lowlight/lowlight/['1'].png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/results/images/lowlight/lowlight/['1'].png
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/utils/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/optimize.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/utils/__pycache__/optimize.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/sampling.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/utils/__pycache__/sampling.cpython-38.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/lowlight.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/__pycache__/lowlight.cpython-38.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/raindrop.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/__pycache__/raindrop.cpython-38.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/raindrop1.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/__pycache__/raindrop1.cpython-38.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/snow100k.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/__pycache__/snow100k.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/restoration.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/models/__pycache__/restoration.cpython-38.pyc
--------------------------------------------------------------------------------
/results/images/lowlight/lowlight/['111'].png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/results/images/lowlight/lowlight/['111'].png
--------------------------------------------------------------------------------
/results/images/lowlight/lowlight/['146'].png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/results/images/lowlight/lowlight/['146'].png
--------------------------------------------------------------------------------
/results/images/lowlight/lowlight/['179'].png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/results/images/lowlight/lowlight/['179'].png
--------------------------------------------------------------------------------
/results/images/lowlight/lowlight/['22'].png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/results/images/lowlight/lowlight/['22'].png
--------------------------------------------------------------------------------
/results/images/lowlight/lowlight/['23'].png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/results/images/lowlight/lowlight/['23'].png
--------------------------------------------------------------------------------
/results/images/lowlight/lowlight/['493'].png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/results/images/lowlight/lowlight/['493'].png
--------------------------------------------------------------------------------
/results/images/lowlight/lowlight/['547'].png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/results/images/lowlight/lowlight/['547'].png
--------------------------------------------------------------------------------
/results/images/lowlight/lowlight/['55'].png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/results/images/lowlight/lowlight/['55'].png
--------------------------------------------------------------------------------
/results/images/lowlight/lowlight/['665'].png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/results/images/lowlight/lowlight/['665'].png
--------------------------------------------------------------------------------
/results/images/lowlight/lowlight/['669'].png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/results/images/lowlight/lowlight/['669'].png
--------------------------------------------------------------------------------
/results/images/lowlight/lowlight/['748'].png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/results/images/lowlight/lowlight/['748'].png
--------------------------------------------------------------------------------
/results/images/lowlight/lowlight/['778'].png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/results/images/lowlight/lowlight/['778'].png
--------------------------------------------------------------------------------
/results/images/lowlight/lowlight/['780'].png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/results/images/lowlight/lowlight/['780'].png
--------------------------------------------------------------------------------
/results/images/lowlight/lowlight/['79'].png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/results/images/lowlight/lowlight/['79'].png
--------------------------------------------------------------------------------
/datasets/__pycache__/allweather.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/__pycache__/allweather.cpython-38.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/outdoorrain.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/__pycache__/outdoorrain.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/restoration1.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/models/__pycache__/restoration1.cpython-38.pyc
--------------------------------------------------------------------------------
/datasets/scratch/LLIE/data/lowlight/test/gt/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/scratch/LLIE/data/lowlight/test/gt/1.png
--------------------------------------------------------------------------------
/datasets/scratch/LLIE/data/lowlight/test/gt/22.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/scratch/LLIE/data/lowlight/test/gt/22.png
--------------------------------------------------------------------------------
/datasets/scratch/LLIE/data/lowlight/test/gt/23.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/scratch/LLIE/data/lowlight/test/gt/23.png
--------------------------------------------------------------------------------
/datasets/scratch/LLIE/data/lowlight/test/gt/55.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/scratch/LLIE/data/lowlight/test/gt/55.png
--------------------------------------------------------------------------------
/datasets/scratch/LLIE/data/lowlight/test/gt/79.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/scratch/LLIE/data/lowlight/test/gt/79.png
--------------------------------------------------------------------------------
/datasets/scratch/LLIE/data/lowlight/test/gt/111.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/scratch/LLIE/data/lowlight/test/gt/111.png
--------------------------------------------------------------------------------
/datasets/scratch/LLIE/data/lowlight/test/gt/146.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/scratch/LLIE/data/lowlight/test/gt/146.png
--------------------------------------------------------------------------------
/datasets/scratch/LLIE/data/lowlight/test/gt/179.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/scratch/LLIE/data/lowlight/test/gt/179.png
--------------------------------------------------------------------------------
/datasets/scratch/LLIE/data/lowlight/test/gt/493.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/scratch/LLIE/data/lowlight/test/gt/493.png
--------------------------------------------------------------------------------
/datasets/scratch/LLIE/data/lowlight/test/gt/547.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/scratch/LLIE/data/lowlight/test/gt/547.png
--------------------------------------------------------------------------------
/datasets/scratch/LLIE/data/lowlight/test/gt/665.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/scratch/LLIE/data/lowlight/test/gt/665.png
--------------------------------------------------------------------------------
/datasets/scratch/LLIE/data/lowlight/test/gt/669.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/scratch/LLIE/data/lowlight/test/gt/669.png
--------------------------------------------------------------------------------
/datasets/scratch/LLIE/data/lowlight/test/gt/748.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/scratch/LLIE/data/lowlight/test/gt/748.png
--------------------------------------------------------------------------------
/datasets/scratch/LLIE/data/lowlight/test/gt/778.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/scratch/LLIE/data/lowlight/test/gt/778.png
--------------------------------------------------------------------------------
/datasets/scratch/LLIE/data/lowlight/test/gt/780.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/scratch/LLIE/data/lowlight/test/gt/780.png
--------------------------------------------------------------------------------
/datasets/scratch/LLIE/data/lowlight/test/input/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/scratch/LLIE/data/lowlight/test/input/1.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==1.21.5
2 | opencv_python==4.5.5.64
3 | Pillow==8.4.0
4 | PyYAML==6.0
5 | torch==1.10.0
6 | torchvision==0.11.1
7 | tqdm==4.61.2
8 |
--------------------------------------------------------------------------------
/datasets/scratch/LLIE/data/lowlight/test/input/111.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/scratch/LLIE/data/lowlight/test/input/111.png
--------------------------------------------------------------------------------
/datasets/scratch/LLIE/data/lowlight/test/input/146.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/scratch/LLIE/data/lowlight/test/input/146.png
--------------------------------------------------------------------------------
/datasets/scratch/LLIE/data/lowlight/test/input/179.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/scratch/LLIE/data/lowlight/test/input/179.png
--------------------------------------------------------------------------------
/datasets/scratch/LLIE/data/lowlight/test/input/22.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/scratch/LLIE/data/lowlight/test/input/22.png
--------------------------------------------------------------------------------
/datasets/scratch/LLIE/data/lowlight/test/input/23.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/scratch/LLIE/data/lowlight/test/input/23.png
--------------------------------------------------------------------------------
/datasets/scratch/LLIE/data/lowlight/test/input/493.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/scratch/LLIE/data/lowlight/test/input/493.png
--------------------------------------------------------------------------------
/datasets/scratch/LLIE/data/lowlight/test/input/547.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/scratch/LLIE/data/lowlight/test/input/547.png
--------------------------------------------------------------------------------
/datasets/scratch/LLIE/data/lowlight/test/input/55.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/scratch/LLIE/data/lowlight/test/input/55.png
--------------------------------------------------------------------------------
/datasets/scratch/LLIE/data/lowlight/test/input/665.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/scratch/LLIE/data/lowlight/test/input/665.png
--------------------------------------------------------------------------------
/datasets/scratch/LLIE/data/lowlight/test/input/669.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/scratch/LLIE/data/lowlight/test/input/669.png
--------------------------------------------------------------------------------
/datasets/scratch/LLIE/data/lowlight/test/input/748.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/scratch/LLIE/data/lowlight/test/input/748.png
--------------------------------------------------------------------------------
/datasets/scratch/LLIE/data/lowlight/test/input/778.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/scratch/LLIE/data/lowlight/test/input/778.png
--------------------------------------------------------------------------------
/datasets/scratch/LLIE/data/lowlight/test/input/780.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/scratch/LLIE/data/lowlight/test/input/780.png
--------------------------------------------------------------------------------
/datasets/scratch/LLIE/data/lowlight/test/input/79.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Oliiveralien/MDMS/HEAD/datasets/scratch/LLIE/data/lowlight/test/input/79.png
--------------------------------------------------------------------------------
/utils/logging.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import shutil
3 | import os
4 | import torchvision.utils as tvu
5 |
6 |
7 | def save_image(img, file_directory):
8 | if not os.path.exists(os.path.dirname(file_directory)):
9 | os.makedirs(os.path.dirname(file_directory))
10 | tvu.save_image(img, file_directory)
11 |
12 |
13 | def save_checkpoint(state, filename):
14 | if not os.path.exists(os.path.dirname(filename)):
15 | os.makedirs(os.path.dirname(filename))
16 | torch.save(state, filename + '.pth.tar')
17 |
18 |
19 | def load_checkpoint(path, device):
20 | if device is None:
21 | return torch.load(path)
22 | else:
23 | return torch.load(path, map_location=device)
24 |
--------------------------------------------------------------------------------
/datasets/scratch/LLIE/data/read.py:
--------------------------------------------------------------------------------
1 | import os
2 | import codecs # 读取文件夹中的文件名
3 |
4 | folder = '/home/shangkai/MDMS/MDMS-main/datasets/scratch/LLIE/data/lowlight/test/input'
5 | filenames = os.listdir(folder)
6 |
7 | # 将文件名写入 txt 文件
8 | txt_file = r'/home/shangkai/MDMS/MDMS-main/datasets/scratch/LLIE/data/lowlight/test/lowlighttesta.txt'
9 | with codecs.open(txt_file, 'w', 'utf-8') as f:
10 | for filename in filenames:
11 | filepath = os.path.dirname(__file__)
12 | filepath = os.path.join(filepath, '/lowlight/test/input')
13 | filename = '/home/shangkai/MDMS/MDMS-main/datasets/scratch/LLIE/data/lowlight/test/input/'+filename
14 |
15 | f.write(filename + '\n')
16 |
--------------------------------------------------------------------------------
/utils/optimize.py:
--------------------------------------------------------------------------------
1 | import torch.optim as optim
2 |
3 |
4 | def get_optimizer(config, parameters):
5 | if config.optim.optimizer == 'Adam':
6 | return optim.Adam(parameters, lr=config.optim.lr, weight_decay=config.optim.weight_decay,
7 | betas=(0.9, 0.999), amsgrad=config.optim.amsgrad, eps=config.optim.eps)
8 | elif config.optim.optimizer == 'RMSProp':
9 | return optim.RMSprop(parameters, lr=config.optim.lr, weight_decay=config.optim.weight_decay)
10 | elif config.optim.optimizer == 'SGD':
11 | return optim.SGD(parameters, lr=config.optim.lr, momentum=0.9)
12 | else:
13 | raise NotImplementedError('Optimizer {} not understood.'.format(config.optim.optimizer))
14 |
--------------------------------------------------------------------------------
/configs/lowlight.yml:
--------------------------------------------------------------------------------
1 | data:
2 | dataset: "lowlight"
3 | image_size: 64
4 | channels: 3
5 | num_workers: 0
6 | data_dir: "/root/LLIE/MDMS-main/datasets/scratch/LLIE/"
7 | conditional: True
8 |
9 | model:
10 | in_channels: 3
11 | out_ch: 3
12 | ch: 128
13 | #ch: 32
14 | # ch_mult: [1, 1, 2, 2, 4, 4]
15 | ch_mult: [1, 2, 4,8]
16 | num_res_blocks: 2
17 | attn_resolutions: []
18 | dropout: 0.0
19 | ema_rate: 0.999
20 | ema: True
21 | resamp_with_conv: True
22 |
23 | diffusion:
24 | beta_schedule: linear
25 | beta_start: 0.0001
26 | beta_end: 0.02
27 | num_diffusion_timesteps: 1000
28 |
29 | training:
30 | patch_n: 2
31 | batch_size: 38
32 | n_epochs: 200000
33 | n_iters: 2000000
34 | snapshot_freq: 200
35 | validation_freq: 200
36 |
37 | sampling:
38 | batch_size: 1 #4
39 | last_only: True
40 |
41 | optim:
42 | weight_decay: 0.000
43 | optimizer: "Adam"
44 | lr: 0.00002
45 | amsgrad: False
46 | eps: 0.00000001
47 |
--------------------------------------------------------------------------------
/datasets/scratch/LLIE/data/lowlight/test/lowlighttesta.txt:
--------------------------------------------------------------------------------
1 | /root/LLIE/MDMS-main/datasets/scratch/LLIE/data/lowlight/test/input/1.png
2 | /root/LLIE/MDMS-main/datasets/scratch/LLIE/data/lowlight/test/input/111.png
3 | /root/LLIE/MDMS-main/datasets/scratch/LLIE/data/lowlight/test/input/146.png
4 | /root/LLIE/MDMS-main/datasets/scratch/LLIE/data/lowlight/test/input/179.png
5 | /root/LLIE/MDMS-main/datasets/scratch/LLIE/data/lowlight/test/input/22.png
6 | /root/LLIE/MDMS-main/datasets/scratch/LLIE/data/lowlight/test/input/23.png
7 | /root/LLIE/MDMS-main/datasets/scratch/LLIE/data/lowlight/test/input/493.png
8 | /root/LLIE/MDMS-main/datasets/scratch/LLIE/data/lowlight/test/input/547.png
9 | /root/LLIE/MDMS-main/datasets/scratch/LLIE/data/lowlight/test/input/55.png
10 | /root/LLIE/MDMS-main/datasets/scratch/LLIE/data/lowlight/test/input/665.png
11 | /root/LLIE/MDMS-main/datasets/scratch/LLIE/data/lowlight/test/input/669.png
12 | /root/LLIE/MDMS-main/datasets/scratch/LLIE/data/lowlight/test/input/748.png
13 | /root/LLIE/MDMS-main/datasets/scratch/LLIE/data/lowlight/test/input/778.png
14 | /root/LLIE/MDMS-main/datasets/scratch/LLIE/data/lowlight/test/input/780.png
15 | /root/LLIE/MDMS-main/datasets/scratch/LLIE/data/lowlight/test/input/79.png
16 |
--------------------------------------------------------------------------------
/train_diffusion.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import random
4 | import socket
5 | import yaml
6 | import torch
7 | import torch.backends.cudnn as cudnn
8 | import torch.utils.data
9 | import numpy as np
10 | import torchvision
11 | import models
12 | import datasets
13 | import utils
14 | from models import DenoisingDiffusion
15 |
16 |
17 | def parse_args_and_config():
18 | parser = argparse.ArgumentParser(description='Training Patch-Based Denoising Diffusion Models')
19 | parser.add_argument("--config", type=str, required=False, default="lowlight.yml",help="Path to the config file")
20 | parser.add_argument('--resume', default=r'', type=str,
21 | help='Path for checkpoint to load and resume')
22 | parser.add_argument("--sampling_timesteps", type=int, default=25,
23 | help="Number of implicit sampling steps for validation image patches")
24 | parser.add_argument("--image_folder", default='results/images/', type=str,
25 | help="Location to save restored validation image patches")
26 | parser.add_argument('--seed', default=61, type=int, metavar='N',
27 | help='Seed for initializing training (default: 61)')
28 | args = parser.parse_args()
29 |
30 | with open(os.path.join("configs", args.config), "r") as f:
31 | config = yaml.safe_load(f)
32 | new_config = dict2namespace(config)
33 |
34 | return args, new_config
35 |
36 |
37 | def dict2namespace(config):
38 | namespace = argparse.Namespace()
39 | for key, value in config.items():
40 | if isinstance(value, dict):
41 | new_value = dict2namespace(value)
42 | else:
43 | new_value = value
44 | setattr(namespace, key, new_value)
45 | return namespace
46 |
47 |
48 | def main():
49 | args, config = parse_args_and_config()
50 |
51 | # setup device to run
52 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
53 | print("Using device: {}".format(device))
54 | config.device = device
55 |
56 | # set random seed
57 | torch.manual_seed(args.seed)
58 | np.random.seed(args.seed)
59 | if torch.cuda.is_available():
60 | torch.cuda.manual_seed_all(args.seed)
61 | torch.backends.cudnn.benchmark = True
62 |
63 | # data loading
64 | print("=> using dataset '{}'".format(config.data.dataset))
65 | DATASET = datasets.__dict__[config.data.dataset](config)
66 |
67 | # create model
68 | print("=> creating denoising-diffusion model...")
69 | diffusion = DenoisingDiffusion(args, config)
70 | diffusion.train(DATASET)
71 |
72 |
73 | if __name__ == "__main__":
74 | main()
75 |
--------------------------------------------------------------------------------
/eval_diffusion.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import random
4 | import socket
5 | import yaml
6 | import torch
7 | import torch.backends.cudnn as cudnn
8 | import numpy as np
9 | import torchvision
10 | import models
11 | import datasets
12 | import utils
13 | from models import DenoisingDiffusion, DiffusiveRestoration
14 |
15 |
16 | def parse_args_and_config():
17 | parser = argparse.ArgumentParser(description='Restoring Weather with Patch-Based Denoising Diffusion Models')
18 | parser.add_argument("--config", type=str, default='lowlight.yml',
19 | help="Path to the config file")
20 |
21 | parser.add_argument('--resume', default=r'/root/LLIE/MDMS-main/datasets/scratch/LLIE/ckpts/bestpsnr.pth.tar', type=str,
22 | help='Path for the diffusion model checkpoint to load for evaluation')
23 | parser.add_argument("--grid_r", type=int, default=16,
24 | help="Grid cell width r that defines the overlap between patches")
25 | parser.add_argument("--sampling_timesteps", type=int, default=25,
26 | help="Number of implicit sampling steps")
27 | parser.add_argument("--test_set", type=str, default='lowlight')
28 | parser.add_argument("--image_folder", default='results/images/', type=str,
29 | help="Location to save restored images")
30 | parser.add_argument('--seed', default=61, type=int, metavar='N',
31 | help='Seed for initializing training (default: 61)')
32 | args = parser.parse_args()
33 |
34 | with open(os.path.join("configs", args.config), "r") as f:
35 | config = yaml.safe_load(f)
36 | new_config = dict2namespace(config)
37 |
38 | return args, new_config
39 |
40 |
41 | def dict2namespace(config):
42 | namespace = argparse.Namespace()
43 | for key, value in config.items():
44 | if isinstance(value, dict):
45 | new_value = dict2namespace(value)
46 | else:
47 | new_value = value
48 | setattr(namespace, key, new_value)
49 | return namespace
50 |
51 |
52 | def main():
53 | args, config = parse_args_and_config()
54 |
55 | # setup device to run
56 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
57 | print("Using device: {}".format(device))
58 | config.device = device
59 |
60 | if torch.cuda.is_available():
61 | print('Note: Currently supports evaluations (restoration) when run only on a single GPU!')
62 |
63 | # set random seed
64 | torch.manual_seed(args.seed)
65 | np.random.seed(args.seed)
66 | if torch.cuda.is_available():
67 | torch.cuda.manual_seed_all(args.seed)
68 | torch.backends.cudnn.benchmark = True
69 |
70 | # data loading
71 | print("=> using dataset '{}'".format(config.data.dataset))
72 | DATASET = datasets.__dict__[config.data.dataset](config)
73 | _, val_loader = DATASET.get_loaders(parse_patches=False, validation=args.test_set)
74 |
75 | # create model
76 | print("=> creating denoising-diffusion model with wrapper...")
77 | diffusion = DenoisingDiffusion(args, config)
78 | model = DiffusiveRestoration(diffusion, args, config)
79 | model.restore(val_loader, validation=args.test_set, r=args.grid_r,use_align=True)
80 |
81 |
82 | if __name__ == '__main__':
83 | main()
84 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # 【AAAI'2024】Multi-Domain Multi-Scale Diffusion Model for Low-Light Image Enhancement
4 |
5 |
6 | The official implementation of AAAI24 paper [Multi-Domain Multi-Scale Diffusion Model for Low-Light Image Enhancement](https://ojs.aaai.org/index.php/AAAI/article/view/28273).
7 |
8 | ## Environment
9 | create a new conda env,
10 | and run
11 | ```
12 | $ pip install -r requirements.txt
13 | ```
14 | torch/torchvision with CUDA version >= 11.3 should be fine.
15 |
16 | ## Demo
17 | #### 1. Download pretrained model
18 |
19 | Download the Pretrained MDMS model from [Baidu NetDisk](https://pan.baidu.com/s/1AQzsofBfsiSEy6cG6wcbYg?pwd=gfri) or [Google Drive](https://drive.google.com/file/d/1PQIqAs0mw8xp5obrVcg7QnxLLsfQ5H9L/view?usp=sharing).
20 |
21 | Put the downloaded ckpt in `datasets/scratch/LLIE/ckpts`.
22 |
23 |
24 | #### 2. Inference
25 | ```
26 | # in {path_to_this_repo}/,
27 | $ python eval_diffusion.py
28 | ```
29 | Put the test input in `datasets/scratch/LLIE/data/lowlight/test/input`.
30 |
31 | Output results will be saved in `results/images/lowlight/lowlight`.
32 |
33 | ## Evaluation
34 |
35 | Put the test GT in `datasets/scratch/LLIE/data/lowlight/test/gt` for paired evaluation.
36 |
37 | ```
38 | # in {path_to_this_repo}/,
39 | $ python evaluation.py
40 | ```
41 | * Note that our [evaluation metrics](https://github.com/Oliiveralien/MDMS/tree/main/evaluation.py) are slightly different from [PyDiff](https://github.com/limuloo/PyDIff/tree/862f8cc428450ef02822fd218b15705e2214ec2d/BasicSR-light/basicsr/metrics) (inherited from [BasicSR](https://github.com/XPixelGroup/BasicSR)).
42 |
43 | ## Results
44 | All results listed in our paper including the compared methods are available in [Baidu Netdisk](https://pan.baidu.com/s/1O8hOVflnLGLSLP07nXp_sg?pwd=zftu) or [Google Drive](https://drive.google.com/file/d/1k9-vD-I5JaHj7Y9bGq1gen2TKEzEhCzs/view?usp=sharing).
45 |
46 | * Note that the provided model is trained on the [LOLv1](https://daooshee.github.io/BMVC2018website/) training set, but generalizes well on other datasets.
47 | * For SSIM, we directly calculate the performance on [RGB channel](https://github.com/Oliiveralien/MDMS/tree/main/evaluation.py#L49-L51) rather than just [grayscale images](https://github.com/limuloo/PyDIff/blob/862f8cc428450ef02822fd218b15705e2214ec2d/BasicSR-light/basicsr/metrics/ssim_lol.py#L7C1-L12C132) in PyDiff.
48 | * For LPIPS, we use a different normalization method ([NormA](https://github.com/Oliiveralien/MDMS/tree/main/evaluation.py#L74)) compared to PyDiff ([NormB](https://github.com/limuloo/PyDIff/blob/862f8cc428450ef02822fd218b15705e2214ec2d/BasicSR-light/basicsr/metrics/lpips_lol.py#L19)).
49 |
50 | Our method remains superior under the same setting as PyDiff.
51 |
52 |
53 |

54 |
55 |
56 | ### 1. Test results on LOLv1 test set.
57 |
58 |
59 |

60 |
61 |
62 | ### 2. Generalization results on LOLv2 syn and real test sets.
63 |
64 |
65 |

66 |
67 |
68 | ### 3. Generalization results on other unpaired datasets.
69 |
70 |
71 |

72 |
73 |
74 | We will perform more training and tests on other datasets in the future.
75 |
76 | ## Training
77 | Put the training dataset in `datasets/scratch/LLIE/data/lowlight/train`.
78 |
79 | ```
80 | # in {path_to_this_repo}/,
81 | $ python train_diffusion.py
82 | ```
83 |
84 | Detailed training instructions will be updated soon.
85 |
86 | ## Citation
87 | If you find this paper useful, please consider staring this repo and citing our paper:
88 | ```
89 | @inproceedings{shang2024multi,
90 | title={Multi-Domain Multi-Scale Diffusion Model for Low-Light Image Enhancement},
91 | author={Shang, Kai and Shao, Mingwen and Wang, Chao and Cheng, Yuanshuo and Wang, Shuigen},
92 | booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
93 | volume={38},
94 | number={5},
95 | pages={4722--4730},
96 | year={2024}
97 | }
98 |
--------------------------------------------------------------------------------
/models/restoration.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import utils
4 | import torchvision
5 | import os
6 | import PIL
7 | import re
8 | from torchvision.transforms import Resize
9 |
10 | def data_transform(X):
11 | return 2 * X - 1.0
12 |
13 |
14 | def inverse_data_transform(X):
15 | return torch.clamp((X + 1.0) / 2.0, 0.0, 1.0)
16 |
17 |
18 | class DiffusiveRestoration:
19 | def __init__(self, diffusion, args, config):
20 | super(DiffusiveRestoration, self).__init__()
21 | self.args = args
22 | self.config = config
23 | self.diffusion = diffusion
24 |
25 | if os.path.isfile(args.resume):
26 | self.diffusion.load_ddm_ckpt(args.resume, ema=True)
27 | self.diffusion.model.eval()
28 | else:
29 | print('Pre-trained diffusion model path is missing!')
30 |
31 | def restore(self, val_loader, validation='lowlight', r=None,use_align=False):
32 | image_folder = os.path.join(self.args.image_folder, self.config.data.dataset, validation)
33 | with torch.no_grad():
34 | for i, (x, y,wd,ht) in enumerate(val_loader):
35 | print(f"starting processing from image {y}")
36 | y = re.findall(r'\d+', y[0])
37 | x = x.flatten(start_dim=0, end_dim=1) if x.ndim == 5 else x
38 | x_cond = x[:, :6, :, :].to(self.diffusion.device)
39 | x_output1 = self.diffusive_restoration(x_cond, r=r,fullresolusion=False)
40 | #x_output = inverse_data_transform(x_output)
41 | x_output1 = inverse_data_transform(x_output1)
42 | x_output=x_output1
43 | #x_output=(x_output+x_output1)/2
44 | b,c,h,w=x_output.shape
45 | ht=ht.item()
46 | wd=wd.item()
47 | torch_resize = Resize([ht, wd])
48 | x_output=torch_resize(x_output)
49 |
50 | if use_align==True:
51 | target = x[:, 6:, :, :]
52 | gt_mean = torch.mean(target)
53 | sr_mean = torch.mean(x_output)
54 | x_output = x_output * gt_mean / sr_mean
55 |
56 |
57 |
58 | utils.logging.save_image(x_output, os.path.join(image_folder, f"{y}.png"))
59 |
60 | def diffusive_restoration(self, x_cond, r=None,fullresolusion=False):
61 | if fullresolusion==False:##
62 | # p_size = self.config.data.image_size
63 | p_size=64
64 | h_list, w_list = self.overlapping_grid_indices(x_cond, output_size=p_size, r=8)
65 | corners = [(i, j) for i in h_list for j in w_list]
66 | h_list1, w_list1 = self.overlapping_grid_indices(x_cond, output_size=96, r=8)
67 | corners1 = [(i, j) for i in h_list1 for j in w_list1]
68 |
69 | h_list2, w_list2 = self.overlapping_grid_indices(x_cond, output_size=128, r=8)
70 | corners2 = [(i, j) for i in h_list2 for j in w_list2]
71 |
72 |
73 | x = torch.randn(x_cond.size()[0],3,x_cond.size()[2],x_cond.size()[3], device=self.diffusion.device)
74 |
75 | ii = torch.tensor([item[0] for item in corners])
76 | jj = torch.tensor([item[1] for item in corners])
77 | ii=ii/x_cond.size()[2]*2-1
78 | jj=jj/x_cond.size()[3]*2-1
79 | osize=torch.full((len(corners),), p_size)
80 | x_output = self.diffusion.sample_image(x_cond, x, ii,jj,osize,patch_locs=corners, patch_size=p_size,patch_locs1=corners1,patch_locs2=corners2)
81 | else:
82 | x = torch.randn(x_cond.size()[0], 3, x_cond.size()[2], x_cond.size()[3], device=self.diffusion.device)
83 | ii=torch.tensor(-1).unsqueeze(0)
84 | jj=torch.tensor(-1).unsqueeze(0)
85 | osize=torch.tensor(x_cond.size()[2]).unsqueeze(0)
86 | x_output = self.diffusion.sample_image(x_cond, x, ii, jj, osize, patch_locs=None, patch_size=None)
87 |
88 |
89 | return x_output
90 | #return x_output,x_output1
91 |
92 |
93 | def overlapping_grid_indices(self, x_cond, output_size, r=None):
94 | _, c, h, w = x_cond.shape
95 | r = 16 if r is None else r
96 | h_list = [i for i in range(0, h - output_size + 1, r)]
97 | w_list = [i for i in range(0, w - output_size + 1, r)]
98 | return h_list, w_list
99 |
--------------------------------------------------------------------------------
/evaluation.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | import time
4 | from collections import OrderedDict
5 |
6 | import numpy as np
7 | import torch
8 | import cv2
9 | import argparse
10 |
11 | from natsort import natsort
12 | from skimage.metrics import structural_similarity as ssim
13 | from skimage.metrics import peak_signal_noise_ratio as psnr
14 | import lpips
15 | from numpy.lib.histograms import histogram
16 | from numpy.lib.function_base import interp
17 | def histeq(im, nbr_bins=256):
18 | im=im.cpu()
19 | im=im.detach().numpy()
20 | imhist, bins = histogram(im.flatten(), nbr_bins)
21 | cdf = imhist.cumsum()
22 | cdf = 1.0 * cdf / cdf[-1]
23 | im2 = interp(im.flatten(), bins[:-1], cdf)
24 | return im2.reshape(im.shape)
25 | def save_img(filepath, img):
26 | cv2.imwrite(filepath, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
27 | def rgb(t): return (
28 | np.clip((t[0] if len(t.shape) == 4 else t).detach().cpu().numpy().transpose([1, 2, 0]), 0, 1) * 255).astype(
29 | np.uint8)
30 | class Measure():
31 | def __init__(self, net='alex', use_gpu=False):
32 | self.device = 'cuda' if use_gpu else 'cpu'
33 | self.model = lpips.LPIPS(net=net)
34 | self.model.to(self.device)
35 |
36 | def measure(self, imgA, imgB):#A=gt B=out
37 |
38 | return [float(f(imgA, imgB)) for f in [self.psnr, self.ssim, self.lpips]],imgB
39 |
40 | def lpips(self, imgA, imgB, model=None):
41 | tA = t(imgA).to(self.device)
42 | tB = t(imgB).to(self.device)
43 | # dist01 = self.model.forward(tA, tB).item()
44 |
45 | dist01 = self.model(tA, tB).item()
46 |
47 | return dist01
48 |
49 | def ssim(self, imgA, imgB):
50 | score, diff = ssim(imgA, imgB, full=True, multichannel=True)
51 | # score, diff = ssim( cv2.cvtColor(imgA, cv2.COLOR_RGB2GRAY), cv2.cvtColor(imgB, cv2.COLOR_RGB2GRAY), full=True, multichannel=True)
52 |
53 | return score
54 |
55 | def psnr(self, imgA, imgB):
56 | psnr_val = psnr(imgA, imgB)
57 | return psnr_val
58 |
59 |
60 | def t(img):
61 | def to_4d(img):
62 | assert len(img.shape) == 3
63 | assert img.dtype == np.uint8
64 | img_new = np.expand_dims(img, axis=0)
65 | assert len(img_new.shape) == 4
66 | return img_new
67 |
68 | def to_CHW(img):
69 | return np.transpose(img, [2, 0, 1])
70 |
71 | def to_tensor(img):
72 | return torch.Tensor(img)
73 |
74 | return to_tensor(to_4d(to_CHW(img))) / 255
75 | # return to_tensor(to_4d(to_CHW(img))) / 127.5 - 1
76 |
77 |
78 |
79 | def fiFindByWildcard(wildcard):
80 | return natsort.natsorted(glob.glob(wildcard, recursive=True))
81 |
82 |
83 | def imread(path):
84 | return cv2.imread(path)[:, :, [2, 1, 0]]
85 |
86 |
87 | def format_result(psnr, ssim, lpips):
88 | return f'{psnr:0.2f}, {ssim:0.3f}, {lpips:0.3f}'
89 |
90 | def measure_dirs(dirA, dirB, use_gpu, verbose=False):
91 | if verbose:
92 | vprint = lambda x: print(x)
93 | else:
94 | vprint = lambda x: None
95 |
96 |
97 | t_init = time.time()
98 |
99 | paths_A = fiFindByWildcard(os.path.join(dirA, f'*.{type}'))
100 | paths_B = fiFindByWildcard(os.path.join(dirB, f'*.{type}'))
101 |
102 | vprint("Comparing: ")
103 | vprint(dirA)
104 | vprint(dirB)
105 |
106 | measure = Measure(use_gpu=use_gpu)
107 |
108 | results = []
109 | for pathA, pathB in zip(paths_A, paths_B):
110 | result = OrderedDict()
111 |
112 | t = time.time()
113 | A=imread(pathA)
114 | B=imread(pathB)
115 | As = A.shape
116 | Bs = B.shape
117 | A0=As[0]
118 | A1= As[1]
119 | b = cv2.resize(B, (A1, A0))
120 | [result['psnr'], result['ssim'], result['lpips']],imgb= measure.measure(A, b)
121 |
122 | d = time.time() - t
123 | vprint(f"{pathA.split('/')[-1]}, {pathB.split('/')[-1]}, {format_result(**result)}, {d:0.1f}")
124 |
125 | results.append(result)
126 |
127 | psnr = np.mean([result['psnr'] for result in results])
128 | ssim = np.mean([result['ssim'] for result in results])
129 | lpips = np.mean([result['lpips'] for result in results])
130 |
131 | vprint(f"Final Result: {format_result(psnr, ssim, lpips)}, {time.time() - t_init:0.1f}s")
132 |
133 |
134 | if __name__ == "__main__":
135 | parser = argparse.ArgumentParser()
136 | parser.add_argument('-dirA', default=r'./datasets/scratch/LLIE/data/lowlight/test/gt', type=str)
137 | parser.add_argument('-dirB', default=r'./results/images/lowlight/lowlight', type=str)
138 | parser.add_argument('-type', default='png')
139 | parser.add_argument('--use_gpu', default=True)
140 | args = parser.parse_args()
141 |
142 | dirA = args.dirA
143 | dirB = args.dirB
144 | type = args.type
145 | use_gpu = args.use_gpu
146 |
147 | if len(dirA) > 0 and len(dirB) > 0:
148 | measure_dirs(dirA, dirB, use_gpu=use_gpu, verbose=True)
149 |
--------------------------------------------------------------------------------
/utils/metrics.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import torch
4 |
5 | # This script is adapted from the following repository: https://github.com/JingyunLiang/SwinIR
6 |
7 |
8 | def calculate_psnr(img1, img2, test_y_channel=False):
9 | """Calculate PSNR (Peak Signal-to-Noise Ratio).
10 |
11 | Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
12 |
13 | Args:
14 | img1 (ndarray): Images with range [0, 255].
15 | img2 (ndarray): Images with range [0, 255].
16 | test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
17 |
18 | Returns:
19 | float: psnr result.
20 | """
21 |
22 | assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
23 | assert img1.shape[2] == 3
24 | img1 = img1.astype(np.float64)
25 | img2 = img2.astype(np.float64)
26 |
27 | if test_y_channel:
28 | img1 = to_y_channel(img1)
29 | img2 = to_y_channel(img2)
30 |
31 | mse = np.mean((img1 - img2) ** 2)
32 | if mse == 0:
33 | return float('inf')
34 | return 20. * np.log10(255. / np.sqrt(mse))
35 |
36 |
37 | def _ssim(img1, img2):
38 | """Calculate SSIM (structural similarity) for one channel images.
39 |
40 | It is called by func:`calculate_ssim`.
41 |
42 | Args:
43 | img1 (ndarray): Images with range [0, 255] with order 'HWC'.
44 | img2 (ndarray): Images with range [0, 255] with order 'HWC'.
45 |
46 | Returns:
47 | float: ssim result.
48 | """
49 |
50 | C1 = (0.01 * 255) ** 2
51 | C2 = (0.03 * 255) ** 2
52 |
53 | img1 = img1.astype(np.float64)
54 | img2 = img2.astype(np.float64)
55 | kernel = cv2.getGaussianKernel(11, 1.5)
56 | window = np.outer(kernel, kernel.transpose())
57 |
58 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]
59 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
60 | mu1_sq = mu1 ** 2
61 | mu2_sq = mu2 ** 2
62 | mu1_mu2 = mu1 * mu2
63 | sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq
64 | sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq
65 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
66 |
67 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
68 | return ssim_map.mean()
69 |
70 |
71 | def calculate_ssim(img1, img2, test_y_channel=False):
72 | """Calculate SSIM (structural similarity).
73 |
74 | Ref:
75 | Image quality assessment: From error visibility to structural similarity
76 |
77 | The results are the same as that of the official released MATLAB code in
78 | https://ece.uwaterloo.ca/~z70wang/research/ssim/.
79 |
80 | For three-channel images, SSIM is calculated for each channel and then
81 | averaged.
82 |
83 | Args:
84 | img1 (ndarray): Images with range [0, 255].
85 | img2 (ndarray): Images with range [0, 255].
86 | test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
87 |
88 | Returns:
89 | float: ssim result.
90 | """
91 |
92 | assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
93 | assert img1.shape[2] == 3
94 | img1 = img1.astype(np.float64)
95 | img2 = img2.astype(np.float64)
96 |
97 | if test_y_channel:
98 | img1 = to_y_channel(img1)
99 | img2 = to_y_channel(img2)
100 |
101 | ssims = []
102 | for i in range(img1.shape[2]):
103 | ssims.append(_ssim(img1[..., i], img2[..., i]))
104 | return np.array(ssims).mean()
105 |
106 |
107 | def to_y_channel(img):
108 | """Change to Y channel of YCbCr.
109 |
110 | Args:
111 | img (ndarray): Images with range [0, 255].
112 |
113 | Returns:
114 | (ndarray): Images with range [0, 255] (float type) without round.
115 | """
116 | img = img.astype(np.float32) / 255.
117 | if img.ndim == 3 and img.shape[2] == 3:
118 | img = bgr2ycbcr(img, y_only=True)
119 | img = img[..., None]
120 | return img * 255.
121 |
122 |
123 | def _convert_input_type_range(img):
124 | """Convert the type and range of the input image.
125 |
126 | It converts the input image to np.float32 type and range of [0, 1].
127 | It is mainly used for pre-processing the input image in colorspace
128 | convertion functions such as rgb2ycbcr and ycbcr2rgb.
129 |
130 | Args:
131 | img (ndarray): The input image. It accepts:
132 | 1. np.uint8 type with range [0, 255];
133 | 2. np.float32 type with range [0, 1].
134 |
135 | Returns:
136 | (ndarray): The converted image with type of np.float32 and range of
137 | [0, 1].
138 | """
139 | img_type = img.dtype
140 | img = img.astype(np.float32)
141 | if img_type == np.float32:
142 | pass
143 | elif img_type == np.uint8:
144 | img /= 255.
145 | else:
146 | raise TypeError('The img type should be np.float32 or np.uint8, ' f'but got {img_type}')
147 | return img
148 |
149 |
150 | def _convert_output_type_range(img, dst_type):
151 | """Convert the type and range of the image according to dst_type.
152 |
153 | It converts the image to desired type and range. If `dst_type` is np.uint8,
154 | images will be converted to np.uint8 type with range [0, 255]. If
155 | `dst_type` is np.float32, it converts the image to np.float32 type with
156 | range [0, 1].
157 | It is mainly used for post-processing images in colorspace convertion
158 | functions such as rgb2ycbcr and ycbcr2rgb.
159 |
160 | Args:
161 | img (ndarray): The image to be converted with np.float32 type and
162 | range [0, 255].
163 | dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
164 | converts the image to np.uint8 type with range [0, 255]. If
165 | dst_type is np.float32, it converts the image to np.float32 type
166 | with range [0, 1].
167 |
168 | Returns:
169 | (ndarray): The converted image with desired type and range.
170 | """
171 | if dst_type not in (np.uint8, np.float32):
172 | raise TypeError('The dst_type should be np.float32 or np.uint8, ' f'but got {dst_type}')
173 | if dst_type == np.uint8:
174 | img = img.round()
175 | else:
176 | img /= 255.
177 | return img.astype(dst_type)
178 |
179 |
180 | def bgr2ycbcr(img, y_only=False):
181 | """Convert a BGR image to YCbCr image.
182 |
183 | The bgr version of rgb2ycbcr.
184 | It implements the ITU-R BT.601 conversion for standard-definition
185 | television. See more details in
186 | https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
187 |
188 | It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
189 | In OpenCV, it implements a JPEG conversion. See more details in
190 | https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
191 |
192 | Args:
193 | img (ndarray): The input image. It accepts:
194 | 1. np.uint8 type with range [0, 255];
195 | 2. np.float32 type with range [0, 1].
196 | y_only (bool): Whether to only return Y channel. Default: False.
197 |
198 | Returns:
199 | ndarray: The converted YCbCr image. The output image has the same type
200 | and range as input image.
201 | """
202 | img_type = img.dtype
203 | img = _convert_input_type_range(img)
204 | if y_only:
205 | out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
206 | else:
207 | out_img = np.matmul(
208 | img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [16, 128, 128]
209 | out_img = _convert_output_type_range(out_img, img_type)
210 | return out_img
211 |
--------------------------------------------------------------------------------
/utils/sampling.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import utils.logging
3 | import os
4 | import torchvision
5 | from torchvision.transforms.functional import crop
6 |
7 |
8 | # This script is adapted from the following repository: https://github.com/ermongroup/ddim
9 |
10 | def compute_alpha(beta, t):
11 | beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0)
12 | a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1)
13 | return a
14 |
15 |
16 | def data_transform(X):
17 | return 2 * X - 1.0
18 |
19 |
20 | def inverse_data_transform(X):
21 | return torch.clamp((X + 1.0) / 2.0, 0.0, 1.0)
22 |
23 |
24 | def generalized_steps(x, x_cond,seq, model, b,ii,jj,osize,eta=0.):
25 | with torch.no_grad():
26 | n = x.size(0)
27 | seq_next = [-1] + list(seq[:-1])
28 | x0_preds = []
29 | xs = [x]
30 | for i, j in zip(reversed(seq), reversed(seq_next)):
31 | t = (torch.ones(n) * i).to(x.device)
32 | next_t = (torch.ones(n) * j).to(x.device)
33 | at = compute_alpha(b, t.long())
34 | at_next = compute_alpha(b, next_t.long())
35 | xt = xs[-1].to('cuda')
36 | ii = ii.to(x.device)
37 | jj = jj.to(x.device)
38 | osize = osize.to(x.device)
39 | # print(xt.shape)
40 | # print(x_cond.shape)
41 | et = model(torch.cat([x_cond, xt], dim=1), t,ii,jj,osize)
42 |
43 | x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()
44 | x0_preds.append(x0_t.to('cpu'))
45 |
46 | c1 = eta * ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt()
47 | c2 = ((1 - at_next) - c1 ** 2).sqrt()
48 | xt_next = at_next.sqrt() * x0_t + c1 * torch.randn_like(x) + c2 * et
49 | xs.append(xt_next.to('cpu'))
50 | return xs, x0_preds
51 |
52 |
53 | def generalized_steps_overlapping(x, x_cond, seq, model, b,ii,jj,osize, eta=0., corners=None, p_size=None,corners1=None,corners2=None,manual_batching=True):
54 | with torch.no_grad():
55 | n = x.size(0)
56 | seq_next = [-1] + list(seq[:-1])
57 | x0_preds = []
58 | xs = [x]
59 |
60 | x_grid_mask = torch.zeros(x_cond.size(0),3,x_cond.size(2),x_cond.size(3), device=x.device)
61 | for (hi, wi) in corners:
62 | x_grid_mask[:, :, hi:hi + p_size, wi:wi + p_size] += 1
63 | if corners1!=None:
64 | p_size1=96
65 | x_grid_mask1 = torch.zeros(x_cond.size(0),3,x_cond.size(2),x_cond.size(3), device=x.device)
66 | for (hi, wi) in corners1:
67 | x_grid_mask1[:, :, hi:hi + p_size1, wi:wi + p_size1] += 1
68 | if corners2 != None:
69 | p_size2=128
70 | x_grid_mask2 = torch.zeros(x_cond.size(0),3,x_cond.size(2),x_cond.size(3), device=x.device)
71 | for (hi, wi) in corners2:
72 | x_grid_mask2[:, :, hi:hi + p_size2, wi:wi + p_size2] += 1
73 |
74 |
75 | for i, j in zip(reversed(seq), reversed(seq_next)):
76 | t = (torch.ones(n) * i).to(x.device)
77 | next_t = (torch.ones(n) * j).to(x.device)
78 | at = compute_alpha(b, t.long())
79 | at_next = compute_alpha(b, next_t.long())
80 | xt = xs[-1].to('cuda')
81 | et_output = torch.zeros(x_cond.size(0),3,x_cond.size(2),x_cond.size(3), device=x.device)
82 |
83 | if manual_batching==True:
84 | manual_batching_size = p_size
85 | xt_patch = torch.cat([crop(xt, hi, wi, p_size, p_size) for (hi, wi) in corners], dim=0)
86 | x_cond_patch = torch.cat([data_transform(crop(x_cond, hi, wi, p_size, p_size)) for (hi, wi) in corners], dim=0)
87 | for i in range(0, len(corners), manual_batching_size):#以16为步长的块即为corners
88 | ii_input=torch.unsqueeze(ii[i],dim=0)
89 | jj_input=torch.unsqueeze(jj[i],dim=0)
90 | osize_input=torch.unsqueeze(osize[i],dim=0)
91 | x_input=torch.cat([x_cond_patch[i:i + manual_batching_size], xt_patch[i:i + manual_batching_size]], dim=1)
92 | # print("x_input", x_input.shape)
93 | outputs = model(x_input, t,ii_input,jj_input,osize_input)
94 | # print("outputs", outputs.shape)
95 | for idx, (hi, wi) in enumerate(corners[i:i+manual_batching_size]):
96 | # print("idx", idx)
97 | et_output[0, :, hi:hi + p_size, wi:wi + p_size] += outputs[idx]
98 |
99 | if corners1 != None:
100 | et_output1 = torch.zeros(x_cond.size(0), 3, x_cond.size(2), x_cond.size(3), device=x.device)
101 | manual_batching_size = p_size1
102 | xt_patch = torch.cat([crop(xt, hi, wi, p_size1, p_size1) for (hi, wi) in corners1], dim=0)
103 | x_cond_patch = torch.cat(
104 | [data_transform(crop(x_cond, hi, wi, p_size1, p_size1)) for (hi, wi) in corners1], dim=0)
105 | for i in range(0, len(corners1), manual_batching_size): # 以16为步长的块即为corners
106 | ii_input = torch.unsqueeze(ii[i], dim=0)
107 | jj_input = torch.unsqueeze(jj[i], dim=0)
108 | osize_input = torch.unsqueeze(osize[i], dim=0)
109 | x_input = torch.cat(
110 | [x_cond_patch[i:i + manual_batching_size], xt_patch[i:i + manual_batching_size]], dim=1)
111 | # print("x_input", x_input.shape)
112 | outputs1 = model(x_input, t, ii_input, jj_input, osize_input)
113 | # print("outputs", outputs.shape)
114 | for idx, (hi, wi) in enumerate(corners1[i:i + manual_batching_size]):
115 | # print("idx", idx)
116 | et_output1[0, :, hi:hi + p_size1, wi:wi + p_size1] += outputs1[idx]
117 | if corners2 != None:
118 | et_output2= torch.zeros(x_cond.size(0), 3, x_cond.size(2), x_cond.size(3), device=x.device)
119 | manual_batching_size = p_size2
120 | xt_patch = torch.cat([crop(xt, hi, wi, p_size2, p_size2) for (hi, wi) in corners2], dim=0)
121 | x_cond_patch = torch.cat(
122 | [data_transform(crop(x_cond, hi, wi, p_size2, p_size2)) for (hi, wi) in corners2], dim=0)
123 | for i in range(0, len(corners2), manual_batching_size): # 以16为步长的块即为corners
124 | ii_input = torch.unsqueeze(ii[i], dim=0)
125 | jj_input = torch.unsqueeze(jj[i], dim=0)
126 | osize_input = torch.unsqueeze(osize[i], dim=0)
127 | x_input = torch.cat(
128 | [x_cond_patch[i:i + manual_batching_size], xt_patch[i:i + manual_batching_size]], dim=1)
129 | # print("x_input", x_input.shape)
130 | outputs2 = model(x_input, t, ii_input, jj_input, osize_input)
131 | # print("outputs", outputs.shape)
132 | for idx, (hi, wi) in enumerate(corners2[i:i + manual_batching_size]):
133 | # print("idx", idx)
134 | et_output2[0, :, hi:hi + p_size2, wi:wi + p_size2] += outputs2[idx]
135 |
136 | else:
137 | for (hi, wi) in corners:
138 | xt_patch = crop(xt, hi, wi, p_size, p_size)
139 | x_cond_patch = crop(x_cond, hi, wi, p_size, p_size)
140 | x_cond_patch = data_transform(x_cond_patch)
141 | et_output[:, :, hi:hi + p_size, wi:wi + p_size] += model(torch.cat([x_cond_patch, xt_patch], dim=1), t)
142 |
143 | et0 = torch.div(et_output, x_grid_mask)
144 | if corners1 != None:
145 | et1 = torch.div(et_output1, x_grid_mask1)
146 | et=(et0+et1)/2.0
147 | if corners2 != None:
148 | et2 = torch.div(et_output2, x_grid_mask2)
149 | et = (et0 +et1+ et2) / 3.0
150 | else:
151 | et=et0
152 | x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()
153 | x0_preds.append(x0_t.to('cpu'))
154 |
155 | c1 = eta * ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt()
156 | c2 = ((1 - at_next) - c1 ** 2).sqrt()
157 | xt_next = at_next.sqrt() * x0_t + c1 * torch.randn_like(x) + c2 * et
158 | # if i>500:
159 | # xt_next=xt_next/torch.mean(xt_next)*torch.mean(x_cond[:,3:6,:,])
160 | xs.append(xt_next.to('cpu'))
161 |
162 |
163 |
164 |
165 |
166 | return xs, x0_preds
167 |
--------------------------------------------------------------------------------
/datasets/lowlight.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | from os import listdir
4 | from os.path import isfile
5 | import torch
6 | import numpy as np
7 | import torchvision
8 | import torch.utils.data
9 | import PIL
10 | import re
11 | import random
12 | import torchvision.transforms as transforms
13 |
14 | class lowlight:
15 | def __init__(self, config):
16 | self.config = config
17 | self.transforms = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
18 |
19 | def get_loaders(self, parse_patches=True, validation='lowlight'):
20 | print("=> evaluating lowlight test set...")
21 | train_dataset = lowlightDataset(dir=os.path.join(self.config.data.data_dir, 'data', 'lowlight', 'train'),
22 | n=self.config.training.patch_n,
23 | patch_size=self.config.data.image_size,
24 | transforms=self.transforms,
25 | filelist=None,
26 | parse_patches=parse_patches,
27 | train=True)
28 | val_dataset = lowlightDataset(dir=os.path.join(self.config.data.data_dir, 'data', 'lowlight', 'test'),
29 | n=self.config.training.patch_n,
30 | patch_size=self.config.data.image_size,
31 | transforms=self.transforms,
32 | filelist='lowlighttesta.txt',
33 | parse_patches=parse_patches,
34 | train=False)
35 |
36 | if not parse_patches:
37 | self.config.training.batch_size = 1
38 | self.config.sampling.batch_size = 1
39 |
40 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=self.config.training.batch_size,
41 | shuffle=True, num_workers=self.config.data.num_workers,
42 | pin_memory=True)
43 | val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=self.config.sampling.batch_size,
44 | shuffle=False, num_workers=self.config.data.num_workers,
45 | pin_memory=True)
46 |
47 | return train_loader, val_loader
48 |
49 |
50 | class lowlightDataset(torch.utils.data.Dataset):
51 | def __init__(self, dir, patch_size, n, transforms, train,filelist=None, parse_patches=True):
52 | super().__init__()
53 |
54 | if filelist is None:
55 | lowlight_dir = dir
56 | input_names, gt_names = [], []
57 |
58 |
59 |
60 | filepath = os.path.dirname(__file__)
61 | print(f"filepath:{filepath}")
62 | lowlight_dir=os.path.join(filepath,lowlight_dir)
63 | #lowlight_inputs = os.path.join(filepath, lowlight_inputs0)
64 |
65 | # lowlight train filelist
66 | lowlight_inputs = os.path.join(lowlight_dir, 'input')
67 | listdir(lowlight_inputs)
68 | images = [f for f in listdir(lowlight_inputs) if isfile(os.path.join(lowlight_inputs, f))]
69 | assert len(images) == 485
70 | input_names += [os.path.join(lowlight_inputs, i) for i in images]
71 | # gt_names += [os.path.join(os.path.join(lowlight_dir, 'gt'), i.replace('rain', 'clean')) for i in images]
72 | gt_names += [os.path.join(os.path.join(lowlight_dir, 'gt'), i.replace('', '')) for i in images]
73 | print(len(input_names))
74 |
75 | x = list(enumerate(input_names))
76 | random.shuffle(x)
77 | indices, input_names = zip(*x)
78 | gt_names = [gt_names[idx] for idx in indices]
79 | self.dir = None
80 | else:
81 | self.dir = dir
82 | filepath = os.path.dirname(__file__)
83 | dir = os.path.join(filepath, dir)
84 | train_list = os.path.join(dir, filelist)
85 | with open(train_list) as f:
86 | contents = f.readlines()
87 | input_names = [i.strip() for i in contents]
88 | gt_names = [i.strip().replace('input', 'gt') for i in input_names]
89 |
90 | self.input_names = input_names
91 | self.gt_names = gt_names
92 | self.patch_size = patch_size
93 | self.transforms = transforms
94 | self.n = n
95 | self.parse_patches = parse_patches
96 | self.batchnum = 0
97 | self.batchsize = 1
98 | self.train=train
99 |
100 | @staticmethod
101 | def get_params(img, output_size, n,random_size):
102 | w, h = img.size
103 | if random_size == 0:
104 | output_size = (64,64)
105 | elif random_size == 1:
106 | output_size = (128,128)
107 | else:
108 | output_size = (256,256)
109 | # output_size = (192,192)
110 | th, tw = output_size
111 | if w == tw and h == th:
112 | return 0, 0, h, w
113 |
114 | i_list = [random.randint(0, h - th) for _ in range(n)]
115 | j_list = [random.randint(0, w - tw) for _ in range(n)]
116 | osize = [output_size[0] for _ in range(n)]
117 | return i_list, j_list, th, tw,osize,h,w
118 |
119 | @staticmethod
120 | def n_random_crops(img, x, y, h, w):
121 | crops = []
122 | for i in range(len(x)):
123 | new_crop = img.crop((y[i], x[i], y[i] + w, x[i] + h))
124 |
125 | if h!=64:
126 | resize_transform = transforms.Resize(64)
127 | new_crop=resize_transform(new_crop)
128 | crops.append(new_crop)
129 | return tuple(crops)
130 |
131 | def get_max(self,input):
132 | T,_=torch.max(input,dim=0)
133 | T=T+0.1
134 | input[0,:,:] = input[0,:,:]/ T
135 | input[1,:,:] = input[1,:,:]/ T
136 | input[2,:,:]= input[2,:,:] / T
137 | return input
138 |
139 | def get_images(self, index):
140 | if self.train==True:
141 | if self.batchnum==0:
142 | self.random_size = random.randint(0, 2)
143 | self.batchnum=self.batchnum+1
144 | if self.batchnum==self.batchsize:
145 | self.batchnum=0
146 | else:
147 | self.random_size=0
148 |
149 | input_name = self.input_names[index]
150 | gt_name = self.gt_names[index]
151 | img_id = re.split('/', input_name)[-1][:-4]
152 | input_img = PIL.Image.open(os.path.join(self.dir, input_name)) if self.dir else PIL.Image.open(input_name)
153 | try:
154 | gt_img = PIL.Image.open(os.path.join(self.dir, gt_name)) if self.dir else PIL.Image.open(gt_name)
155 | except:
156 | gt_img = PIL.Image.open(os.path.join(self.dir, gt_name)).convert('RGB') if self.dir else \
157 | PIL.Image.open(gt_name).convert('RGB')
158 |
159 | if self.parse_patches:
160 |
161 | i, j, h, w,osize,h_org,w_org = self.get_params(input_img, (self.patch_size, self.patch_size), self.n,self.random_size)
162 | input_img = self.n_random_crops(input_img, i, j, h, w)
163 | gt_img = self.n_random_crops(gt_img, i, j, h, w)
164 |
165 |
166 | random_input=0
167 | if random_input==0:
168 | outputs = [torch.cat([self.transforms(input_img[i]),self.get_max(self.transforms(input_img[i])),self.transforms(gt_img[i])], dim=0)
169 | for i in range(self.n)]
170 | else:
171 | outputs = [torch.cat([self.get_max(self.transforms(input_img[i])), self.transforms(input_img[i]),
172 | self.transforms(gt_img[i])], dim=0)
173 | for i in range(self.n)]
174 | ii=torch.tensor(i)
175 | jj = torch.tensor(j)
176 | ii = (ii / h_org) * 2 - 1
177 | jj = (jj / w_org) * 2 - 1
178 | osize=torch.tensor(osize)
179 | return torch.stack(outputs, dim=0), img_id,ii,jj,osize
180 | else:
181 | # Resizing images to multiples of 16 for whole-image restoration
182 | wd_new, ht_new = input_img.size
183 | wd=wd_new
184 | ht=ht_new
185 | if ht_new > wd_new and ht_new > 1024:
186 | wd_new = int(np.ceil(wd_new * 1024 / ht_new))
187 | ht_new = 1024
188 | elif ht_new <= wd_new and wd_new > 1024:
189 | ht_new = int(np.ceil(ht_new * 1024 / wd_new))
190 | wd_new = 1024
191 |
192 |
193 |
194 | wd_new = int(8 * np.ceil(wd_new / 8.0))
195 | ht_new = int(8 * np.ceil(ht_new / 8.0))
196 | input_img = input_img.resize((wd_new, ht_new), PIL.Image.ANTIALIAS)
197 | gt_img = gt_img.resize((wd_new, ht_new), PIL.Image.ANTIALIAS)
198 |
199 | return torch.cat([self.transforms(input_img),self.get_max(self.transforms(input_img)), self.transforms(gt_img)], dim=0), img_id,wd,ht
200 |
201 | def __getitem__(self, index):
202 | res = self.get_images(index)
203 | return res
204 |
205 | def __len__(self):
206 | return len(self.input_names)
207 |
--------------------------------------------------------------------------------
/models/ddm.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import glob
4 | import numpy as np
5 | import tqdm
6 | import torch
7 | import torch.nn as nn
8 | import torch.utils.data as data
9 | import torch.backends.cudnn as cudnn
10 | import utils
11 | from models.unet import DiffusionUNet
12 |
13 |
14 |
15 | def data_transform(X):
16 | return 2 * X - 1.0
17 |
18 |
19 | def inverse_data_transform(X):
20 | return torch.clamp((X + 1.0) / 2.0, 0.0, 1.0)
21 |
22 |
23 | class EMAHelper(object):
24 | def __init__(self, mu=0.9999):
25 | self.mu = mu
26 | self.shadow = {}
27 |
28 | def register(self, module):
29 | if isinstance(module, nn.DataParallel):
30 | module = module.module
31 | for name, param in module.named_parameters():
32 | if param.requires_grad:
33 | self.shadow[name] = param.data.clone()
34 |
35 | def update(self, module):
36 | if isinstance(module, nn.DataParallel):
37 | module = module.module
38 | for name, param in module.named_parameters():
39 | if param.requires_grad:
40 | self.shadow[name].data = (1. - self.mu) * param.data + self.mu * self.shadow[name].data
41 |
42 | def ema(self, module):
43 | if isinstance(module, nn.DataParallel):
44 | module = module.module
45 | for name, param in module.named_parameters():
46 | if param.requires_grad:
47 | param.data.copy_(self.shadow[name].data)
48 |
49 | def ema_copy(self, module):
50 | if isinstance(module, nn.DataParallel):
51 | inner_module = module.module
52 | module_copy = type(inner_module)(inner_module.config).to(inner_module.config.device)
53 | module_copy.load_state_dict(inner_module.state_dict())
54 | module_copy = nn.DataParallel(module_copy)
55 | else:
56 | module_copy = type(module)(module.config).to(module.config.device)
57 | module_copy.load_state_dict(module.state_dict())
58 | self.ema(module_copy)
59 | return module_copy
60 |
61 | def state_dict(self):
62 | return self.shadow
63 |
64 | def load_state_dict(self, state_dict):
65 | self.shadow = state_dict
66 |
67 |
68 | def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
69 | def sigmoid(x):
70 | return 1 / (np.exp(-x) + 1)
71 |
72 | if beta_schedule == "quad":
73 | betas = (np.linspace(beta_start ** 0.5, beta_end ** 0.5, num_diffusion_timesteps, dtype=np.float64) ** 2)
74 | elif beta_schedule == "linear":
75 | betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
76 | elif beta_schedule == "const":
77 | betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
78 | elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
79 | betas = 1.0 / np.linspace(num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64)
80 | elif beta_schedule == "sigmoid":
81 | betas = np.linspace(-6, 6, num_diffusion_timesteps)
82 | betas = sigmoid(betas) * (beta_end - beta_start) + beta_start
83 | else:
84 | raise NotImplementedError(beta_schedule)
85 | assert betas.shape == (num_diffusion_timesteps,)
86 | return betas
87 |
88 |
89 | def noise_estimation_loss(model, x0, t, e, b, i, j, osize):
90 | a = (1 - b).cumprod(dim=0).index_select(0, t).view(-1, 1, 1, 1)
91 | x = x0[:, 6:, :, :] * a.sqrt() + e * (1.0 - a).sqrt()
92 | output = model(torch.cat([x0[:, :6, :, :], x], dim=1), t.float(), i, j,
93 | osize)
94 |
95 |
96 | return (e - output).square().sum(dim=(1, 2, 3)).mean(dim=0)
97 |
98 |
99 | def count_parameters(model):
100 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
101 |
102 |
103 | def count_flops(model, input_shape):
104 | flops = 0
105 | model.eval()
106 | with torch.no_grad():
107 | input = torch.randn(1, *input_shape)
108 | model(input)
109 | flops = torch.cuda.memory_stats(0)["allocated_bytes.all.current"]
110 | model.train()
111 | return flops
112 |
113 |
114 | class DenoisingDiffusion(object):
115 | def __init__(self, args, config):
116 | super().__init__()
117 | self.args = args
118 | self.config = config
119 | self.device = config.device
120 |
121 | self.model = DiffusionUNet(config)
122 | self.model.to(self.device)
123 | self.model = torch.nn.DataParallel(self.model)
124 |
125 |
126 | num_params = count_parameters(self.model)
127 | print("parameters: {:,}".format(num_params))
128 |
129 | self.ema_helper = EMAHelper()
130 | self.ema_helper.register(self.model)
131 |
132 | self.optimizer = utils.optimize.get_optimizer(self.config, self.model.parameters())
133 | self.start_epoch, self.step = 0, 0
134 |
135 | betas = get_beta_schedule(
136 | beta_schedule=config.diffusion.beta_schedule,
137 | beta_start=config.diffusion.beta_start,
138 | beta_end=config.diffusion.beta_end,
139 | num_diffusion_timesteps=config.diffusion.num_diffusion_timesteps,
140 | )
141 |
142 | betas = self.betas = torch.from_numpy(betas).float().to(self.device)
143 | self.num_timesteps = betas.shape[0]
144 |
145 | def load_ddm_ckpt(self, load_path, ema=False):
146 | checkpoint = utils.logging.load_checkpoint(load_path, None)
147 | self.start_epoch = checkpoint['epoch']
148 | self.step = checkpoint['step']
149 | self.model.load_state_dict(checkpoint['state_dict'], strict=True)
150 | self.optimizer.load_state_dict(checkpoint['optimizer'])
151 | self.ema_helper.load_state_dict(checkpoint['ema_helper'])
152 | if ema:
153 | self.ema_helper.ema(self.model)
154 | print("=> loaded checkpoint '{}' (epoch {}, step {})".format(load_path, checkpoint['epoch'], self.step))
155 |
156 | def train(self, DATASET):
157 | cudnn.benchmark = True
158 | train_loader, val_loader = DATASET.get_loaders()
159 |
160 | if os.path.isfile(self.args.resume):
161 | self.load_ddm_ckpt(self.args.resume)
162 |
163 | for epoch in range(self.start_epoch, self.config.training.n_epochs):
164 | print('epoch: ', epoch)
165 | data_start = time.time()
166 | data_time = 0
167 | for inter, (x, y, i, j, osize) in enumerate(train_loader):
168 | x = x.flatten(start_dim=0, end_dim=1) if x.ndim == 5 else x
169 | n = x.size(0)
170 | data_time += time.time() - data_start
171 | self.model.train()
172 | self.step += 1
173 |
174 | x = x.to(self.device)
175 | x = data_transform(x)
176 | e = torch.randn_like(x[:, 6:, :, :])
177 | b = self.betas
178 |
179 | # antithetic sampling
180 | t = torch.randint(low=0, high=self.num_timesteps, size=(n // 2 + 1,)).to(
181 | self.device)
182 | t = torch.cat([t, self.num_timesteps - t - 1], dim=0)[
183 | :n]
184 | i = i.squeeze().to(self.device)
185 | j = j.squeeze().to(self.device)
186 | osize = osize.squeeze().to(self.device)
187 | i = i.view(n)
188 | j = j.view(n)
189 | osize = osize.view(n)
190 | # cond=torch.stack([ i, j, osize], dim=1)
191 | loss = noise_estimation_loss(self.model, x, t, e, b, i, j,
192 | osize)
193 |
194 | if self.step % 10 == 0:
195 | print(f"step: {self.step}, loss: {loss.item()}, data time: {data_time / (inter + 1)}")
196 |
197 | self.optimizer.zero_grad()
198 | loss.backward()
199 | self.optimizer.step()
200 | self.ema_helper.update(self.model)
201 | data_start = time.time()
202 |
203 | if self.step % self.config.training.validation_freq == 0:
204 | self.model.eval()
205 | self.sample_validation_patches(val_loader, self.step)
206 |
207 | if self.step % self.config.training.snapshot_freq == 0 or self.step == 1:
208 | utils.logging.save_checkpoint({
209 | 'epoch': epoch + 1,
210 | 'step': self.step,
211 | 'state_dict': self.model.state_dict(),
212 | 'optimizer': self.optimizer.state_dict(),
213 | 'ema_helper': self.ema_helper.state_dict(),
214 | 'params': self.args,
215 | 'config': self.config
216 | }, filename=os.path.join(self.config.data.data_dir, 'ckpts',self.config.data.dataset + '_ddpm'))
217 |
218 | def sample_image(self, x_cond, x, ii, jj, osize, last=True, patch_locs=None, patch_size=None,patch_locs1=None,patch_locs2=None):
219 | skip = self.config.diffusion.num_diffusion_timesteps // self.args.sampling_timesteps
220 | seq = range(0, self.config.diffusion.num_diffusion_timesteps, skip)
221 | if patch_locs is not None:
222 | xs = utils.sampling.generalized_steps_overlapping(x, x_cond, seq, self.model, self.betas, ii=ii, jj=jj,
223 | osize=osize, eta=0.,
224 | corners=patch_locs, p_size=patch_size,corners1=patch_locs1,corners2=patch_locs2)
225 | else:
226 | xs = utils.sampling.generalized_steps(x, x_cond, seq, self.model, self.betas, ii=ii, jj=jj, osize=osize,
227 | eta=0.)
228 | if last:
229 | xs = xs[0][-1]
230 | return xs
231 |
232 | def sample_validation_patches(self, val_loader, step):
233 | image_folder = os.path.join(self.args.image_folder, self.config.data.dataset + str(self.config.data.image_size))
234 | with torch.no_grad():
235 | print(f"Processing a single batch of validation images at step: {step}")
236 | for i, (x, y, ii, jj, osize) in enumerate(val_loader):
237 | x = x.flatten(start_dim=0, end_dim=1) if x.ndim == 5 else x
238 | break
239 | n = x.size(0)
240 | x_cond = x[:, :6, :, :].to(self.device)
241 | x_cond = data_transform(x_cond)
242 | # print("osize:", osize)
243 | # print("osize shape:", osize.shape)
244 |
245 | x = torch.randn(n, 3, int(osize[0, 0]), int(osize[0, 1]), device=self.device) # 这里是噪声项,需要3通道
246 | ii = ii.squeeze()
247 | jj = jj.squeeze()
248 | osize = osize.squeeze()
249 | ii = ii.view(n)
250 | jj = jj.view(n)
251 | osize = osize.view(n)
252 | x = self.sample_image(x_cond, x, ii, jj, osize)
253 | x = inverse_data_transform(x)
254 | x_cond = inverse_data_transform(x_cond)
255 |
256 | for i in range(n):
257 | utils.logging.save_image(x_cond[i][:3, :, :], os.path.join(image_folder, str(step), f"{i}_cond.png"))
258 | utils.logging.save_image(x[i], os.path.join(image_folder, str(step), f"{i}.png"))
259 |
--------------------------------------------------------------------------------
/models/unet.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | import math
4 | import torch
5 | import torch.nn as nn
6 |
7 | def get_timestep_embedding(timesteps, embedding_dim):
8 | """
9 | This matches the implementation in Denoising Diffusion Probabilistic Models:
10 | From Fairseq.
11 | Build sinusoidal embeddings.
12 | This matches the implementation in tensor2tensor, but differs slightly
13 | from the description in Section 3.5 of "Attention Is All You Need".
14 | """
15 | # assert len(timesteps.shape) == 1
16 |
17 | half_dim = embedding_dim // 2
18 |
19 | emb = math.log(10000) / (half_dim - 1)
20 |
21 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
22 |
23 | emb = emb.to(device=timesteps.device)
24 |
25 | emb = timesteps.float()[:, None] * emb[None, :]
26 |
27 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
28 |
29 | if embedding_dim % 2 != 0:
30 | emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
31 |
32 | return emb
33 |
34 |
35 |
36 | def nonlinearity(x):
37 | # swish
38 | return x*torch.sigmoid(x)
39 |
40 |
41 | def Normalize(in_channels):
42 | return torch.nn.GroupNorm(num_groups=8, num_channels=in_channels, eps=1e-6, affine=True)
43 |
44 |
45 | class Upsample(nn.Module):
46 | def __init__(self, in_channels, with_conv):
47 | super().__init__()
48 | self.with_conv = with_conv
49 | if self.with_conv:
50 | self.conv = torch.nn.Conv2d(in_channels,
51 | in_channels,
52 | kernel_size=3,
53 | stride=1,
54 | padding=1)
55 |
56 | def forward(self, x):
57 | x = torch.nn.functional.interpolate(
58 | x, scale_factor=2.0, mode="nearest")
59 | if self.with_conv:
60 | x = self.conv(x)
61 | return x
62 |
63 |
64 | class Downsample(nn.Module):
65 | def __init__(self, in_channels, with_conv):
66 | super().__init__()
67 | self.with_conv = with_conv
68 | if self.with_conv:
69 | # no asymmetric padding in torch conv, must do it ourselves
70 | self.conv = torch.nn.Conv2d(in_channels,
71 | in_channels,
72 | kernel_size=3,
73 | stride=2,
74 | padding=0)
75 |
76 | def forward(self, x):
77 | if self.with_conv:
78 | pad = (0, 1, 0, 1)
79 | x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
80 | x = self.conv(x)
81 | else:
82 | x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
83 | return x
84 |
85 |
86 | class HighMixer(nn.Module):
87 | def __init__(self, dim, kernel_size=3, stride=1, padding=1,
88 | **kwargs, ):
89 | super().__init__()
90 |
91 | self.cnn_in = cnn_in = dim
92 | # self.pool_in = pool_in = dim-cnn_in
93 |
94 | self.cnn_dim = cnn_dim = cnn_in
95 | # self.pool_dim = pool_dim = pool_in
96 |
97 | self.conv1 = nn.Conv2d(cnn_in, cnn_dim, kernel_size=1, stride=1, padding=0, bias=False)
98 | self.proj1 = nn.Conv2d(cnn_dim, cnn_dim, kernel_size=kernel_size, stride=stride, padding=padding, bias=False,
99 | groups=cnn_dim)
100 | self.mid_gelu1 = nn.GELU()
101 |
102 |
103 |
104 | def forward(self, x):
105 | # B, C H, W
106 |
107 | cx = x[:, :self.cnn_in, :, :].contiguous()
108 | cx = self.conv1(cx)
109 | cx = self.proj1(cx)
110 | cx = self.mid_gelu1(cx)
111 |
112 | return cx
113 |
114 |
115 | class LowMixer(nn.Module):
116 | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., pool_size=2,
117 | **kwargs, ):
118 | super().__init__()
119 | self.num_heads = num_heads
120 | self.head_dim = head_dim = dim // num_heads
121 | self.scale = head_dim ** -0.5
122 | self.dim = dim
123 |
124 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
125 | self.attn_drop = nn.Dropout(attn_drop)
126 |
127 | self.pool = nn.AvgPool2d(pool_size, stride=pool_size, padding=0,
128 | count_include_pad=False) if pool_size > 1 else nn.Identity()
129 | self.uppool = nn.Upsample(scale_factor=pool_size) if pool_size > 1 else nn.Identity()
130 |
131 | def att_fun(self, q, k, v, B, N, C):
132 | attn = (q @ k.transpose(-2, -1)) * self.scale
133 | attn = attn.softmax(dim=-1)
134 | attn = self.attn_drop(attn)
135 | # x = (attn @ v).transpose(1, 2).reshape(B, N, C)
136 | x = (attn @ v).transpose(2, 3).reshape(B, C, N)
137 | return x
138 |
139 | def forward(self, x):
140 | # B, C, H, W
141 | B, _, _, _ = x.shape
142 | x = self.pool(x)
143 | xa = x.permute(0, 2, 3, 1).view(B, -1, self.dim)
144 | B, N, C = xa.shape
145 | qkv = self.qkv(xa).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
146 | q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
147 | xa = self.att_fun(q, k, v, B, N, C)
148 | xa = xa.view(B, C, int(N ** 0.5), int(N ** 0.5)) # .permute(0, 3, 1, 2)
149 |
150 | xa = self.uppool(xa)
151 | return xa
152 |
153 |
154 | class Mixer(nn.Module):
155 | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., attention_head=1, pool_size=2,
156 | **kwargs, ):
157 | super().__init__()
158 | self.num_heads = num_heads
159 | self.low_dim = low_dim = dim //2
160 | self.high_dim = high_dim = dim - low_dim
161 | self.high_mixer = HighMixer(high_dim)
162 | self.low_mixer = LowMixer(low_dim, num_heads=attention_head, qkv_bias=qkv_bias, attn_drop=attn_drop,
163 | pool_size=pool_size, )
164 |
165 | self.conv_fuse = nn.Conv2d(low_dim + high_dim , low_dim + high_dim , kernel_size=3, stride=1, padding=1,
166 | bias=False, groups=low_dim + high_dim)
167 | self.proj = nn.Conv2d(low_dim + high_dim, dim, kernel_size=1, stride=1, padding=0)
168 | self.proj_drop = nn.Dropout(proj_drop)
169 |
170 | self.freblock=FreBlock(dim,dim)
171 | self.finalproj = nn.Conv2d(2 * dim, dim, 1, 1, 0)
172 | def forward(self, x):
173 |
174 | B, C, H, W = x.shape #16,128,64,64
175 | #x = x.permute(0, 3, 1, 2)
176 | x_ori=x
177 |
178 | hx = x[:, :self.high_dim, :, :].contiguous()#16,64,64,64
179 | hx = self.high_mixer(hx)#16,64,64,64
180 |
181 | lx = x[:, self.high_dim:, :, :].contiguous()#16,64,64,64
182 | lx = self.low_mixer(lx)#16,64,64,64
183 | x = torch.cat((hx, lx), dim=1)#16,128,64,64
184 |
185 | x = x + self.conv_fuse(x)
186 | x_sptial = self.proj(x)
187 |
188 | x_freq=self.freblock(x_ori)
189 |
190 | x_out=torch.cat((x_sptial,x_freq),1)
191 | x_out=self.finalproj(x_out)
192 | x_out=self.proj_drop(x_out)
193 | return x_out+x_ori
194 |
195 |
196 | class ResnetBlock(nn.Module):
197 | def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
198 | dropout, temb_channels=512,incep=False):
199 | super().__init__()
200 | self.in_channels = in_channels
201 | out_channels = in_channels if out_channels is None else out_channels
202 | self.out_channels = out_channels
203 | self.use_conv_shortcut = conv_shortcut
204 |
205 | self.norm1 = Normalize(in_channels)
206 | self.conv1 = torch.nn.Conv2d(in_channels,
207 | out_channels,
208 | kernel_size=3,
209 | stride=1,
210 | padding=1)
211 | # self.conv1=FreBlock(in_channels,out_channels)
212 | self.temb_proj = torch.nn.Linear(temb_channels,
213 | out_channels)
214 | self.norm2 = Normalize(out_channels)
215 | self.dropout = torch.nn.Dropout(dropout)
216 | if incep==True:
217 | self.conv2 =Mixer(dim=out_channels)
218 | else:
219 | self.conv2 = torch.nn.Conv2d(out_channels,
220 | out_channels,
221 | kernel_size=3,
222 | stride=1,
223 | padding=1)
224 | if self.in_channels != self.out_channels:
225 | if self.use_conv_shortcut:
226 | self.conv_shortcut = torch.nn.Conv2d(in_channels,
227 | out_channels,
228 | kernel_size=3,
229 | stride=1,
230 | padding=1)
231 | else:
232 | self.nin_shortcut = torch.nn.Conv2d(in_channels,
233 | out_channels,
234 | kernel_size=1,
235 | stride=1,
236 | padding=0)
237 |
238 | def forward(self, x, temb):
239 | h = x
240 | h = self.norm1(h)
241 | h = nonlinearity(h) #(16,128,64,64)
242 | h = self.conv1(h) #(16,128,64,64)
243 |
244 | h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
245 |
246 | h = self.norm2(h)
247 | h = nonlinearity(h)
248 | h = self.dropout(h)
249 | h = self.conv2(h) #(16,128,64,64)
250 |
251 | if self.in_channels != self.out_channels:
252 | if self.use_conv_shortcut:
253 | x = self.conv_shortcut(x)
254 | else:
255 | x = self.nin_shortcut(x)
256 |
257 | return x+h
258 |
259 |
260 | class AttnBlock(nn.Module):
261 | def __init__(self, in_channels):
262 | super().__init__()
263 | self.in_channels = in_channels
264 |
265 | self.norm = Normalize(in_channels)
266 | self.q = torch.nn.Conv2d(in_channels,
267 | in_channels,
268 | kernel_size=1,
269 | stride=1,
270 | padding=0)
271 | self.k = torch.nn.Conv2d(in_channels,
272 | in_channels,
273 | kernel_size=1,
274 | stride=1,
275 | padding=0)
276 | self.v = torch.nn.Conv2d(in_channels,
277 | in_channels,
278 | kernel_size=1,
279 | stride=1,
280 | padding=0)
281 | self.proj_out = torch.nn.Conv2d(in_channels,
282 | in_channels,
283 | kernel_size=1,
284 | stride=1,
285 | padding=0)
286 |
287 | def forward(self, x):
288 | h_ = x
289 | h_ = self.norm(h_)
290 | q = self.q(h_)
291 | k = self.k(h_)
292 | v = self.v(h_)
293 |
294 | # compute attention
295 | b, c, h, w = q.shape
296 | q = q.reshape(b, c, h*w)
297 | q = q.permute(0, 2, 1) # b,hw,c
298 | k = k.reshape(b, c, h*w) # b,c,hw
299 | w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
300 | w_ = w_ * (int(c)**(-0.5))
301 | w_ = torch.nn.functional.softmax(w_, dim=2)
302 |
303 | # attend to values
304 | v = v.reshape(b, c, h*w)
305 | w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
306 | # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
307 | h_ = torch.bmm(v, w_)
308 | h_ = h_.reshape(b, c, h, w)
309 |
310 | h_ = self.proj_out(h_)
311 |
312 | return x+h_
313 |
314 | class FreBlock(nn.Module):
315 | def __init__(self, in_channels,out_channels):
316 | super(FreBlock, self).__init__()
317 | self.processmag = nn.Sequential(
318 | nn.Conv2d(in_channels,in_channels,1,1,0),
319 | nn.LeakyReLU(0.1,inplace=True),
320 | nn.Conv2d(in_channels,out_channels,1,1,0))
321 | self.processpha = nn.Sequential(
322 | nn.Conv2d(in_channels, in_channels, 1, 1, 0),
323 | nn.LeakyReLU(0.1, inplace=True),
324 | nn.Conv2d(in_channels, out_channels, 1, 1, 0))
325 |
326 | def forward(self,x):
327 | xori = x
328 | _, _, H, W = x.shape
329 | x_freq = torch.fft.rfft2(x, norm='backward')
330 | mag = torch.abs(x_freq)
331 | pha = torch.angle(x_freq)
332 | mag = self.processmag(mag)
333 | pha = self.processpha(pha)
334 | real = mag * torch.cos(pha)
335 | imag = mag * torch.sin(pha)
336 | x_out = torch.complex(real, imag)
337 | x_out1 = torch.fft.irfft2(x_out, s=(H, W), norm='backward')
338 | return x_out1
339 | class DiffusionUNet(nn.Module):
340 | def __init__(self, config):
341 | super().__init__()
342 | self.config = config
343 | ch, out_ch, ch_mult = config.model.ch, config.model.out_ch, tuple(config.model.ch_mult)
344 | num_res_blocks = config.model.num_res_blocks
345 | attn_resolutions = config.model.attn_resolutions
346 | dropout = config.model.dropout
347 | in_channels = config.model.in_channels * 3 if config.data.conditional else config.model.in_channels##为了max图,所以*3
348 | resolution = config.data.image_size
349 | resamp_with_conv = config.model.resamp_with_conv
350 |
351 | self.ch = ch #128
352 | self.temb_ch = self.ch*4 #512
353 | self.num_resolutions = len(ch_mult)
354 | self.num_res_blocks = num_res_blocks
355 | self.resolution = resolution
356 | self.in_channels = in_channels
357 |
358 | self.inceplayers=2
359 |
360 | # timestep embedding
361 | self.temb = nn.Module()
362 | self.temb.dense = nn.ModuleList([
363 | torch.nn.Linear(self.ch,
364 | self.temb_ch),
365 | torch.nn.Linear(self.temb_ch,
366 | self.temb_ch),
367 | ])
368 |
369 | # downsampling
370 | self.conv_in = torch.nn.Conv2d(in_channels,
371 | self.ch,
372 | kernel_size=3,
373 | stride=1,
374 | padding=1)
375 |
376 | curr_res = resolution
377 | in_ch_mult = (1,)+ch_mult
378 | self.down = nn.ModuleList()
379 | block_in = None
380 | for i_level in range(self.num_resolutions):
381 | block = nn.ModuleList()
382 | attn = nn.ModuleList()
383 | block_in = ch*in_ch_mult[i_level]
384 | block_out = ch*ch_mult[i_level]
385 | for i_block in range(self.num_res_blocks):
386 | if i_level+self.inceplayers=self.num_resolutions:
431 | block.append(ResnetBlock(in_channels=block_in+skip_in,
432 | out_channels=block_out,
433 | temb_channels=self.temb_ch,
434 | dropout=dropout,incep=True))
435 | else:
436 | block.append(ResnetBlock(in_channels=block_in+skip_in,
437 | out_channels=block_out,
438 | temb_channels=self.temb_ch,
439 | dropout=dropout,incep=True))
440 | block_in = block_out
441 | # if curr_res in attn_resolutions:
442 | # attn.append(AttnBlock(block_in))
443 | up = nn.Module()
444 | up.block = block
445 | up.attn = attn
446 | if i_level != 0 :
447 | if ch_mult[i_level]== ch_mult[i_level-1]*2:
448 | up.upsample = Upsample(block_in, resamp_with_conv)
449 | curr_res = curr_res * 2
450 | self.up.insert(0, up) # prepend to get consistent order
451 |
452 | # end
453 | self.norm_out = Normalize(block_in)
454 | self.conv_out = torch.nn.Conv2d(block_in,
455 | out_ch,
456 | kernel_size=3,
457 | stride=1,
458 | padding=1)
459 |
460 | def forward(self, x, t,i,j,osize):
461 | assert x.shape[2] == x.shape[3]
462 | # timestep embedding
463 | # t=torch.stack([t.unsqueeze(1),cond],dim=1)
464 | temb1 = get_timestep_embedding(t, self.ch//4)#(16,32)
465 | temb2 = get_timestep_embedding(i, self.ch//4) # (16,32)
466 | temb3 = get_timestep_embedding(j, self.ch//4) # (16,32)
467 | temb4 = get_timestep_embedding(osize, self.ch//4) # (16,32)
468 |
469 | temb=torch.cat([temb1,temb2,temb3,temb4],dim=1)
470 | temb = self.temb.dense[0](temb) #(16,512)
471 | temb = nonlinearity(temb) #(16,512)
472 | temb = self.temb.dense[1](temb) #(16,512)
473 |
474 | # downsampling
475 | hs = [self.conv_in(x)]
476 | for i_level in range(self.num_resolutions):
477 | for i_block in range(self.num_res_blocks):
478 | h = self.down[i_level].block[i_block](hs[-1], temb)
479 | if len(self.down[i_level].attn) > 0:
480 | h = self.down[i_level].attn[i_block](h)
481 | hs.append(h)
482 |
483 | if i_level != self.num_resolutions-1 :
484 | hs.append(self.down[i_level].downsample(hs[-1]))
485 |
486 |
487 | # middle
488 | h = hs[-1]
489 | h = self.mid.block_1(h, temb)
490 | h = self.mid.block_2(h, temb)
491 |
492 | # upsampling
493 | for i_level in reversed(range(self.num_resolutions)):
494 | for i_block in range(self.num_res_blocks+1):
495 | h = self.up[i_level].block[i_block](
496 | torch.cat([h, hs.pop()], dim=1), temb)
497 | if len(self.up[i_level].attn) > 0:
498 | h = self.up[i_level].attn[i_block](h)
499 | if i_level != 0:
500 | h = self.up[i_level].upsample(h)
501 |
502 | # end
503 | h = self.norm_out(h)
504 | h = nonlinearity(h)
505 | h = self.conv_out(h)
506 | return h
507 |
--------------------------------------------------------------------------------